Mike Bayer avatar Mike Bayer committed 7bda449

- repaired single table inheritance such that you
can single-table inherit from a joined-table inherting
mapper without issue [ticket:1036].

Comments (0)

Files changed (3)

       functions and other expressions.  (partial progress
       towards [ticket:610])
     
+    - repaired single table inheritance such that you 
+      can single-table inherit from a joined-table inherting
+      mapper without issue [ticket:1036].
+      
     - Fixed "concatenate tuple" bug which could occur with
       Query.order_by() if clause adaption had taken place.
       [ticket:1027]

lib/sqlalchemy/orm/mapper.py

             # inherit_condition is optional.
             if self.local_table is None:
                 self.local_table = self.inherits.local_table
+                self.mapped_table = self.inherits.mapped_table
                 self.single = True
-            if not self.local_table is self.inherits.local_table:
+            elif not self.local_table is self.inherits.local_table:
                 if self.concrete:
                     self.mapped_table = self.local_table
                     for mapper in self.iterate_to_root():
         for mapper in self.iterate_to_root():
             if mapper is base_mapper:
                 break
-            allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
+            if not mapper.single:
+                allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
 
         return sql.and_(*allconds), param_names
 

test/orm/inheritance/single.py

 from sqlalchemy import *
 from sqlalchemy.orm import *
 from testlib import *
+from testlib.fixtures import Base
 
-
-class SingleInheritanceTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
-        metadata = MetaData(testing.db)
+class SingleInheritanceTest(ORMTest):
+    def define_tables(self, metadata):
         global employees_table
         employees_table = Table('employees', metadata,
             Column('employee_id', Integer, primary_key=True),
             Column('engineer_info', String(50)),
             Column('type', String(20))
         )
-        employees_table.create()
-    def tearDownAll(self):
-        employees_table.drop()
-    def testbasic(self):
-        class Employee(object):
-            def __init__(self, name):
-                self.name = name
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name
 
+    def test_single_inheritance(self):
+        class Employee(Base):
+            pass
         class Manager(Employee):
-            def __init__(self, name, manager_data):
-                self.name = name
-                self.manager_data = manager_data
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " " +  self.manager_data
-
+            pass
         class Engineer(Employee):
-            def __init__(self, name, engineer_info):
-                self.name = name
-                self.engineer_info = engineer_info
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " " +  self.engineer_info
-
+            pass
         class JuniorEngineer(Engineer):
             pass
 
-        employee_mapper = mapper(Employee, employees_table, polymorphic_on=employees_table.c.type)
-        manager_mapper = mapper(Manager, inherits=employee_mapper, polymorphic_identity='manager')
-        engineer_mapper = mapper(Engineer, inherits=employee_mapper, polymorphic_identity='engineer')
-        junior_engineer = mapper(JuniorEngineer, inherits=engineer_mapper, polymorphic_identity='juniorengineer')
+        mapper(Employee, employees_table, polymorphic_on=employees_table.c.type)
+        mapper(Manager, inherits=Employee, polymorphic_identity='manager')
+        mapper(Engineer, inherits=Employee, polymorphic_identity='engineer')
+        mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer')
 
         session = create_session()
 
-        m1 = Manager('Tom', 'knows how to manage things')
-        e1 = Engineer('Kurt', 'knows how to hack')
-        e2 = JuniorEngineer('Ed', 'oh that ed')
+        m1 = Manager(name='Tom', manager_data='knows how to manage things')
+        e1 = Engineer(name='Kurt', engineer_info='knows how to hack')
+        e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed')
         session.save(m1)
         session.save(e1)
         session.save(e2)
         assert session.query(Manager).all() == [m1]
         assert session.query(JuniorEngineer).all() == [e2]
 
+class SingleOnJoinedTest(ORMTest):
+    def define_tables(self, metadata):
+        global persons_table, employees_table
+        
+        persons_table = Table('persons', metadata,
+           Column('person_id', Integer, primary_key=True),
+           Column('name', String(50)),
+           Column('type', String(20), nullable=False)
+        )
+
+        employees_table = Table('employees', metadata,
+           Column('person_id', Integer, ForeignKey('persons.person_id'),primary_key=True),
+           Column('employee_data', String(50)),
+           Column('manager_data', String(50)),
+        )
+    
+    def test_single_on_joined(self):
+        class Person(Base):
+            pass
+        class Employee(Person):
+            pass
+        class Manager(Employee):
+            pass
+        
+        mapper(Person, persons_table, polymorphic_on=persons_table.c.type, polymorphic_identity='person')
+        mapper(Employee, employees_table, inherits=Person,polymorphic_identity='engineer')
+        mapper(Manager, inherits=Employee,polymorphic_identity='manager')
+        
+        sess = create_session()
+        sess.save(Person(name='p1'))
+        sess.save(Employee(name='e1', employee_data='ed1'))
+        sess.save(Manager(name='m1', employee_data='ed2', manager_data='md1'))
+        sess.flush()
+        sess.clear()
+        
+        self.assertEquals(sess.query(Person).order_by(Person.person_id).all(), [
+            Person(name='p1'),
+            Employee(name='e1', employee_data='ed1'),
+            Manager(name='m1', employee_data='ed2', manager_data='md1')
+        ])
+        sess.clear()
+
+        self.assertEquals(sess.query(Employee).order_by(Person.person_id).all(), [
+            Employee(name='e1', employee_data='ed1'),
+            Manager(name='m1', employee_data='ed2', manager_data='md1')
+        ])
+        sess.clear()
+
+        self.assertEquals(sess.query(Manager).order_by(Person.person_id).all(), [
+            Manager(name='m1', employee_data='ed2', manager_data='md1')
+        ])
+        sess.clear()
+        
+        def go():
+            self.assertEquals(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [
+                Person(name='p1'),
+                Employee(name='e1', employee_data='ed1'),
+                Manager(name='m1', employee_data='ed2', manager_data='md1')
+            ])
+        self.assert_sql_count(testing.db, go, 1)
+    
 if __name__ == '__main__':
     testenv.main()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.