Commits

Mike Bayer  committed 0ba8b4f

- [feature] postgresql.ARRAY features an optional
"dimension" argument, will assign a specific
number of dimensions to the array which will
render in DDL as ARRAY[][]..., also improves
performance of bind/result processing.
[ticket:2441]

  • Participants
  • Parent commits 369e2cd

Comments (0)

Files changed (3)

     with_lockmode("read_nowait").
     These emit "FOR SHARE" and "FOR SHARE NOWAIT",
     respectively.  Courtesy Diana Clarke 
-    [ticket:2445]
-    also in 0.7.7.
+    [ticket:2445] Also in 0.7.7.
+
+  - [feature] postgresql.ARRAY features an optional
+    "dimension" argument, will assign a specific
+    number of dimensions to the array which will
+    render in DDL as ARRAY[][]..., also improves
+    performance of bind/result processing.
+    [ticket:2441]
 
 - mysql
   - [bug] Fixed bug whereby column name inside 

File lib/sqlalchemy/dialects/postgresql/base.py

     """
     __visit_name__ = 'ARRAY'
 
-    def __init__(self, item_type, as_tuple=False):
+    def __init__(self, item_type, as_tuple=False, dimensions=None):
         """Construct an ARRAY.
 
         E.g.::
           as psycopg2 return lists by default. When tuples are
           returned, the results are hashable.
 
+        :param dimensions: if non-None, the ARRAY will assume a fixed
+         number of dimensions.  This will cause the DDL emitted for this
+         ARRAY to include the exact number of bracket clauses ``[]``,
+         and will also optimize the performance of the type overall. 
+         Note that PG arrays are always implicitly "non-dimensioned",
+         meaning they can store any number of dimensions no matter how
+         they were declared.
+
         """
         if isinstance(item_type, ARRAY):
             raise ValueError("Do not nest ARRAY types; ARRAY(basetype) "
             item_type = item_type()
         self.item_type = item_type
         self.as_tuple = as_tuple
+        self.dimensions = dimensions
 
     def compare_values(self, x, y):
         return x == y
 
+    def _proc_array(self, arr, itemproc, dim, collection):
+        if dim == 1 or (
+                    dim is None and
+                    (not arr or not isinstance(arr[0], (list, tuple)))
+                ):
+            if itemproc:
+                return collection(itemproc(x) for x in arr)
+            else:
+                return collection(arr)
+        else:
+            return collection(
+                    self._proc_array(
+                            x, itemproc, 
+                            dim - 1 if dim is not None else None, 
+                            collection) 
+                    for x in arr
+                )
+
     def bind_processor(self, dialect):
-        item_proc = self.item_type.dialect_impl(dialect).bind_processor(dialect)
-        if item_proc:
-            def convert_item(item):
-                if isinstance(item, (list, tuple)):
-                    return [convert_item(child) for child in item]
-                else:
-                    return item_proc(item)
-        else:
-            def convert_item(item):
-                if isinstance(item, (list, tuple)):
-                    return [convert_item(child) for child in item]
-                else:
-                    return item
+        item_proc = self.item_type.\
+                        dialect_impl(dialect).\
+                        bind_processor(dialect)
         def process(value):
             if value is None:
                 return value
-            return [convert_item(item) for item in value]
+            else:
+                return self._proc_array(
+                            value, 
+                            item_proc, 
+                            self.dimensions, 
+                            list)
         return process
 
     def result_processor(self, dialect, coltype):
-        item_proc = self.item_type.dialect_impl(dialect).result_processor(dialect, coltype)
-        if item_proc:
-            def convert_item(item):
-                if isinstance(item, list):
-                    r = [convert_item(child) for child in item]
-                    if self.as_tuple:
-                        r = tuple(r)
-                    return r
-                else:
-                    return item_proc(item)
-        else:
-            def convert_item(item):
-                if isinstance(item, list):
-                    r = [convert_item(child) for child in item]
-                    if self.as_tuple:
-                        r = tuple(r)
-                    return r
-                else:
-                    return item
+        item_proc = self.item_type.\
+                        dialect_impl(dialect).\
+                        result_processor(dialect, coltype)
         def process(value):
             if value is None:
                 return value
-            r = [convert_item(item) for item in value]
-            if self.as_tuple:
-                r = tuple(r)
-            return r
+            else:
+                return self._proc_array(
+                            value, 
+                            item_proc, 
+                            self.dimensions, 
+                            tuple if self.as_tuple else list)
         return process
+
 PGArray = ARRAY
 
 class ENUM(sqltypes.Enum):
         return "BYTEA"
 
     def visit_ARRAY(self, type_):
-        return self.process(type_.item_type) + '[]'
+        return self.process(type_.item_type) + ('[]' * (type_.dimensions 
+                                                if type_.dimensions 
+                                                is not None else 1))
 
 
 class PGIdentifierPreparer(compiler.IdentifierPreparer):

File test/dialect/test_postgresql.py

     def test_generate_multiple(self):
         """Test that the same enum twice only generates once
         for the create_all() call, without using checkfirst.
-        
+
         A 'memo' collection held by the DDL runner
         now handles this.
-        
+
         """
         metadata = self.metadata
 
     def setup_class(cls):
         global metadata, arrtable
         metadata = MetaData(testing.db)
-        arrtable = Table('arrtable', metadata, Column('id', Integer,
-                         primary_key=True), Column('intarr',
-                         postgresql.ARRAY(Integer)), Column('strarr',
-                         postgresql.ARRAY(Unicode()), nullable=False))
+
+        class ProcValue(TypeDecorator):
+            impl = postgresql.ARRAY(Integer, dimensions=2)
+
+            def process_bind_param(self, value, dialect):
+                if value is None:
+                    return None
+                return [
+                    [x + 5 for x in v]
+                    for v in value
+                ]
+
+            def process_result_value(self, value, dialect):
+                if value is None:
+                    return None
+                return [
+                    [x - 7 for x in v]
+                    for v in value
+                ]
+
+        arrtable = Table('arrtable', metadata, 
+                        Column('id', Integer, primary_key=True), 
+                        Column('intarr',postgresql.ARRAY(Integer)), 
+                         Column('strarr',postgresql.ARRAY(Unicode())),
+                        Column('dimarr', ProcValue)
+                    )
         metadata.create_all()
 
     def teardown(self):
         eq_(results[0]['strarr'], [u'm\xe4\xe4', u'm\xf6\xf6'])
         eq_(results[1]['strarr'], [[u'm\xe4\xe4'], [u'm\xf6\xf6']])
 
-    @testing.fails_on('postgresql+pg8000',
-                      'pg8000 has poor support for PG arrays')
-    @testing.fails_on('postgresql+zxjdbc',
-                      'zxjdbc has no support for PG arrays')
-    def test_array_mutability(self):
-
-        class Foo(object):
-            pass
-
-        footable = Table('foo', metadata, 
-                        Column('id', Integer,primary_key=True), 
-                        Column('intarr', 
-                            postgresql.ARRAY(Integer, mutable=True), 
-                            nullable=True))
-        mapper(Foo, footable)
-        metadata.create_all()
-        sess = create_session()
-        foo = Foo()
-        foo.id = 1
-        foo.intarr = [1, 2, 3]
-        sess.add(foo)
-        sess.flush()
-        sess.expunge_all()
-        foo = sess.query(Foo).get(1)
-        eq_(foo.intarr, [1, 2, 3])
-        foo.intarr.append(4)
-        sess.flush()
-        sess.expunge_all()
-        foo = sess.query(Foo).get(1)
-        eq_(foo.intarr, [1, 2, 3, 4])
-        foo.intarr = []
-        sess.flush()
-        sess.expunge_all()
-        eq_(foo.intarr, [])
-        foo.intarr = None
-        sess.flush()
-        sess.expunge_all()
-        eq_(foo.intarr, None)
-
-        # Errors in r4217:
-
-        foo = Foo()
-        foo.id = 2
-        sess.add(foo)
-        sess.flush()
-
     @testing.fails_on('+zxjdbc',
                       "Can't infer the SQL type to use for an instance "
                       "of org.python.core.PyList.")
     @testing.provide_metadata
     def test_tuple_flag(self):
         metadata = self.metadata
-        assert_raises_message(
-            exc.ArgumentError, 
-            "mutable must be set to False if as_tuple is True.",
-            postgresql.ARRAY, Integer, mutable=True, 
-                as_tuple=True)
 
         t1 = Table('t1', metadata,
             Column('id', Integer, primary_key=True),
-            Column('data', postgresql.ARRAY(String(5), as_tuple=True, mutable=False)),
-            Column('data2', postgresql.ARRAY(Numeric(asdecimal=False), as_tuple=True, mutable=False)),
+            Column('data', postgresql.ARRAY(String(5), as_tuple=True)),
+            Column('data2', postgresql.ARRAY(Numeric(asdecimal=False), as_tuple=True)),
         )
         metadata.create_all()
         testing.db.execute(t1.insert(), id=1, data=["1","2","3"], data2=[5.4, 5.6])
         testing.db.execute(t1.insert(), id=2, data=["4", "5", "6"], data2=[1.0])
-        testing.db.execute(t1.insert(), id=3, data=[["4", "5"], ["6", "7"]], data2=[[5.4, 5.6], [1.0, 1.1]])
+        testing.db.execute(t1.insert(), id=3, data=[["4", "5"], ["6", "7"]], 
+                        data2=[[5.4, 5.6], [1.0, 1.1]])
 
         r = testing.db.execute(t1.select().order_by(t1.c.id)).fetchall()
         eq_(
             set([('1', '2', '3'), ('4', '5', '6'), (('4', '5'), ('6', '7'))])
         )
 
-
+    def test_dimension(self):
+        testing.db.execute(arrtable.insert(), dimarr=[[1, 2, 3], [4,5, 6]])
+        eq_(
+            testing.db.scalar(select([arrtable.c.dimarr])),
+            [[-1, 0, 1], [2, 3, 4]]
+        )
 
 class TimestampTest(fixtures.TestBase, AssertsExecutionResults):
     __only_on__ = 'postgresql'