Commits

Mike Bayer committed f1c95b4

- The "polymorphic discriminator" column may be part of a
primary key, and it will be populated with the correct
discriminator value. [ticket:1300]

Comments (0)

Files changed (4)

       mixin version of the AppenderQuery, which allows subclassing
       the AppenderMixin.
 
+    - The "polymorphic discriminator" column may be part of a 
+      primary key, and it will be populated with the correct 
+      discriminator value.  [ticket:1300]
+      
     - Fixed the evaluator not being able to evaluate IS NULL clauses.
 
     - Fixed the "set collection" function on "dynamic" relations to

lib/sqlalchemy/orm/mapper.py

                     for col in mapper._cols_by_table[table]:
                         if col is mapper.version_id_col:
                             params[col.key] = 1
-                        elif col in pks:
-                            value = mapper._get_state_attr_by_column(state, col)
-                            if value is not None:
-                                params[col.key] = value
                         elif mapper.polymorphic_on and mapper.polymorphic_on.shares_lineage(col):
                             if self._should_log_debug:
                                 self._log_debug("Using polymorphic identity '%s' for insert column '%s'" % (mapper.polymorphic_identity, col.key))
                                  col.server_default is None) or
                                 value is not None):
                                 params[col.key] = value
+                        elif col in pks:
+                            value = mapper._get_state_attr_by_column(state, col)
+                            if value is not None:
+                                params[col.key] = value
                         else:
                             value = mapper._get_state_attr_by_column(state, col)
                             if ((col.default is None and

lib/sqlalchemy/sql/expression.py

     def shares_lineage(self, othercolumn):
         """Return True if the given ``ColumnElement`` has a common ancestor to this ``ColumnElement``."""
 
-        return len(self.proxy_set.intersection(othercolumn.proxy_set)) > 0
+        return bool(self.proxy_set.intersection(othercolumn.proxy_set))
 
     def _make_proxy(self, selectable, name=None):
         """Create a new ``ColumnElement`` representing this

test/orm/inheritance/basic.py

 from sqlalchemy.orm import exc as orm_exc
 from testlib import *
 from testlib import fixtures
+from orm import _base, _fixtures
 
 class O2MTest(ORMTest):
     """deals with inheritance and one-to-many relationships"""
         # the optimized load needs to return "None" so regular full-row loading proceeds
         s1 = sess.query(Base).get(s1.id)
         assert s1.sub == 's1sub'
+
+class PKDiscriminatorTest(_base.MappedTest):
+    def define_tables(self, metadata):
+        parents = Table('parents', metadata,
+                           Column('id', Integer, primary_key=True),
+                           Column('name', String(60)))
+                           
+        children = Table('children', metadata,
+                        Column('id', Integer, ForeignKey('parents.id'), primary_key=True),
+                        Column('type', Integer,primary_key=True),
+                        Column('name', String(60)))
+
+    @testing.resolve_artifact_names
+    def test_pk_as_discriminator(self):
+        class Parent(object):
+                def __init__(self, name=None):
+                    self.name = name
+
+        class Child(object):
+            def __init__(self, name=None):
+                self.name = name
+
+        class A(Child):
+            pass
+            
+        mapper(Parent, parents, properties={
+            'children': relation(Child, backref='parent'),
+        })
+        mapper(Child, children, polymorphic_on=children.c.type,
+            polymorphic_identity=1)
+            
+        mapper(A, inherits=Child, polymorphic_identity=2)
+
+        s = create_session()
+        p = Parent('p1')
+        a = A('a1')
+        p.children.append(a)
+        s.add(p)
+        s.flush()
+
+        assert a.id
+        assert a.type == 2
+        
         
 class DeleteOrphanTest(ORMTest):
     def define_tables(self, metadata):