Commits

Brian Kearns committed 36a3c5c Merge

merge default

Comments (0)

Files changed (2)

lib_pypy/_sqlite3.py

     sqlite.sqlite3_enable_load_extension.argtypes = [c_void_p, c_int]
     sqlite.sqlite3_enable_load_extension.restype = c_int
 
-_DML, _DQL, _DDL = range(3)
-
 ##########################################
 # END Wrapped SQLite C API and constants
 ##########################################
             if len(self.cache) > self.maxcount:
                 self.cache.popitem(0)
 
-        if stat.in_use:
+        if stat._in_use:
             stat = Statement(self.connection, sql)
-        stat.set_row_factory(row_factory)
+        stat._row_factory = row_factory
         return stat
 
 
         for statement in self.__statements:
             obj = statement()
             if obj is not None:
-                obj.finalize()
+                obj._finalize()
 
         if self._db:
             ret = sqlite.sqlite3_close(self._db)
         for statement in self.__statements:
             obj = statement()
             if obj is not None:
-                obj.reset()
+                obj._reset()
 
         statement = c_void_p()
         next_char = c_char_p()
         for statement in self.__statements:
             obj = statement()
             if obj is not None:
-                obj.reset()
+                obj._reset()
 
         for cursor_ref in self._cursors:
             cursor = cursor_ref()
             except ValueError:
                 pass
         if self.__statement:
-            self.__statement.reset()
+            self.__statement._reset()
 
     def close(self):
         self.__connection._check_thread()
         self.__connection._check_closed()
         if self.__statement:
-            self.__statement.reset()
+            self.__statement._reset()
             self.__statement = None
         self.__closed = True
 
         self.__connection._check_thread()
         self.__connection._check_closed()
 
+    def __check_cursor_wrap(func):
+        @wraps(func)
+        def wrapper(self, *args, **kwargs):
+            self.__check_cursor()
+            return func(self, *args, **kwargs)
+        return wrapper
+
+    @__check_cursor_wrap
     def execute(self, sql, params=None):
-        self.__check_cursor()
         self.__locked = True
         try:
             self.__description = None
                 sql, self.row_factory)
 
             if self.__connection._isolation_level is not None:
-                if self.__statement.kind == _DDL:
+                if self.__statement._kind == Statement._DDL:
                     if self.__connection._in_transaction:
                         self.__connection.commit()
-                elif self.__statement.kind == _DML:
+                elif self.__statement._kind == Statement._DML:
                     if not self.__connection._in_transaction:
                         self.__connection._begin()
 
-            self.__statement.set_params(params)
+            self.__statement._set_params(params)
 
             # Actually execute the SQL statement
-            ret = sqlite.sqlite3_step(self.__statement.statement)
+            ret = sqlite.sqlite3_step(self.__statement._statement)
             if ret not in (SQLITE_DONE, SQLITE_ROW):
-                self.__statement.reset()
+                self.__statement._reset()
                 self.__connection._in_transaction = \
                         not sqlite.sqlite3_get_autocommit(self.__connection._db)
                 raise self.__connection._get_exception(ret)
 
-            if self.__statement.kind == _DML:
-                self.__statement.reset()
+            if self.__statement._kind == Statement._DML:
+                self.__statement._reset()
 
-            if self.__statement.kind == _DQL and ret == SQLITE_ROW:
+            if self.__statement._kind == Statement._DQL and ret == SQLITE_ROW:
                 self.__statement._build_row_cast_map()
                 self.__statement._readahead(self)
             else:
-                self.__statement.item = None
-                self.__statement.exhausted = True
+                self.__statement._item = None
+                self.__statement._exhausted = True
 
             self.__rowcount = -1
-            if self.__statement.kind == _DML:
+            if self.__statement._kind == Statement._DML:
                 self.__rowcount = sqlite.sqlite3_changes(self.__connection._db)
         finally:
             self.__locked = False
 
         return self
 
+    @__check_cursor_wrap
     def executemany(self, sql, many_params):
-        self.__check_cursor()
         self.__locked = True
         try:
             self.__description = None
             self.__statement = self.__connection._statement_cache.get(
                 sql, self.row_factory)
 
-            if self.__statement.kind == _DML:
+            if self.__statement._kind == Statement._DML:
                 if self.__connection._isolation_level is not None:
                     if not self.__connection._in_transaction:
                         self.__connection._begin()
 
             self.__rowcount = 0
             for params in many_params:
-                self.__statement.set_params(params)
-                ret = sqlite.sqlite3_step(self.__statement.statement)
+                self.__statement._set_params(params)
+                ret = sqlite.sqlite3_step(self.__statement._statement)
                 if ret != SQLITE_DONE:
-                    self.__statement.reset()
+                    self.__statement._reset()
                     self.__connection._in_transaction = \
                             not sqlite.sqlite3_get_autocommit(self.__connection._db)
                     raise self.__connection._get_exception(ret)
                 self.__rowcount += sqlite.sqlite3_changes(self.__connection._db)
-            self.__statement.reset()
+            self.__statement._reset()
         finally:
             self.__locked = False
 
             return None
 
         try:
-            return self.__statement.next(self)
+            return self.__statement._next(self)
         except StopIteration:
             return None
 
 
 
 class Statement(object):
-    statement = None
+    _DML, _DQL, _DDL = range(3)
+
+    _statement = None
 
     def __init__(self, connection, sql):
+        self.__con = connection
+
         if not isinstance(sql, str):
             raise ValueError("sql must be a string")
-        self.con = connection
-        self.sql = sql  # DEBUG ONLY
         first_word = self._statement_kind = sql.lstrip().split(" ")[0].upper()
         if first_word in ("INSERT", "UPDATE", "DELETE", "REPLACE"):
-            self.kind = _DML
+            self._kind = Statement._DML
         elif first_word in ("SELECT", "PRAGMA"):
-            self.kind = _DQL
+            self._kind = Statement._DQL
         else:
-            self.kind = _DDL
-        self.exhausted = False
-        self.in_use = False
-        #
-        # set by set_row_factory
-        self.row_factory = None
+            self._kind = Statement._DDL
 
-        self.statement = c_void_p()
+        self._in_use = False
+        self._exhausted = False
+        self._row_factory = None
+
+        self._statement = c_void_p()
         next_char = c_char_p()
         sql_char = sql
-        ret = sqlite.sqlite3_prepare_v2(self.con._db, sql_char, -1, byref(self.statement), byref(next_char))
-        if ret == SQLITE_OK and self.statement.value is None:
+        ret = sqlite.sqlite3_prepare_v2(self.__con._db, sql_char, -1, byref(self._statement), byref(next_char))
+        if ret == SQLITE_OK and self._statement.value is None:
             # an empty statement, we work around that, as it's the least trouble
-            ret = sqlite.sqlite3_prepare_v2(self.con._db, "select 42", -1, byref(self.statement), byref(next_char))
-            self.kind = _DQL
+            ret = sqlite.sqlite3_prepare_v2(self.__con._db, "select 42", -1, byref(self._statement), byref(next_char))
+            self._kind = Statement._DQL
 
         if ret != SQLITE_OK:
-            raise self.con._get_exception(ret)
-        self.con._remember_statement(self)
+            raise self.__con._get_exception(ret)
+        self.__con._remember_statement(self)
         next_char = next_char.value.decode('utf-8')
         if _check_remaining_sql(next_char):
             raise Warning("One and only one statement required: %r" %
                           (next_char,))
-        # sql_char should remain alive until here
 
-        self._build_row_cast_map()
+    def __del__(self):
+        if self._statement:
+            sqlite.sqlite3_finalize(self._statement)
 
-    def set_row_factory(self, row_factory):
-        self.row_factory = row_factory
+    def _finalize(self):
+        sqlite.sqlite3_finalize(self._statement)
+        self._statement = None
+        self._in_use = False
+
+    def _reset(self):
+        ret = sqlite.sqlite3_reset(self._statement)
+        self._in_use = False
+        self._exhausted = False
+        return ret
 
     def _build_row_cast_map(self):
-        self.row_cast_map = []
-        for i in range(sqlite.sqlite3_column_count(self.statement)):
+        self.__row_cast_map = []
+        for i in range(sqlite.sqlite3_column_count(self._statement)):
             converter = None
 
-            if self.con._detect_types & PARSE_COLNAMES:
-                colname = sqlite.sqlite3_column_name(self.statement, i)
+            if self.__con._detect_types & PARSE_COLNAMES:
+                colname = sqlite.sqlite3_column_name(self._statement, i)
                 if colname is not None:
                     colname = colname.decode('utf-8')
                     type_start = -1
                             key = colname[type_start:pos]
                             converter = converters[key.upper()]
 
-            if converter is None and self.con._detect_types & PARSE_DECLTYPES:
-                decltype = sqlite.sqlite3_column_decltype(self.statement, i)
+            if converter is None and self.__con._detect_types & PARSE_DECLTYPES:
+                decltype = sqlite.sqlite3_column_decltype(self._statement, i)
                 if decltype is not None:
                     decltype = decltype.split()[0]      # if multiple words, use first, eg. "INTEGER NOT NULL" => "INTEGER"
                     decltype = decltype.decode('utf-8')
                         decltype = decltype[:decltype.index('(')]
                     converter = converters.get(decltype.upper(), None)
 
-            self.row_cast_map.append(converter)
+            self.__row_cast_map.append(converter)
 
-    def set_param(self, idx, param):
+    def __set_param(self, idx, param):
         cvt = converters.get(type(param))
         if cvt is not None:
             cvt = param = cvt(param)
         param = adapt(param)
 
         if param is None:
-            sqlite.sqlite3_bind_null(self.statement, idx)
+            sqlite.sqlite3_bind_null(self._statement, idx)
         elif type(param) in (bool, int):
             if -2147483648 <= param <= 2147483647:
-                sqlite.sqlite3_bind_int(self.statement, idx, param)
+                sqlite.sqlite3_bind_int(self._statement, idx, param)
             else:
-                sqlite.sqlite3_bind_int64(self.statement, idx, param)
+                sqlite.sqlite3_bind_int64(self._statement, idx, param)
         elif type(param) is float:
-            sqlite.sqlite3_bind_double(self.statement, idx, param)
+            sqlite.sqlite3_bind_double(self._statement, idx, param)
         elif isinstance(param, str):
             param = param.encode('utf-8')
-            sqlite.sqlite3_bind_text(self.statement, idx, param, len(param), SQLITE_TRANSIENT)
+            sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT)
         elif type(param) in (bytes, memoryview):
             param = bytes(param)
-            sqlite.sqlite3_bind_blob(self.statement, idx, param, len(param), SQLITE_TRANSIENT)
+            sqlite.sqlite3_bind_blob(self._statement, idx, param, len(param), SQLITE_TRANSIENT)
         else:
             raise InterfaceError("parameter type %s is not supported" %
                                  type(param))
 
-    def set_params(self, params):
-        ret = sqlite.sqlite3_reset(self.statement)
+    def _set_params(self, params):
+        ret = sqlite.sqlite3_reset(self._statement)
         if ret != SQLITE_OK:
-            raise self.con._get_exception(ret)
-        self.mark_dirty()
+            raise self.__con._get_exception(ret)
+        self._in_use = True
 
         if params is None:
-            if sqlite.sqlite3_bind_parameter_count(self.statement) != 0:
+            if sqlite.sqlite3_bind_parameter_count(self._statement) != 0:
                 raise ProgrammingError("wrong number of arguments")
             return
 
             params_type = list
 
         if params_type == list:
-            if len(params) != sqlite.sqlite3_bind_parameter_count(self.statement):
+            if len(params) != sqlite.sqlite3_bind_parameter_count(self._statement):
                 raise ProgrammingError("wrong number of arguments")
 
             for i in range(len(params)):
-                self.set_param(i+1, params[i])
+                self.__set_param(i+1, params[i])
         else:
-            for idx in range(1, sqlite.sqlite3_bind_parameter_count(self.statement) + 1):
-                param_name = sqlite.sqlite3_bind_parameter_name(self.statement, idx)
+            for idx in range(1, sqlite.sqlite3_bind_parameter_count(self._statement) + 1):
+                param_name = sqlite.sqlite3_bind_parameter_name(self._statement, idx)
                 if param_name is None:
                     raise ProgrammingError("need named parameters")
                 param_name = param_name[1:].decode('utf-8')
                     param = params[param_name]
                 except KeyError:
                     raise ProgrammingError("missing parameter %r" % param_name)
-                self.set_param(idx, param)
+                self.__set_param(idx, param)
 
-    def next(self, cursor):
-        self.con._check_closed()
-        self.con._check_thread()
-        if self.exhausted:
+    def _next(self, cursor):
+        self.__con._check_closed()
+        self.__con._check_thread()
+        if self._exhausted:
             raise StopIteration
-        item = self.item
+        item = self._item
 
-        ret = sqlite.sqlite3_step(self.statement)
+        ret = sqlite.sqlite3_step(self._statement)
         if ret == SQLITE_DONE:
-            self.exhausted = True
-            self.item = None
+            self._exhausted = True
+            self._item = None
         elif ret != SQLITE_ROW:
-            exc = self.con._get_exception(ret)
-            sqlite.sqlite3_reset(self.statement)
+            exc = self.__con._get_exception(ret)
+            sqlite.sqlite3_reset(self._statement)
             raise exc
 
         self._readahead(cursor)
         return item
 
     def _readahead(self, cursor):
-        self.column_count = sqlite.sqlite3_column_count(self.statement)
+        self.column_count = sqlite.sqlite3_column_count(self._statement)
         row = []
         for i in range(self.column_count):
-            typ = sqlite.sqlite3_column_type(self.statement, i)
+            typ = sqlite.sqlite3_column_type(self._statement, i)
 
-            converter = self.row_cast_map[i]
+            converter = self.__row_cast_map[i]
             if converter is None:
                 if typ == SQLITE_INTEGER:
-                    val = sqlite.sqlite3_column_int64(self.statement, i)
+                    val = sqlite.sqlite3_column_int64(self._statement, i)
                     if -sys.maxsize-1 <= val <= sys.maxsize:
                         val = int(val)
                 elif typ == SQLITE_FLOAT:
-                    val = sqlite.sqlite3_column_double(self.statement, i)
+                    val = sqlite.sqlite3_column_double(self._statement, i)
                 elif typ == SQLITE_BLOB:
-                    blob_len = sqlite.sqlite3_column_bytes(self.statement, i)
-                    blob = sqlite.sqlite3_column_blob(self.statement, i)
+                    blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
+                    blob = sqlite.sqlite3_column_blob(self._statement, i)
                     val = bytes(string_at(blob, blob_len))
                 elif typ == SQLITE_NULL:
                     val = None
                 elif typ == SQLITE_TEXT:
-                    text_len = sqlite.sqlite3_column_bytes(self.statement, i)
-                    text = sqlite.sqlite3_column_text(self.statement, i)
+                    text_len = sqlite.sqlite3_column_bytes(self._statement, i)
+                    text = sqlite.sqlite3_column_text(self._statement, i)
                     val = string_at(text, text_len)
-                    val = self.con.text_factory(val)
+                    val = self.__con.text_factory(val)
             else:
-                blob = sqlite.sqlite3_column_blob(self.statement, i)
+                blob = sqlite.sqlite3_column_blob(self._statement, i)
                 if not blob:
                     val = None
                 else:
-                    blob_len = sqlite.sqlite3_column_bytes(self.statement, i)
+                    blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
                     val = string_at(blob, blob_len)
                     val = converter(val)
             row.append(val)
 
         row = tuple(row)
-        if self.row_factory is not None:
-            row = self.row_factory(cursor, row)
-        self.item = row
-
-    def reset(self):
-        self.row_cast_map = None
-        ret = sqlite.sqlite3_reset(self.statement)
-        self.in_use = False
-        self.exhausted = False
-        return ret
-
-    def finalize(self):
-        sqlite.sqlite3_finalize(self.statement)
-        self.statement = None
-        self.in_use = False
-
-    def mark_dirty(self):
-        self.in_use = True
-
-    def __del__(self):
-        if self.statement:
-            sqlite.sqlite3_finalize(self.statement)
+        if self._row_factory is not None:
+            row = self._row_factory(cursor, row)
+        self._item = row
 
     def _get_description(self):
-        if self.kind == _DML:
+        if self._kind == Statement._DML:
             return None
         desc = []
-        for i in range(sqlite.sqlite3_column_count(self.statement)):
-            col_name = sqlite.sqlite3_column_name(self.statement, i)
+        for i in range(sqlite.sqlite3_column_count(self._statement)):
+            col_name = sqlite.sqlite3_column_name(self._statement, i)
             name = col_name.decode('utf-8').split("[")[0].strip()
             desc.append((name, None, None, None, None, None, None))
         return desc

pypy/module/test_lib_pypy/test_sqlite3.py

      cur.close()
      con.close()
      pytest.raises(_sqlite3.ProgrammingError, "cur.close()")
+     # raises ProgrammingError because should check closed before check args
+     pytest.raises(_sqlite3.ProgrammingError, "cur.execute(1,2,3,4,5)")
+     pytest.raises(_sqlite3.ProgrammingError, "cur.executemany(1,2,3,4,5)")
 
 @pytest.mark.skipif("not hasattr(sys, 'pypy_translation_info')")
 def test_cursor_del():
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.