capgrid-export /

Wrappers around gdal/org to read and write ESRI shapefiles

If you've build gdal from source and installed it to an 
unusual location (say, your $HOME), and importing gdal in
python doesn't work, try: 
setenv LD_LIBRARY_PATH=/path/to/pythondist/lib/

I could clean this up a bit, but gdal/ogr is just plain ugly
I hate working with ogr's python bindings, (python is _not_ 
C++... I don't want to write a crapload of boilerplate to do 
simple things!) so I threw these classes togther...

It doesn't really support complex polygons (i.e. anything 
with interior rings), but it mostly works...

Also, due to limitations of ogr, you can't really easily 
modify an existing shapefile, so it's usually easier to
just copy things to a new shapefile and make the changes
as you copy. 

Note that the new shapefile won't actually be written
to disk until it's object is garbage collected (i.e.
when the script ends, or when it's deleted)


    # Calculates the azimuth of each line in a shapefile
    # and records it in a copy of the original shapefile.
    from math import atan2, degrees
    from shapefile import LineShp

    input_shp = LineShp('/path/to/file.shp')
    output_shp = LineShp('/path/to/newfile.shp', 'w')

    # Copy the fields (just the database structure, not the
    # actual vales of the records) of input_shp to output_shp

    # Add a field called "azimuth" to output_shp that we'll store
    # our calculation in
    output_shp.add_field('azimuth', 'float')

    for line, record in input_shp:
        # "line" is a tuple of (x,y) coords
        x,y = line

        # Calculate azimuth
        azi = atan2(y[-1] - y[0], x[-1] - x[0])
        azi = degrees(azi)
        # Convert from 0=east (math) to 0=north (map)
        azi = 90 - azi
        if azi < 0:
            azi += 360

        # "record" is a dict of attribute data for the current feature
        record['azimuth'] = azi

        # Write the line and it's attributes (record) to output_shp
        output_shp.add(x, y, record)


__license__ = 'MIT License <>'
__author__ = 'Joe Kington'
__copyright__ = '2009, Free Software Foundation'

import os

    from osgeo import ogr, osr
except ImportError:
    raise ImportError("""The python bindings for GDAL do not appear to be installed!""")

class Shapefile(object):
    """Reads and writes shapefiles
        shp = Shapefile('some/path.shp', type='line')
        for geom, record_dict in shp:
            print 'x = ', geom[0]
            print 'y = ', geom[1]
        print 'Attributes: ', record_dict
    def __init__(self, shpname, type, mode=None):
            shpname: the path to the shapefile
            type: One of <'point', 'line', 'poly'>
            mode: (optional) Either 'r' or 'w' (for read and write, 
                respectively) If mode is 'w', any existing shapefile will 
                be removed and overwritten.
                Defaults to 'r' if shpname is an existing file, 'w' 
        type = type.lower()
        types = dict( line = ogr.wkbLineString,
                           point = ogr.wkbPoint,
                           poly = ogr.wkbPolygon,
                           linez = ogr.wkbLineString25D,
                           pointz = ogr.wkbPoint25D,
                           polyz = ogr.wkbPolygon25D
            self._type = types[type]
        except KeyError:
            raise ValueError('Invalid type!')
        if mode is None:
            if os.path.exists(shpname):
                mode = 'r'
                mode = 'w'

        self._driver = ogr.GetDriverByName('ESRI Shapefile')
        if mode == 'w':
            if os.path.exists(shpname):
            name = shpname.split('/')[-1].split('.')[0]
            self._shp_datasource = self._driver.CreateDataSource(shpname)
            self._layer = self._shp_datasource.CreateLayer(name, None, self._type)
        if mode == 'r':
            self._shp_datasource = self._driver.Open(shpname, 0)
            self._layer = self._shp_datasource.GetLayer(0)
        if self._shp_datasource is None:
            raise IOError('Could not open shapefile!')

    def __iter__(self):
        """Iterates through the ogr geometry objects in
        the shapefile and the associated records in the
        attribute table"""
        feature = self._layer.GetFeature(0)
        while feature is not None:
            geom = feature.geometry()
            x, y = [], []
            for i in range(geom.GetPointCount()):
            yield (x,y), feature.items()
            feature = self._layer.GetNextFeature()

    def add_field(self, name, type, length=None):
        """Add a field to the shapefile. All fields must be added before
        adding any geometric features to the shapefile.
            name: The name of the field
            type: The type of the field (one of 'string', 'integer', 'float')
            length: If type=='string', the width of the field. Ignored otherwise
        type = {'string':ogr.OFTString, 'integer':ogr.OFTInteger, 'float':ogr.OFTReal}[type.lower()]
        field_def = ogr.FieldDefn(name, type)
        if type == 'string':

    def add_fields(self, field_dict):
        """Add several fields to the shapefile. All fields must be 
        added before adding any geometric features.
            field_dict: A dict with the structure:
                field_dict = {
                    field_name_1 : {
                            type : <'string', 'integer', or 'float'>
                            length : (optional) (integer width of the field
                    field_name_2 : {
                            type : <'string', 'integer', or 'float'>
                            length : (optional) (integer width of the field
                ... etc
        for name, info in field_dict.iteritems():
            self.add_field(name, info['type'], info.get('length', None))

    def copy_fields(self, shp):
        """Copy fields from an existing shapefile. Input may be a filename or a Shapefile object"""
        if not isinstance(shp, Shapefile):
            shp = Shapefile(shp, 'line', 'r')
        for field_def in shp.field_definitions:

    def query(self, sql_string):
        """Query the attribute table of the shapefile with the SQL statement "sql_string" 
        (May be buggy...)"""
        result_layer = self._shp_datasource.ExecuteSQL(sql_string)
        if result_layer is None:
            raise ValueError('SQL Query Failed. See stderr')
        feature = result_layer.GetNextFeature()
        while feature is not None:
            yield feature.geometry(), feature.items()
            feature = result_layer.GetNextFeature()

    def _layer_def(self):
        return self._layer.GetLayerDefn()

    def num_fields(self):
        return self._layer_def.GetFieldCount()

    def field_definitions(self):
        for i in range(self.num_fields):
            yield self._layer_def.GetFieldDefn(i)

    def field_names(self):
        out = []
        for field_def in self.field_definitions:
        return out

    def add(self, X, Y, record_dict, transform=None):
        """Add a feature to the shapefile
            X: A list (or array) of x-values in the feature
            Y: A list (or array) of y-values in the feature 
            record_dict: A dict containing field_name:value
                pairs for line
        geometry = ogr.Geometry(self._type)
            for x,y in zip(X,Y):
        except TypeError:
            # Only one point...
        self.add_geometry(geometry, record_dict, transform)

    def add_geometry(self, geometry, record_dict, transform=None):
        """Add a line to the shapefile from a osr.Geometry object
            geometry: An osr.Geometry LineString object
            record_dict: A dict containing field_name:value
                pairs for line
        feature = ogr.Feature(self._layer_def)
        if transform is not None:
        for key, value in record_dict.iteritems():
            feature.SetField(key, value)

class LineShp(Shapefile):
    def __init__(self, shpname, mode=None):
        Shapefile.__init__(self, shpname, 'line', mode)

class PointShp(Shapefile):
    def __init__(self, shpname, mode=None):
        Shapefile.__init__(self, shpname, 'point', mode)

class PolyShp(Shapefile):
    def __init__(self, shpname, mode=None):
        Shapefile.__init__(self, shpname, 'poly', mode)