Commits

Brian Kearns committed 71538a3

unify sqlite3 between default and py3k

  • Participants
  • Parent commits 6316975

Comments (0)

Files changed (1)

lib_pypy/_sqlite3.py

 import weakref
 from threading import _get_ident as _thread_get_ident
 
+if sys.version_info[0] >= 3:
+    StandardError = Exception
+    long = int
+    xrange = range
+    basestring = unicode = str
+    buffer = memoryview
+    BLOB_TYPE = bytes
+else:
+    BLOB_TYPE = buffer
+
 names = "sqlite3.dll libsqlite3.so.0 libsqlite3.so libsqlite3.dylib".split()
 for name in names:
     try:
 ##########################################
 
 # SQLite version information
-sqlite_version = sqlite.sqlite3_libversion()
+sqlite_version = str(sqlite.sqlite3_libversion().decode('ascii'))
 
 class Error(StandardError):
     pass
 def unicode_text_factory(x):
     return unicode(x, 'utf-8')
 
+if sys.version_info[0] < 3:
+    def OptimizedUnicode(s):
+        try:
+            val = unicode(s, "ascii").encode("ascii")
+        except UnicodeDecodeError:
+            val = unicode(s, "utf-8")
+        return val
+else:
+    OptimizedUnicode = unicode_text_factory
+
 
 class _StatementCache(object):
     def __init__(self, connection, maxcount):
     @_check_thread_wrap
     @_check_closed_wrap
     def __call__(self, sql):
-        if not isinstance(sql, (str, unicode)):
+        if not isinstance(sql, basestring):
             raise Warning("SQL is of wrong type. Must be string or unicode.")
         return self._statement_cache.get(sql, self.row_factory)
 
     @_check_closed_wrap
     def create_collation(self, name, callback):
         name = name.upper()
-        if not all(c in string.uppercase + string.digits + '_' for c in name):
+        if not all(c in string.ascii_uppercase + string.digits + '_' for c in name):
             raise ProgrammingError("invalid character in collation name")
 
         if callback is None:
         if ret != SQLITE_OK:
             raise self._get_exception(ret)
 
+    if sys.version_info[0] >= 3:
+        def __get_in_transaction(self):
+            return self._in_transaction
+        in_transaction = property(__get_in_transaction)
+
     def __get_total_changes(self):
         self._check_closed()
         return sqlite.sqlite3_total_changes(self._db)
         if val is None:
             self.commit()
         else:
-            self.__begin_statement = b"BEGIN " + val.encode('ascii')
+            self.__begin_statement = str("BEGIN " + val).encode('utf-8')
         self._isolation_level = val
     isolation_level = property(__get_isolation_level, __set_isolation_level)
 
         try:
             self.__description = None
             self._reset = False
-            if not isinstance(sql, (str, unicode)):
+            if not isinstance(sql, basestring):
                 raise ValueError("operation parameter must be str or unicode")
             self.__statement = self.__connection._statement_cache.get(
                 sql, self.row_factory)
         try:
             self.__description = None
             self._reset = False
-            if not isinstance(sql, (str, unicode)):
+            if not isinstance(sql, basestring):
                 raise ValueError("operation parameter must be str or unicode")
             self.__statement = self.__connection._statement_cache.get(
                 sql, self.row_factory)
     def __init__(self, connection, sql):
         self.__con = connection
 
-        if not isinstance(sql, (str, unicode)):
+        if not isinstance(sql, basestring):
             raise ValueError("sql must be a string")
         first_word = self._statement_kind = sql.lstrip().split(" ")[0].upper()
         if first_word in ("INSERT", "UPDATE", "DELETE", "REPLACE"):
 
             self.__row_cast_map.append(converter)
 
-    def __check_decodable(self, param):
-        if self.__con.text_factory in (unicode, OptimizedUnicode, unicode_text_factory):
-            for c in param:
-                if ord(c) & 0x80 != 0:
-                    raise self.__con.ProgrammingError(
-                        "You must not use 8-bit bytestrings unless "
-                        "you use a text_factory that can interpret "
-                        "8-bit bytestrings (like text_factory = str). "
-                        "It is highly recommended that you instead "
-                        "just switch your application to Unicode strings.")
+    if sys.version_info[0] < 3:
+        def __check_decodable(self, param):
+            if self.__con.text_factory in (unicode, OptimizedUnicode,
+                                           unicode_text_factory):
+                for c in param:
+                    if ord(c) & 0x80 != 0:
+                        raise self.__con.ProgrammingError(
+                            "You must not use 8-bit bytestrings unless "
+                            "you use a text_factory that can interpret "
+                            "8-bit bytestrings (like text_factory = str). "
+                            "It is highly recommended that you instead "
+                            "just switch your application to Unicode strings.")
 
     def __set_param(self, idx, param):
         cvt = converters.get(type(param))
 
         if param is None:
             rc = sqlite.sqlite3_bind_null(self._statement, idx)
-        elif type(param) in (bool, int, long):
+        elif isinstance(param, (bool, int, long)):
             if -2147483648 <= param <= 2147483647:
                 rc = sqlite.sqlite3_bind_int(self._statement, idx, param)
             else:
                 rc = sqlite.sqlite3_bind_int64(self._statement, idx, param)
-        elif type(param) is float:
+        elif isinstance(param, float):
             rc = sqlite.sqlite3_bind_double(self._statement, idx, param)
+        elif isinstance(param, unicode):
+            param = param.encode("utf-8")
+            rc = sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT)
         elif isinstance(param, str):
             self.__check_decodable(param)
             rc = sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT)
-        elif isinstance(param, unicode):
-            param = param.encode("utf-8")
-            rc = sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT)
-        elif type(param) is buffer:
+        elif isinstance(param, (buffer, bytes)):
             param = bytes(param)
             rc = sqlite.sqlite3_bind_blob(self._statement, idx, param, len(param), SQLITE_TRANSIENT)
         else:
 
             converter = self.__row_cast_map[i]
             if converter is None:
-                if typ == SQLITE_INTEGER:
+                if typ == SQLITE_NULL:
+                    val = None
+                elif typ == SQLITE_INTEGER:
                     val = sqlite.sqlite3_column_int64(self._statement, i)
-                    if -sys.maxint-1 <= val <= sys.maxint:
-                        val = int(val)
                 elif typ == SQLITE_FLOAT:
                     val = sqlite.sqlite3_column_double(self._statement, i)
-                elif typ == SQLITE_BLOB:
-                    blob = sqlite.sqlite3_column_blob(self._statement, i)
-                    blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
-                    val = buffer(string_at(blob, blob_len))
-                elif typ == SQLITE_NULL:
-                    val = None
                 elif typ == SQLITE_TEXT:
                     text = sqlite.sqlite3_column_text(self._statement, i)
                     text_len = sqlite.sqlite3_column_bytes(self._statement, i)
                     val = string_at(text, text_len)
                     val = self.__con.text_factory(val)
+                elif typ == SQLITE_BLOB:
+                    blob = sqlite.sqlite3_column_blob(self._statement, i)
+                    blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
+                    val = BLOB_TYPE(string_at(blob, blob_len))
             else:
                 blob = sqlite.sqlite3_column_blob(self._statement, i)
                 if not blob:
     _params = []
     for i in range(nargs):
         typ = sqlite.sqlite3_value_type(params[i])
-        if typ == SQLITE_INTEGER:
+        if typ == SQLITE_NULL:
+            val = None
+        elif typ == SQLITE_INTEGER:
             val = sqlite.sqlite3_value_int64(params[i])
-            if -sys.maxint-1 <= val <= sys.maxint:
-                val = int(val)
         elif typ == SQLITE_FLOAT:
             val = sqlite.sqlite3_value_double(params[i])
+        elif typ == SQLITE_TEXT:
+            val = sqlite.sqlite3_value_text(params[i])
+            val = val.decode('utf-8')
         elif typ == SQLITE_BLOB:
             blob = sqlite.sqlite3_value_blob(params[i])
             blob_len = sqlite.sqlite3_value_bytes(params[i])
-            val = buffer(string_at(blob, blob_len))
-        elif typ == SQLITE_NULL:
-            val = None
-        elif typ == SQLITE_TEXT:
-            val = sqlite.sqlite3_value_text(params[i])
-            val = val.decode('utf-8')
+            val = BLOB_TYPE(string_at(blob, blob_len))
         else:
             raise NotImplementedError
         _params.append(val)
         sqlite.sqlite3_result_null(con)
     elif isinstance(val, (bool, int, long)):
         sqlite.sqlite3_result_int64(con, int(val))
-    elif isinstance(val, str):
-        sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT)
+    elif isinstance(val, float):
+        sqlite.sqlite3_result_double(con, val)
     elif isinstance(val, unicode):
         val = val.encode('utf-8')
         sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT)
-    elif isinstance(val, float):
-        sqlite.sqlite3_result_double(con, val)
-    elif isinstance(val, buffer):
+    elif isinstance(val, str):
+        sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT)
+    elif isinstance(val, (buffer, bytes)):
         sqlite.sqlite3_result_blob(con, bytes(val), len(val), SQLITE_TRANSIENT)
     else:
         raise NotImplementedError
             microseconds = int(timepart_full[1])
         else:
             microseconds = 0
-        return datetime.datetime(year, month, day,
-                                 hours, minutes, seconds, microseconds)
+        return datetime.datetime(year, month, day, hours, minutes, seconds,
+                                 microseconds)
 
     register_adapter(datetime.date, adapt_date)
     register_adapter(datetime.datetime, adapt_datetime)
     return val
 
 register_adapters_and_converters()
-
-
-def OptimizedUnicode(s):
-    try:
-        val = unicode(s, "ascii").encode("ascii")
-    except UnicodeDecodeError:
-        val = unicode(s, "utf-8")
-    return val