Source

capgrid-export / shapefile.py

"""
Wrappers around gdal/org to read and write ESRI shapefiles

If you've build gdal from source and installed it to an 
unusual location (say, your $HOME), and importing gdal in
python doesn't work, try: 
setenv LD_LIBRARY_PATH=/path/to/pythondist/lib/

I could clean this up a bit, but gdal/ogr is just plain ugly
I hate working with ogr's python bindings, (python is _not_ 
C++... I don't want to write a crapload of boilerplate to do 
simple things!) so I threw these classes togther...

It doesn't really support complex polygons (i.e. anything 
with interior rings), but it mostly works...

Also, due to limitations of ogr, you can't really easily 
modify an existing shapefile, so it's usually easier to
just copy things to a new shapefile and make the changes
as you copy. 

Note that the new shapefile won't actually be written
to disk until it's object is garbage collected (i.e.
when the script ends, or when it's deleted)

Example:

    # Calculates the azimuth of each line in a shapefile
    # and records it in a copy of the original shapefile.
    
    from math import atan2, degrees
    from shapefile import LineShp

    input_shp = LineShp('/path/to/file.shp')
    output_shp = LineShp('/path/to/newfile.shp', 'w')

    # Copy the fields (just the database structure, not the
    # actual vales of the records) of input_shp to output_shp
    output_shp.copy_fields(input_shp)

    # Add a field called "azimuth" to output_shp that we'll store
    # our calculation in
    output_shp.add_field('azimuth', 'float')

    for line, record in input_shp:
        # "line" is a tuple of (x,y) coords
        x,y = line

        # Calculate azimuth
        azi = atan2(y[-1] - y[0], x[-1] - x[0])
        azi = degrees(azi)
        # Convert from 0=east (math) to 0=north (map)
        azi = 90 - azi
        if azi < 0:
            azi += 360

        # "record" is a dict of attribute data for the current feature
        record['azimuth'] = azi

        # Write the line and it's attributes (record) to output_shp
        output_shp.add(x, y, record)

"""

__license__ = 'MIT License <http://www.opensource.org/licenses/mit-license.php>'
__author__ = 'Joe Kington'
__copyright__ = '2009, Free Software Foundation'

import os

try:
    from osgeo import ogr, osr
except ImportError:
    raise ImportError("""The python bindings for GDAL do not appear to be installed!""")

class Shapefile(object):
    """Reads and writes shapefiles
    Example:
        shp = Shapefile('some/path.shp', type='line')
        for geom, record_dict in shp:
            print 'x = ', geom[0]
            print 'y = ', geom[1]
        print 'Attributes: ', record_dict
    """
    def __init__(self, shpname, type, mode=None):
        """
        Input: 
            shpname: the path to the shapefile
            type: One of <'point', 'line', 'poly'>
            mode: (optional) Either 'r' or 'w' (for read and write, 
                respectively) If mode is 'w', any existing shapefile will 
                be removed and overwritten.
                Defaults to 'r' if shpname is an existing file, 'w' 
                otherwise.
        """
        type = type.lower()
        types = dict( line = ogr.wkbLineString,
                           point = ogr.wkbPoint,
                           poly = ogr.wkbPolygon,
                           linez = ogr.wkbLineString25D,
                           pointz = ogr.wkbPoint25D,
                           polyz = ogr.wkbPolygon25D
                        )
        try:
            self._type = types[type]
        except KeyError:
            raise ValueError('Invalid type!')
    
        if mode is None:
            if os.path.exists(shpname):
                mode = 'r'
            else:
                mode = 'w'

        self._driver = ogr.GetDriverByName('ESRI Shapefile')
        if mode == 'w':
            if os.path.exists(shpname):
                self._driver.DeleteDataSource(shpname)
            name = shpname.split('/')[-1].split('.')[0]
            self._shp_datasource = self._driver.CreateDataSource(shpname)
            self._layer = self._shp_datasource.CreateLayer(name, None, self._type)
        if mode == 'r':
            self._shp_datasource = self._driver.Open(shpname, 0)
            self._layer = self._shp_datasource.GetLayer(0)
        if self._shp_datasource is None:
            raise IOError('Could not open shapefile!')

    def __iter__(self):
        """Iterates through the ogr geometry objects in
        the shapefile and the associated records in the
        attribute table"""
        feature = self._layer.GetFeature(0)
        while feature is not None:
            geom = feature.geometry()
            x, y = [], []
            for i in range(geom.GetPointCount()):
                x.append(geom.GetX(i))
                y.append(geom.GetY(i))
            yield (x,y), feature.items()
            feature.Destroy()
            feature = self._layer.GetNextFeature()

    def add_field(self, name, type, length=None):
        """Add a field to the shapefile. All fields must be added before
        adding any geometric features to the shapefile.
        Input:
            name: The name of the field
            type: The type of the field (one of 'string', 'integer', 'float')
            length: If type=='string', the width of the field. Ignored otherwise
        """
        type = {'string':ogr.OFTString, 'integer':ogr.OFTInteger, 'float':ogr.OFTReal}[type.lower()]
        field_def = ogr.FieldDefn(name, type)
        if type == 'string':
            field_def.SetWidth(length)
        self._layer.CreateField(field_def)

    def add_fields(self, field_dict):
        """Add several fields to the shapefile. All fields must be 
        added before adding any geometric features.
        Input:
            field_dict: A dict with the structure:
                field_dict = {
                    field_name_1 : {
                            type : <'string', 'integer', or 'float'>
                            length : (optional) (integer width of the field
                        },
                    field_name_2 : {
                            type : <'string', 'integer', or 'float'>
                            length : (optional) (integer width of the field
                        },
                ... etc
                }
        """
        for name, info in field_dict.iteritems():
            self.add_field(name, info['type'], info.get('length', None))

    def copy_fields(self, shp):
        """Copy fields from an existing shapefile. Input may be a filename or a Shapefile object"""
        if not isinstance(shp, Shapefile):
            shp = Shapefile(shp, 'line', 'r')
        for field_def in shp.field_definitions:
            self._layer.CreateField(field_def)

    def query(self, sql_string):
        """Query the attribute table of the shapefile with the SQL statement "sql_string" 
        (May be buggy...)"""
        result_layer = self._shp_datasource.ExecuteSQL(sql_string)
        if result_layer is None:
            raise ValueError('SQL Query Failed. See stderr')
        feature = result_layer.GetNextFeature()
        while feature is not None:
            yield feature.geometry(), feature.items()
            feature = result_layer.GetNextFeature()

    @property
    def _layer_def(self):
        return self._layer.GetLayerDefn()

    @property
    def num_fields(self):
        return self._layer_def.GetFieldCount()

    @property
    def field_definitions(self):
        for i in range(self.num_fields):
            yield self._layer_def.GetFieldDefn(i)

    @property
    def field_names(self):
        out = []
        for field_def in self.field_definitions:
            out.append(field_def.GetFieldName())
        return out


    def add(self, X, Y, record_dict, transform=None):
        """Add a feature to the shapefile
        Input:
            X: A list (or array) of x-values in the feature
            Y: A list (or array) of y-values in the feature 
            record_dict: A dict containing field_name:value
                pairs for line
        """
        geometry = ogr.Geometry(self._type)
        try:
            for x,y in zip(X,Y):
                geometry.AddPoint(x,y)
        except TypeError:
            # Only one point...
            geometry.AddPoint(X,Y)
        self.add_geometry(geometry, record_dict, transform)

    def add_geometry(self, geometry, record_dict, transform=None):
        """Add a line to the shapefile from a osr.Geometry object
        Input:
            geometry: An osr.Geometry LineString object
            record_dict: A dict containing field_name:value
                pairs for line
        """
        feature = ogr.Feature(self._layer_def)
        if transform is not None:
            geometry.Transform(transform)
        feature.SetGeometry(geometry)
        for key, value in record_dict.iteritems():
            feature.SetField(key, value)
        self._layer.CreateFeature(feature)
        feature.Destroy()

class LineShp(Shapefile):
    def __init__(self, shpname, mode=None):
        Shapefile.__init__(self, shpname, 'line', mode)

class PointShp(Shapefile):
    def __init__(self, shpname, mode=None):
        Shapefile.__init__(self, shpname, 'point', mode)

class PolyShp(Shapefile):
    def __init__(self, shpname, mode=None):
        Shapefile.__init__(self, shpname, 'poly', mode)