Commits

Mike Bayer  committed 7cf4a72

new MySQL types: MSEnum, MSTinyText, MSMediumText, MSLongText, etc.
more support for MS-specific length/precision params in numeric types
patch courtesy Mike Bernson

  • Participants
  • Parent commits 380a78e

Comments (0)

Files changed (2)

 - cursor() method on ConnectionFairy allows db-specific extension
 arguments to be propigated [ticket:221]
 - lazy load bind params properly propigate column type [ticket:225]
+- new MySQL types: MSEnum, MSTinyText, MSMediumText, MSLongText, etc.
+more support for MS-specific length/precision params in numeric types
+patch courtesy Mike Bernson
 
 0.2.3
 - overhaul to mapper compilation to be deferred.  this allows mappers

File lib/sqlalchemy/databases/mysql.py

     import MySQLdb as mysql
 except:
     mysql = None
-    
+
+def kw_colspec(self, spec):
+    if self.unsigned:
+        spec += ' UNSIGNED'
+    if self.zerofill:
+        spec += ' ZEROFILL'
+    return spec
+        
 class MSNumeric(sqltypes.Numeric):
+    def __init__(self, precision = 10, length = 2, **kw):
+        self.unsigned = 'unsigned' in kw
+        self.zerofill = 'zerofill' in kw
+        super(MSNumeric, self).__init__(precision, length)
     def get_col_spec(self):
-        return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+        return kw_colspec(self, "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
+class MSDecimal(MSNumeric):
+    def get_col_spec(self):
+        if self.precision is not None and self.length is not None:
+            return kw_colspec(self, "DECIMAL(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
 class MSDouble(sqltypes.Numeric):
-    def __init__(self, precision = None, length = None):
+    def __init__(self, precision=10, length=2, **kw):
         if (precision is None and length is not None) or (precision is not None and length is None):
             raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.")
+        self.unsigned = 'unsigned' in kw
+        self.zerofill = 'zerofill' in kw
         super(MSDouble, self).__init__(precision, length)
     def get_col_spec(self):
         if self.precision is not None and self.length is not None:
             return "DOUBLE(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
         else:
-            return "DOUBLE"
+            return kw_colspec(self, "DOUBLE")
 class MSFloat(sqltypes.Float):
-    def __init__(self, precision = None):
+    def __init__(self, precision=10, length=None, **kw):
+        if length is not None:
+            self.length=length
+        self.unsigned = 'unsigned' in kw
+        self.zerofill = 'zerofill' in kw
         super(MSFloat, self).__init__(precision)
     def get_col_spec(self):
-        if self.precision is not None:
-            return "FLOAT(%(precision)s)" % {'precision': self.precision}
+        if hasattr(self, 'length') and self.length is not None:
+            return kw_colspec(self, "FLOAT(%(precision)s,%(length)s)" % {'precision': self.precision, 'length' : self.length})
+        elif self.precision is not None:
+            return kw_colspec(self, "FLOAT(%(precision)s)" % {'precision': self.precision})
         else:
-            return "FLOAT"
+            return kw_colspec(self, "FLOAT")
+class MSBigInteger(sqltypes.Integer):
+    def __init__(self, length=None, **kw):
+        self.length = length
+        self.unsigned = 'unsigned' in kw
+        self.zerofill = 'zerofill' in kw
+        super(MSBigInteger, self).__init__()
+    def get_col_spec(self):
+        if self.length is not None:
+            return kw_colspec(self, "BIGINT(%(length)s)" % {'length': self.length})
+        else:
+            return kw_colspec(self, "BIGINT")
 class MSInteger(sqltypes.Integer):
+    def __init__(self, length=None, **kw):
+        self.length = length
+        self.unsigned = 'unsigned' in kw
+        self.zerofill = 'zerofill' in kw
+        super(MSInteger, self).__init__()
     def get_col_spec(self):
-        return "INTEGER"
+        if self.length is not None:
+            return kw_colspec(self, "INTEGER(%(length)s)" % {'length': self.length})
+        else:
+            return kw_colspec(self, "INTEGER")
 class MSSmallInteger(sqltypes.Smallinteger):
+    def __init__(self, length=None, **kw):
+        self.length = length
+        self.unsigned = 'unsigned' in kw
+        self.zerofill = 'zerofill' in kw
+        super(MSSmallInteger, self).__init__()
     def get_col_spec(self):
-        return "SMALLINT"
+        if self.length is not None:
+            return kw_colspec(self, "SMALLINT(%(length)s)" % {'length': self.length})
+        else:
+            return kw_colspec(self, "SMALLINT")
 class MSDateTime(sqltypes.DateTime):
     def get_col_spec(self):
         return "DATETIME"
             return None
             
 class MSText(sqltypes.TEXT):
+    def __init__(self, **kw):
+        self.binary = 'binary' in kw
+        super(MSText, self).__init__()
     def get_col_spec(self):
         return "TEXT"
+class MSTinyText(sqltypes.TEXT):
+    def __init__(self, **kw):
+        self.binary = 'binary' in kw
+        super(MSTinyText, self).__init__()
+    def get_col_spec(self):
+        if self.binary:
+            return "TEXT BINARY"
+        else:
+           return "TEXT"
+class MSMediumText(sqltypes.TEXT):
+    def __init__(self, **kw):
+        self.binary = 'binary' in kw
+        super(MSMediumText, self).__init__()
+    def get_col_spec(self):
+        if self.binary:
+            return "MEDIUMTEXT BINARY"
+        else:
+            return "MEDIUMTEXT"
+class MSLongText(sqltypes.TEXT):
+    def __init__(self, **kw):
+        self.binary = 'binary' in kw
+        super(MSLongText, self).__init__()
+    def get_col_spec(self):
+        if self.binary:
+            return "LONGTEXT BINARY"
+        else:
+            return "LONGTEXT"
 class MSString(sqltypes.String):
     def __init__(self, length=None, *extra):
         sqltypes.String.__init__(self, length=length)
 class MSBinary(sqltypes.Binary):
     def get_col_spec(self):
         if self.length is not None and self.length <=255:
-            # the binary type seems to return a value that is null-padded
+            # the binary2G type seems to return a value that is null-padded
             return "BINARY(%d)" % self.length
         else:
             return "BLOB"
             return None
         else:
             return buffer(value)
+class MSEnum(sqltypes.String):
+    def __init__(self, *enums):
+        self.__enums_hidden = enums
+        length = 0
+        strip_enums = []
+        for a in enums:
+            if a[0:1] == '"' or a[0:1] == "'":
+                a = a[1:-1]
+            if len(a) > length:
+                length=len(a)
+            strip_enums.append(a)
+        self.enums = strip_enums
+        super(MSEnum, self).__init__(length)
+    def get_col_spec(self):
+        return "ENUM(%s)" % ",".join(self.__enums_hidden)
+        
 
 class MSBoolean(sqltypes.Boolean):
     def get_col_spec(self):
         return "BOOLEAN"
         
 colspecs = {
+#    sqltypes.BIGinteger : MSInteger,
     sqltypes.Integer : MSInteger,
     sqltypes.Smallinteger : MSSmallInteger,
     sqltypes.Numeric : MSNumeric,
 }
 
 ischema_names = {
+    'bigint' : MSBigInteger,
     'int' : MSInteger,
     'smallint' : MSSmallInteger,
     'tinyint' : MSSmallInteger, 
     'varchar' : MSString,
     'char' : MSChar,
     'text' : MSText,
-    'decimal' : MSNumeric,
+    'tinytext' : MSTinyText,
+    'mediumtext': MSMediumText,
+    'longtext': MSLongText,
+    'decimal' : MSDecimal,
+    'numeric' : MSNumeric,
     'float' : MSFloat,
     'double' : MSDouble,
     'timestamp' : MSDateTime,
     'time' : MSTime,
     'binary' : MSBinary,
     'blob' : MSBinary,
+    'enum': MSEnum,
 }
 
 def engine(opts, **params):
 
             (name, type, nullable, primary_key, default) = (row[0], row[1], row[2] == 'YES', row[3] == 'PRI', row[4])
             
-            match = re.match(r'(\w+)(\(.*?\))?', type)
-            coltype = match.group(1)
+            match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type)
+            col_type = match.group(1)
             args = match.group(2)
-            
-            #print "coltype: " + repr(coltype) + " args: " + repr(args)
-            coltype = ischema_names.get(coltype, MSString)
+            extra_1 = match.group(3)
+            extra_2 = match.group(4)
+
+            #print "coltype: " + repr(col_type) + " args: " + repr(args) + "extras:" + repr(extra_1) + ' ' + repr(extra_2)
+            coltype = ischema_names.get(col_type, MSString)
+            kw = {}
+            if extra_1 is not None:
+                kw[extra_1] = True
+            if extra_2 is not None:
+                kw[extra_2] = True
+
             if args is not None:
-                args = re.findall(r'(\d+)', args)
-                #print "args! " +repr(args)
-                coltype = coltype(*[int(a) for a in args])
+                if col_type == 'enum':
+                    args= args[1:-1]
+                    argslist = args.split(',')
+                    coltype = coltype(*argslist, **kw)
+                else:
+                    argslist = re.findall(r'(\d+)', args)
+                    coltype = coltype(*[int(a) for a in argslist], **kw)
             
             arglist = []
             fkey = foreignkeyD.get(name)