Commits

Brian Kearns  committed a54b01b Merge

merge default

  • Participants
  • Parent commits 4ed015c, 8f3d8c5
  • Branches py3k

Comments (0)

Files changed (4)

File lib_pypy/_sqlite3.py

 from ctypes import POINTER, byref, string_at, CFUNCTYPE, cast
 from ctypes import sizeof, c_ssize_t
 from collections import OrderedDict
+from functools import wraps
 import datetime
 import sys
 import weakref
 PARSE_COLNAMES = 1
 PARSE_DECLTYPES = 2
 
+DML, DQL, DDL = range(3)
 
 ##########################################
 # BEGIN Wrapped SQLite C API and constants
         if self.db:
             sqlite.sqlite3_close(self.db)
 
+    def close(self):
+        self._check_thread()
+
+        for statement in self.statements:
+            obj = statement()
+            if obj is not None:
+                obj.finalize()
+
+        if self.db:
+            ret = sqlite.sqlite3_close(self.db)
+            if ret != SQLITE_OK:
+                raise self._get_exception(ret)
+            self.db.value = 0
+
+    def _check_closed(self):
+        if self.db is None:
+            raise ProgrammingError("Base Connection.__init__ not called.")
+        if not self.db:
+            raise ProgrammingError("Cannot operate on a closed database.")
+
+    def _check_closed_wrap(func):
+        @wraps(func)
+        def _check_closed_func(self, *args, **kwargs):
+            self._check_closed()
+            return func(self, *args, **kwargs)
+        return _check_closed_func
+
+    def _check_thread(self):
+        if not hasattr(self, 'thread_ident'):
+            return
+        if self.thread_ident != thread_get_ident():
+            raise ProgrammingError(
+                "SQLite objects created in a thread can only be used in that same thread."
+                "The object was created in thread id %d and this is thread id %d",
+                self.thread_ident, thread_get_ident())
+
+    def _check_thread_wrap(func):
+        @wraps(func)
+        def _check_thread_func(self, *args, **kwargs):
+            self._check_thread()
+            return func(self, *args, **kwargs)
+        return _check_thread_func
+
     def _get_exception(self, error_code=None):
         if error_code is None:
             error_code = sqlite.sqlite3_errcode(self.db)
         if self.statement_counter % 100 == 0:
             self.statements = [ref for ref in self.statements if ref() is not None]
 
-    def _check_thread(self):
-        if not hasattr(self, 'thread_ident'):
-            return
-        if self.thread_ident != thread_get_ident():
-            raise ProgrammingError(
-                "SQLite objects created in a thread can only be used in that same thread."
-                "The object was created in thread id %d and this is thread id %d",
-                self.thread_ident, thread_get_ident())
-
-    def _reset_cursors(self):
-        for cursor_ref in self.cursors:
-            cursor = cursor_ref()
-            if cursor:
-                cursor.reset = True
+    @_check_thread_wrap
+    @_check_closed_wrap
+    def __call__(self, sql):
+        if not isinstance(sql, str):
+            raise Warning("SQL is of wrong type. Must be string or unicode.")
+        statement = self.statement_cache.get(sql, self.row_factory)
+        return statement
 
     def cursor(self, factory=None):
         self._check_thread()
             cur.row_factory = self.row_factory
         return cur
 
+    def execute(self, *args):
+        cur = self.cursor()
+        return cur.execute(*args)
+
     def executemany(self, *args):
-        self._check_closed()
-        cur = Cursor(self)
-        if self.row_factory is not None:
-            cur.row_factory = self.row_factory
+        cur = self.cursor()
         return cur.executemany(*args)
 
-    def execute(self, *args):
-        self._check_closed()
-        cur = Cursor(self)
-        if self.row_factory is not None:
-            cur.row_factory = self.row_factory
-        return cur.execute(*args)
-
     def executescript(self, *args):
-        self._check_closed()
-        cur = Cursor(self)
-        if self.row_factory is not None:
-            cur.row_factory = self.row_factory
+        cur = self.cursor()
         return cur.executescript(*args)
 
-    def __call__(self, sql):
-        self._check_closed()
-        if not isinstance(sql, str):
-            raise Warning("SQL is of wrong type. Must be string or unicode.")
-        statement = self.statement_cache.get(sql, self.row_factory)
-        return statement
-
-    def _get_isolation_level(self):
-        return self._isolation_level
-
-    def _set_isolation_level(self, val):
-        if val is None:
-            self.commit()
-        if isinstance(val, str):
-            val = str(val)
-        self._isolation_level = val
-    isolation_level = property(_get_isolation_level, _set_isolation_level)
+    def iterdump(self):
+        from sqlite3.dump import _iterdump
+        return _iterdump(self)
 
     def _begin(self):
-        self._check_closed()
         if self._isolation_level is None:
             return
         if sqlite.sqlite3_get_autocommit(self.db):
             if obj is not None:
                 obj.reset()
 
+        for cursor_ref in self.cursors:
+            cursor = cursor_ref()
+            if cursor:
+                cursor.reset = True
+
         try:
             sql = "ROLLBACK"
             statement = c_void_p()
                 raise self._get_exception(ret)
         finally:
             sqlite.sqlite3_finalize(statement)
-            self._reset_cursors()
-
-    def _check_closed(self):
-        if self.db is None:
-            raise ProgrammingError("Base Connection.__init__ not called.")
-        if not self.db:
-            raise ProgrammingError("Cannot operate on a closed database.")
 
     def __enter__(self):
         return self
         else:
             self.rollback()
 
-    def _get_total_changes(self):
-        self._check_closed()
-        return sqlite.sqlite3_total_changes(self.db)
-    total_changes = property(_get_total_changes)
-
-    def close(self):
-        self._check_thread()
-
-        for statement in self.statements:
-            obj = statement()
-            if obj is not None:
-                obj.finalize()
-
-        if self.db:
-            ret = sqlite.sqlite3_close(self.db)
-            if ret != SQLITE_OK:
-                raise self._get_exception(ret)
-            self.db.value = 0
-
-    def create_collation(self, name, callback):
-        self._check_thread()
-        self._check_closed()
-        name = name.upper()
-        if not name.replace('_', '').isalnum():
-            raise ProgrammingError("invalid character in collation name")
-
-        if callback is None:
-            del self._collations[name]
-            c_collation_callback = cast(None, COLLATION)
-        else:
-            if not callable(callback):
-                raise TypeError("parameter must be callable")
-
-            def collation_callback(context, len1, str1, len2, str2):
-                text1 = string_at(str1, len1)
-                text2 = string_at(str2, len2)
-
-                return callback(text1, text2)
-
-            c_collation_callback = COLLATION(collation_callback)
-            self._collations[name] = c_collation_callback
-
-        ret = sqlite.sqlite3_create_collation(self.db, name,
-                                              SQLITE_UTF8,
-                                              None,
-                                              c_collation_callback)
-        if ret != SQLITE_OK:
-            raise self._get_exception(ret)
-
-    def set_progress_handler(self, callable, nsteps):
-        self._check_thread()
-        self._check_closed()
-        if callable is None:
-            c_progress_handler = cast(None, PROGRESS)
-        else:
-            try:
-                c_progress_handler, _ = self.func_cache[callable]
-            except KeyError:
-                def progress_handler(userdata):
-                    try:
-                        ret = callable()
-                        return bool(ret)
-                    except Exception:
-                        # abort query if error occurred
-                        return 1
-                c_progress_handler = PROGRESS(progress_handler)
-
-                self.func_cache[callable] = c_progress_handler, progress_handler
-        ret = sqlite.sqlite3_progress_handler(self.db, nsteps,
-                                              c_progress_handler,
-                                              None)
-        if ret != SQLITE_OK:
-            raise self._get_exception(ret)
-
-    def set_authorizer(self, callback):
-        self._check_thread()
-        self._check_closed()
-
-        try:
-            c_authorizer, _ = self.func_cache[callback]
-        except KeyError:
-            def authorizer(userdata, action, arg1, arg2, dbname, source):
-                try:
-                    return int(callback(action, arg1, arg2, dbname, source))
-                except Exception:
-                    return SQLITE_DENY
-            c_authorizer = AUTHORIZER(authorizer)
-
-            self.func_cache[callback] = c_authorizer, authorizer
-
-        ret = sqlite.sqlite3_set_authorizer(self.db,
-                                            c_authorizer,
-                                            None)
-        if ret != SQLITE_OK:
-            raise self._get_exception(ret)
-
+    @_check_thread_wrap
+    @_check_closed_wrap
     def create_function(self, name, num_args, callback):
-        self._check_thread()
-        self._check_closed()
         try:
             c_closure, _ = self.func_cache[callback]
         except KeyError:
         if ret != SQLITE_OK:
             raise self.OperationalError("Error creating function")
 
+    @_check_thread_wrap
+    @_check_closed_wrap
     def create_aggregate(self, name, num_args, cls):
-        self._check_thread()
-        self._check_closed()
-
         try:
             c_step_callback, c_final_callback, _, _ = self._aggregates[cls]
         except KeyError:
         if ret != SQLITE_OK:
             raise self._get_exception(ret)
 
-    def iterdump(self):
-        from sqlite3.dump import _iterdump
-        return _iterdump(self)
+    @_check_thread_wrap
+    @_check_closed_wrap
+    def create_collation(self, name, callback):
+        name = name.upper()
+        if not name.replace('_', '').isalnum():
+            raise ProgrammingError("invalid character in collation name")
+
+        if callback is None:
+            del self._collations[name]
+            c_collation_callback = cast(None, COLLATION)
+        else:
+            if not callable(callback):
+                raise TypeError("parameter must be callable")
+
+            def collation_callback(context, len1, str1, len2, str2):
+                text1 = string_at(str1, len1)
+                text2 = string_at(str2, len2)
+
+                return callback(text1, text2)
+
+            c_collation_callback = COLLATION(collation_callback)
+            self._collations[name] = c_collation_callback
+
+        ret = sqlite.sqlite3_create_collation(self.db, name,
+                                              SQLITE_UTF8,
+                                              None,
+                                              c_collation_callback)
+        if ret != SQLITE_OK:
+            raise self._get_exception(ret)
+
+    @_check_thread_wrap
+    @_check_closed_wrap
+    def set_authorizer(self, callback):
+        try:
+            c_authorizer, _ = self.func_cache[callback]
+        except KeyError:
+            def authorizer(userdata, action, arg1, arg2, dbname, source):
+                try:
+                    return int(callback(action, arg1, arg2, dbname, source))
+                except Exception:
+                    return SQLITE_DENY
+            c_authorizer = AUTHORIZER(authorizer)
+
+            self.func_cache[callback] = c_authorizer, authorizer
+
+        ret = sqlite.sqlite3_set_authorizer(self.db,
+                                            c_authorizer,
+                                            None)
+        if ret != SQLITE_OK:
+            raise self._get_exception(ret)
+
+    @_check_thread_wrap
+    @_check_closed_wrap
+    def set_progress_handler(self, callable, nsteps):
+        if callable is None:
+            c_progress_handler = cast(None, PROGRESS)
+        else:
+            try:
+                c_progress_handler, _ = self.func_cache[callable]
+            except KeyError:
+                def progress_handler(userdata):
+                    try:
+                        ret = callable()
+                        return bool(ret)
+                    except Exception:
+                        # abort query if error occurred
+                        return 1
+                c_progress_handler = PROGRESS(progress_handler)
+
+                self.func_cache[callable] = c_progress_handler, progress_handler
+        ret = sqlite.sqlite3_progress_handler(self.db, nsteps,
+                                              c_progress_handler,
+                                              None)
+        if ret != SQLITE_OK:
+            raise self._get_exception(ret)
+
+    def _get_isolation_level(self):
+        return self._isolation_level
+
+    def _get_total_changes(self):
+        self._check_closed()
+        return sqlite.sqlite3_total_changes(self.db)
+    total_changes = property(_get_total_changes)
+
+    def _set_isolation_level(self, val):
+        if val is None:
+            self.commit()
+        if isinstance(val, str):
+            val = str(val)
+        self._isolation_level = val
+    isolation_level = property(_get_isolation_level, _set_isolation_level)
 
     if HAS_LOAD_EXTENSION:
+        @_check_thread_wrap
+        @_check_closed_wrap
         def enable_load_extension(self, enabled):
-            self._check_thread()
-            self._check_closed()
-
             rc = sqlite.sqlite3_enable_load_extension(self.db, int(enabled))
             if rc != SQLITE_OK:
                 raise OperationalError("Error enabling load extension")
 
-DML, DQL, DDL = range(3)
-
 
 class CursorLock(object):
     def __init__(self, cursor):

File pypy/interpreter/app_main.py

     # app-level, very different from rpython.rlib.objectmodel.we_are_translated
     return hasattr(sys, 'pypy_translation_info')
 
-if 'nt' in sys.builtin_module_names:
-    IS_WINDOWS = True
-else:
-    IS_WINDOWS = False
+IS_WINDOWS = 'nt' in sys.builtin_module_names
 
 def setup_and_fix_paths(ignore_environment=False, **extra):
     import os

File pypy/module/__builtin__/functional.py

     return w_start, w_stop, w_step
     
 min_jitdriver = jit.JitDriver(name='min',
-        greens=['w_type'], reds='auto')
+        greens=['has_key', 'has_item', 'w_type'], reds='auto')
 max_jitdriver = jit.JitDriver(name='max',
-        greens=['w_type'], reds='auto')
+        greens=['has_key', 'has_item', 'w_type'], reds='auto')
 
 def make_min_max(unroll):
     @specialize.arg(2)
 
         w_iter = space.iter(w_sequence)
         w_type = space.type(w_iter)
+        has_key = w_key is not None
+        has_item = False
         w_max_item = None
         w_max_val = None
         while True:
             if not unroll:
-                jitdriver.jit_merge_point(w_type=w_type)
+                jitdriver.jit_merge_point(has_key=has_key, has_item=has_item, w_type=w_type)
             try:
                 w_item = space.next(w_iter)
             except OperationError, e:
                 if not e.match(space, space.w_StopIteration):
                     raise
                 break
-            if w_key is not None:
+            if has_key:
                 w_compare_with = space.call_function(w_key, w_item)
             else:
                 w_compare_with = w_item
-            if w_max_item is None or \
+            if not has_item or \
                     space.is_true(compare(w_compare_with, w_max_val)):
+                has_item = True
                 w_max_item = w_item
                 w_max_val = w_compare_with
         if w_max_item is None:

File pypy/module/test_lib_pypy/test_sqlite3.py

     e = pytest.raises(_sqlite3.ProgrammingError, "cur.execute('select 1')")
     assert '__init__' in e.value.message
 
+def test_connection_after_close():
+    con = _sqlite3.connect(':memory:')
+    pytest.raises(TypeError, "con()")
+    con.close()
+    # raises ProgrammingError because should check closed before check args
+    pytest.raises(_sqlite3.ProgrammingError, "con()")
+
 def test_cursor_after_close():
      con = _sqlite3.connect(':memory:')
      cur = con.execute('select 1')
 @pytest.mark.skipif("not hasattr(sys, 'pypy_translation_info')")
 def test_connection_del(tmpdir):
     """For issue1325."""
+    import os
     import gc
     try:
         import resource
 
     limit = resource.getrlimit(resource.RLIMIT_NOFILE)
     try:
-        resource.setrlimit(resource.RLIMIT_NOFILE, (min(10, limit[0]), limit[1]))
+        fds = 0
+        while True:
+            fds += 1
+            resource.setrlimit(resource.RLIMIT_NOFILE, (fds, limit[1]))
+            try:
+                for p in os.pipe(): os.close(p)
+            except OSError:
+                assert fds < 100
+            else:
+                break
         def open_many(cleanup):
             con = []
-            for i in range(20):
+            for i in range(3):
                 con.append(_sqlite3.connect(str(tmpdir.join('test.db'))))
                 if cleanup:
                     con[i] = None