Commits

Anonymous committed d21f1f6

add simple inherits

Comments (0)

Files changed (3)

 # -*- coding: utf-8 -*-
+import uuid
 import types
 from sqlalchemy.orm import class_mapper
-from sqlalchemy import Column, Integer
+from sqlalchemy import Column, Integer, ForeignKey
+from sqlalchemy.orm.properties import PropertyLoader
 
 
-def describe(mappers):
+def describe(mappers, methods=True):
     """
     """
 
     objects = []
     relations = []
-    inhirits = []
+    inherits = []
 
     for mapper in mappers:
 
             'methods': [],
         }
 
-        # Create the DummyClass subclass of mapper bases
-        # for detecting mapper own methods
-        DummyClass = type('Dummy%s' % mapper.class_.__name__,
-            mapper.class_.__bases__, {
-                '__tablename__': 'dummy_table_%s' % mapper.class_.__name__,
-                '__dummy_col': Column(Integer, primary_key=True)
-            }
-        )
+        if methods:
 
-        base_keys = DummyClass.__dict__.keys()
+            suffix = '%s' % str(uuid.uuid4())
 
-        for name, func in mapper.class_.__dict__.iteritems():
-            if name not in base_keys:
-                if isinstance(func, types.FunctionType):
-                    entry['methods'].append(name)
+            # Create the DummyClass subclass of mapper bases
+            # for detecting mapper own methods
+
+            params = {'__tablename__': 'dummy_table_%s' % suffix}
+
+            if mapper.inherits:
+                params['__mapper_args__'] = {'polymorphic_identity':
+                        mapper.inherits.class_.__tablename__}
+
+                # Get primary key
+                pk = [col for col in mapper.columns if col.primary_key]
+
+                # ForeignKey for inherited class
+                params['dummy_id_col'] = Column(pk[0].type,
+                        ForeignKey(pk[0]), primary_key=True)
+            else:
+                params['dummy_id_col'] = Column(Integer, primary_key=True)
+
+            DummyClass = type('Dummy%s' % suffix,
+                    mapper.class_.__bases__, params)
+
+            base_keys = DummyClass.__dict__.keys()
+
+            # Filter mapper methods
+            for name, func in mapper.class_.__dict__.iteritems():
+                if name not in base_keys:
+                    if isinstance(func, types.FunctionType):
+                        entry['methods'].append(name)
 
         objects.append(entry)
 
-    return objects, relations, inhirits
+        for loader in mapper.iterate_properties:
+            if isinstance(loader, PropertyLoader) and loader.mapper in mappers:
+                if hasattr(loader, 'reverse_property'):
+                    relations.add(frozenset([loader, loader.reverse_property]))
+                else:
+                    relations.add(frozenset([loader]))
+
+        if mapper.inherits:
+            inherits.append({
+                'child': mapper.class_.__name__,
+                'parent': mapper.inherits.class_.__name__,
+            })
+
+    return objects, relations, inherits
+# -*- coding: utf-8 -*-
+from sqlalchemy import Column, Integer, Unicode, ForeignKey
+from sqlalchemy.ext.declarative import declarative_base
+
+
+BASE = declarative_base()
+
+
+class User(BASE):
+    __tablename__ = 'user_table'
+
+    id = Column(Integer, primary_key=True)
+    name = Column(Unicode(50))
+
+    def login(self):
+        pass
+
+
+class Admin(User):
+    __tablename__ = 'admin_table'
+    __mapper_args__ = {'polymorphic_identity': 'user_table'}
+
+    id = Column(Integer, ForeignKey('user_table.id'), primary_key=True)
+    phone = Column(Unicode(50))
+
+    def permissions(self):
+        pass

tests/test_describe.py

 # -*- coding: utf-8 -*-
 import unittest
-
-from sqlalchemy import Column, Integer, Unicode
-from sqlalchemy.ext.declarative import declarative_base
-
 import sadisplay
+import model
 
 
 class TestDescribe(unittest.TestCase):
 
-    def setUp(self):
-        self.BASE = declarative_base()
-
     def test_single(self):
 
-        class User(self.BASE):
-            __tablename__ = 'user_table'
-
-            id = Column(Integer, primary_key=True)
-            name = Column(Unicode(50))
-
-            def login(self):
-                pass
-
-        objects, relations, inhirets = sadisplay.describe([User])
+        objects, relations, inherits = sadisplay.describe([model.User])
 
         assert len(objects) == 1
         assert relations == []
-        assert inhirets == []
+        assert inherits == []
         assert objects[0] == {
-                'name': User.__name__,
-                'attributes': [('Integer', 'id'), ('Unicode', 'name')],
-                'methods': ['login'],
+                'name': model.User.__name__,
+                'attributes': [('Integer', 'id'), ('Unicode', 'name'), ],
+                'methods': ['login', ],
             }
 
     def test_subclass(self):
 
-        class User(self.BASE):
-            __tablename__ = 'user_table'
-
-            id = Column(Integer, primary_key=True)
-
-            def login(self):
-                pass
-
-        class Admin(self.BASE):
-            __tablename__ = 'admin_table'
-
-            id = Column(Integer, primary_key=True)
-
-            def permissions(self):
-                pass
-
-        objects, relations, inhirets = sadisplay.describe([User, Admin])
+        objects, relations, inherits = sadisplay \
+                .describe([model.User, model.Admin])
 
         assert len(objects) == 2
-        assert relations == []
-        assert inhirets == []
-        assert objects[0] == {
-                'name': User.__name__,
-                'attributes': [('Integer', 'id'),],
-                'methods': ['login',],
+        assert len(inherits) == 1
+        assert objects[1] == {
+                'name': model.Admin.__name__,
+                'attributes': [('Integer', 'id'),
+                    ('Unicode', 'name'),
+                    ('Unicode', 'phone'), ],
+                'methods': ['permissions', ],
             }
 
-        assert objects[1] == {
-                'name': Admin.__name__,
-                'attributes': [('Integer', 'id'),],
-                'methods': ['permissions',],
+        assert inherits[0] == {
+                'child': model.Admin.__name__,
+                'parent': model.User.__name__,
             }
 
-
 if __name__ == '__main__':
     unittest.main()