Commits

Brian Kearns  committed 878a364 Merge

merge default

  • Participants
  • Parent commits 6938c45, 51cce46
  • Branches py3k

Comments (0)

Files changed (1)

File lib_pypy/_sqlite3.py

     return str(x, 'utf-8')
 
 
-class StatementCache(object):
+class _StatementCache(object):
     def __init__(self, connection, maxcount):
         self.connection = connection
         self.maxcount = maxcount
             self.cache[sql] = stat
             if len(self.cache) > self.maxcount:
                 self.cache.popitem(0)
-        #
+
         if stat.in_use:
             stat = Statement(self.connection, sql)
         stat.set_row_factory(row_factory)
         self._cursors = []
         self.__statements = []
         self.__statement_counter = 0
-        self._statement_cache = StatementCache(self, cached_statements)
+        self._statement_cache = _StatementCache(self, cached_statements)
 
         self.__func_cache = {}
         self.__aggregates = {}
 
     def _check_closed_wrap(func):
         @wraps(func)
-        def _check_closed_func(self, *args, **kwargs):
+        def wrapper(self, *args, **kwargs):
             self._check_closed()
             return func(self, *args, **kwargs)
-        return _check_closed_func
+        return wrapper
 
     def _check_thread(self):
         try:
 
     def _check_thread_wrap(func):
         @wraps(func)
-        def _check_thread_func(self, *args, **kwargs):
+        def wrapper(self, *args, **kwargs):
             self._check_thread()
             return func(self, *args, **kwargs)
-        return _check_thread_func
+        return wrapper
 
     def _get_exception(self, error_code=None):
         if error_code is None:
     def __call__(self, sql):
         if not isinstance(sql, str):
             raise Warning("SQL is of wrong type. Must be string or unicode.")
-        statement = self._statement_cache.get(sql, self.row_factory)
-        return statement
+        return self._statement_cache.get(sql, self.row_factory)
 
     def cursor(self, factory=None):
         self._check_thread()
                 raise OperationalError("Error enabling load extension")
 
 
-class _CursorLock(object):
-    def __init__(self, cursor):
-        self.cursor = cursor
-
-    def __enter__(self):
-        self.cursor._check_closed()
-        if self.cursor._locked:
-            raise ProgrammingError("Recursive use of cursors not allowed.")
-        self.cursor._locked = True
-
-    def __exit__(self, *args):
-        self.cursor._locked = False
-
-
 class Cursor(object):
     __initialized = False
     __connection = None
 
         self.arraysize = 1
         self.row_factory = None
-        self._locked = False
         self._reset = False
+        self.__locked = False
         self.__closed = False
         self.__description = None
         self.__rowcount = -1
             self.__statement = None
         self.__closed = True
 
-    def _check_closed(self):
+    def __check_cursor(self):
         if not self.__initialized:
             raise ProgrammingError("Base Cursor.__init__ not called.")
         if self.__closed:
             raise ProgrammingError("Cannot operate on a closed cursor.")
+        if self.__locked:
+            raise ProgrammingError("Recursive use of cursors not allowed.")
         self.__connection._check_thread()
         self.__connection._check_closed()
 
     def execute(self, sql, params=None):
-        with _CursorLock(self):
+        self.__check_cursor()
+        self.__locked = True
+        try:
             self.__description = None
             self._reset = False
             self.__statement = self.__connection._statement_cache.get(
             self.__rowcount = -1
             if self.__statement.kind == _DML:
                 self.__rowcount = sqlite.sqlite3_changes(self.__connection._db)
+        finally:
+            self.__locked = False
 
         return self
 
     def executemany(self, sql, many_params):
-        with _CursorLock(self):
+        self.__check_cursor()
+        self.__locked = True
+        try:
             self.__description = None
             self._reset = False
             self.__statement = self.__connection._statement_cache.get(
                     raise self.__connection._get_exception(ret)
                 self.__rowcount += sqlite.sqlite3_changes(self.__connection._db)
             self.__statement.reset()
+        finally:
+            self.__locked = False
 
         return self
 
         self._reset = False
         if type(sql) is str:
             sql = sql.encode("utf-8")
-        self._check_closed()
+        self.__check_cursor()
         statement = c_void_p()
         c_sql = c_char_p(sql)
 
                 break
         return self
 
-    def _check_reset(self):
+    def __check_reset(self):
         if self._reset:
-            raise self.__connection.InterfaceError("Cursor needed to be reset because "
-                                                 "of commit/rollback and can "
-                                                 "no longer be fetched from.")
+            raise self.__connection.InterfaceError(
+                    "Cursor needed to be reset because of commit/rollback "
+                    "and can no longer be fetched from.")
 
     # do all statements
     def fetchone(self):
-        self._check_closed()
-        self._check_reset()
+        self.__check_cursor()
+        self.__check_reset()
 
         if self.__statement is None:
             return None
             return None
 
     def fetchmany(self, size=None):
-        self._check_closed()
-        self._check_reset()
+        self.__check_cursor()
+        self.__check_reset()
         if self.__statement is None:
             return []
         if size is None:
         return lst
 
     def fetchall(self):
-        self._check_closed()
-        self._check_reset()
+        self.__check_cursor()
+        self.__check_reset()
         if self.__statement is None:
             return []
         return list(self)