Commits

Mike Bayer committed 27d1fa9

more hammering of defaults. ORM will properly execute defaults and post-fetch rows that contain passive defaults

Comments (0)

Files changed (7)

lib/sqlalchemy/ansisql.py

 
     def get_column_default_string(self, column):
         if isinstance(column.default, schema.PassiveDefault):
-            if not isinstance(column.default.arg, str):
-                arg = str(column.default.arg.compile(self.engine))
+            if isinstance(column.default.arg, str):
+                return repr(column.default.arg)
             else:
-                arg = column.default.arg
-            return arg
+                return str(column.default.arg.compile(self.engine))
         else:
             return None
 

lib/sqlalchemy/databases/information_schema.py

         coltype = coltype(*args)
         colargs= []
         if default is not None:
-            colargs.append(PassiveDefault(default))
+            colargs.append(PassiveDefault(sql.text(default, escape=False)))
         table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs))
 
     s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True)

lib/sqlalchemy/mapping/mapper.py

                             # matching the bindparam we are creating below, i.e. "<tablename>_<colname>"
                             params[col.table.name + "_" + col.key] = self._getattrbycolumn(obj, col)
                         else:
-                            # doing an INSERT? if the primary key values are not populated,
+                            # doing an INSERT, primary key col ? 
+                            # if the primary key values are not populated,
                             # leave them out of the INSERT altogether, since PostGres doesn't want
                             # them to be present for SERIAL to take effect.  A SQLEngine that uses
                             # explicit sequences will put them back in if they are needed
                                     params[col.key] = a[0]
                                     hasdata = True
                         else:
-                            # doing an INSERT ? add the attribute's value to the 
-                            # bind parameters
-                            params[col.key] = self._getattrbycolumn(obj, col)
+                            # doing an INSERT, non primary key col ? 
+                            # add the attribute's value to the 
+                            # bind parameters, unless its None and the column has a 
+                            # default.  if its None and theres no default, we still might
+                            # not want to put it in the col list but SQLIte doesnt seem to like that
+                            # if theres no columns at all
+                            value = self._getattrbycolumn(obj, col)
+                            if col.default is None or value is not None:
+                                params[col.key] = value
 
                 if not isinsert:
                     if hasdata:
                             clause.clauses.append(p == self._getattrbycolumn(obj, p))
                         row = table.select(clause).execute().fetchone()
                         for c in table.c:
-                            if self._getattrbycolumn(obj, col) is None:
-                                self._setattrbycolumn(obj, col, row[c])
+                            if self._getattrbycolumn(obj, c) is None:
+                                self._setattrbycolumn(obj, c, row[c])
                     self.extension.after_insert(self, obj)
                     
     def delete_obj(self, objects, uow):

lib/sqlalchemy/sql.py

     being specified as a bind parameter via the bindparam() method,
     since it provides more information about what it is, including an optional
     type, as well as providing comparison operations."""
-    def __init__(self, text = "", engine=None, bindparams=None, typemap=None):
+    def __init__(self, text = "", engine=None, bindparams=None, typemap=None, escape=True):
         self.parens = False
         self._engine = engine
         self.id = id(self)
         def repl(m):
             self.bindparams[m.group(1)] = bindparam(m.group(1))
             return self.engine.bindtemplate % m.group(1)
-           
-        self.text = re.compile(r':([\w_]+)', re.S).sub(repl, text)
+        
+        if escape: 
+            self.text = re.compile(r':([\w_]+)', re.S).sub(repl, text)
+        else:
+            self.text = text
         if bindparams is not None:
             for b in bindparams:
                 self.bindparams[b.key] = b
         
         use_function_defaults = testbase.db.engine.__module__.endswith('postgres') or testbase.db.engine.__module__.endswith('oracle')
         
+        use_string_defaults = use_function_defaults or testbase.db.engine.__module__.endswith('sqlite')
+
         if use_function_defaults:
             defval = func.current_date()
             deftype = Date
         else:
             defval = "3"
             deftype = Integer
+
+        if use_string_defaults:
+            deftype2 = String
+            defval2 = "im a default"
+        else:
+            deftype2 = Integer
+            defval2 = "15"
             
         users = Table('engine_users', testbase.db,
             Column('user_id', INT, primary_key = True),
             Column('test7', String),
             Column('test8', Binary),
             Column('test_passivedefault', deftype, PassiveDefault(defval)),
+            Column('test_passivedefault2', Integer, PassiveDefault("5")),
+            Column('test_passivedefault3', deftype2, PassiveDefault(defval2)),
             Column('test9', Binary(100)),
             mysql_engine='InnoDB'
         )

test/objectstore.py

         objectstore.clear()
         e2 = Entry.mapper.get(e.multi_id, 2)
         self.assert_(e is not e2 and e._instance_key == e2._instance_key)
+
+class DefaultTest(AssertMixin):
+    def setUpAll(self):
+        #db.echo = 'debug'
+        use_string_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle') or db.engine.__module__.endswith('sqlite')
+
+        if use_string_defaults:
+            hohotype = String
+            self.hohoval = "im hoho"
+            self.althohoval = "im different hoho"
+        else:
+            hohotype = Integer
+            self.hohoval = 9
+            self.althohoval = 15
+        self.table = Table('default_test', db,
+        Column('id', Integer, Sequence("dt_seq", optional=True), primary_key=True),
+        Column('hoho', hohotype, PassiveDefault(str(self.hohoval))),
+        Column('counter', Integer, PassiveDefault("7")),
+        Column('foober', String, default="im foober")
+        )
+        self.table.create()
+    def tearDownAll(self):
+        self.table.drop()
+    def testbasic(self):
         
+        class Hoho(object):pass
+        assign_mapper(Hoho, self.table)
+        h1 = Hoho(hoho=self.althohoval)
+        h2 = Hoho(counter=12)
+        h3 = Hoho(hoho=self.althohoval, counter=12)
+        h4 = Hoho()
+        h5 = Hoho(foober='im the new foober')
+        objectstore.commit()
+        self.assert_(h1.hoho==self.althohoval)
+        self.assert_(h3.hoho==self.althohoval)
+        self.assert_(h2.hoho==h4.hoho==h5.hoho==self.hohoval)
+        self.assert_(h3.counter == h2.counter == 12)
+        self.assert_(h1.counter ==  h4.counter==h5.counter==7)
+        self.assert_(h2.foober == h3.foober == h4.foober == 'im foober')
+        self.assert_(h5.foober=='im the new foober')
+        objectstore.clear()
+        l = Hoho.mapper.select()
+        (h1, h2, h3, h4, h5) = l
+        self.assert_(h1.hoho==self.althohoval)
+        self.assert_(h3.hoho==self.althohoval)
+        self.assert_(h2.hoho==h4.hoho==h5.hoho==self.hohoval)
+        self.assert_(h3.counter == h2.counter == 12)
+        self.assert_(h1.counter ==  h4.counter==h5.counter==7)
+        self.assert_(h2.foober == h3.foober == h4.foober == 'im foober')
+        self.assert_(h5.foober=='im the new foober')
+            
 class SaveTest(AssertMixin):
 
     def setUpAll(self):
         f = select([func.count(1)], engine=db).scalar()
         if use_function_defaults:
             def1 = func.current_date()
-            def2 = "current_date"
+            def2 = text("current_date")
             deftype = Date
             ts = select([func.current_date()], engine=db).scalar()
         else:
         t.create()
         try:
             t.insert().execute()
+            self.assert_(t.engine.lastrow_has_defaults())
             t.insert().execute()
             t.insert().execute()