Commits

desmaj committed 10cbafd

added the MS-TDS spec as a reference; added a simple sqlalchemy dialect; added commit and rollback to the connection; added a query message; changed the cursor.execute signature to match the dbapi spec

Comments (0)

Files changed (11)

 
 def connect(host, port, username, password, database='master', tdsversion=TDS71):
     return TDSConnection(host, port, username, password, database, tdsversion)
+
+paramstyle = 'qmark'
                  database='master', tdsversion=common.TDS71):
         self.tds = TDS(host, port, tdsversion)
         self.tds.login(host, username, password, database)
+        self.tds.query('SET IMPLICIT_TRANSACTIONS ON')
     
     def cursor(self):
         return cursor.Cursor(self)
     
+    def commit(self):
+        self.tds.query('COMMIT')
+    
+    def rollback(self):
+        self.tds.query('IF @@TRANCOUNT > 0 ROLLBACK')
         self._context = None
     
     @property
+    def description(self):
+        return [(col.name, col.type.code, None, None, None, None, None)
+                for col in self._context.columns
+                if col.name != u'ROWSTAT']
+    
+    @property
     def rowcount(self):
         return self._context.rowcount
     
-    def execute(self, sql, *params):
+    def execute(self, sql, params=None):
+        if params is None:
+            params = []
+        
         if self._context and self._context.cursor_id:
             self.close()
         
             mac_bytes.append(0)
         return list(reversed(mac_bytes))
 
+class QueryMessage(object):
+    type = 0x01
+    
+    def __init__(self, text):
+        self.text = text
+    
+    def pack(self, tds):
+        message = ''
+        message += self.text.encode('utf_16_le')
+        return message
+
 class RPCArgument(object):
     
     def __init__(self, name, value, output=False):

ptds/sqlalchemy/__init__.py

Empty file added.

ptds/sqlalchemy/dialect.py

+# mssql/ptds.py
+# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""
+Support for the ptds dialect.
+
+This dialect supports ptds 1.0 and greater.
+
+ptds is available at:
+
+    http://ptds.sourceforge.net/
+
+Connecting
+^^^^^^^^^^
+
+Sample connect string::
+
+    mssql+ptds://<username>:<password>@<freetds_name>
+
+Adding "?charset=utf8" or similar will cause ptds to return
+strings as Python unicode objects.   This can potentially improve 
+performance in some scenarios as decoding of strings is 
+handled natively.
+
+Limitations
+^^^^^^^^^^^
+
+ptds inherits a lot of limitations from FreeTDS, including:
+
+* no support for multibyte schema identifiers
+* poor support for large decimals
+* poor support for binary fields
+* poor support for VARCHAR/CHAR fields over 255 characters
+
+Please consult the ptds documentation for further information.
+ 
+"""
+from sqlalchemy.dialects.mssql.base import MSDialect
+from sqlalchemy import types as sqltypes, util, processors
+import re
+
+class MSDialect_ptds(MSDialect):
+    supports_sane_rowcount = False
+    driver = 'ptds'
+
+    @classmethod
+    def dbapi(cls):
+        return __import__('ptds')
+        
+    def __init__(self, **params):
+        super(MSDialect_ptds, self).__init__(**params)
+        self.use_scope_identity = True
+
+    def _get_server_version_info(self, connection):
+        vers = connection.scalar("select @@version")
+        m = re.match(
+            r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers)
+        if m:
+            return tuple(int(x) for x in m.group(1, 2, 3, 4))
+        else:
+            return None
+
+    def is_disconnect(self, e, connection, cursor):
+        for msg in (
+            "Error 10054",
+            "Not connected to any MS SQL server",
+            "Connection is closed"
+        ):
+            if msg in str(e):
+                return True
+        else:
+            return False
+
+dialect = MSDialect_ptds
                         [self._hex_string(ord(char)) for char in version],
                         text)
 
+class TDSQueryContext(TDSContext):
+    pass
+
 class TDSCursorContext(TDSContext):
     
     def handle_Result(self, columns):
         context = TDSLoginContext(self)
         self._receive(context)
     
+    def query(self, text):
+        message = messages.QueryMessage(text)
+        self._send(message)
+        context = TDSQueryContext(self)
+        self._receive(context)
+    
     def rpc(self, context, name, *args):
         message = messages.RPCMessage(name, *args)
         self._send(message)

ptds/tests/functional/test_cursor.py

     def test_select_with_one_param(self):
         c = self.get_connection().cursor()
         c.execute(u'select * from sys.types where name = ?',
-                  u"image")
+                  [u"image"])
         rows = c.fetchall()
         c.close()
         assert 1 == len(rows)
         c = self.get_connection().cursor()
         c.execute(u'select * from sys.types'
                   ' where name = ? and system_type_id = ?',
-                  u"image", 34)
+                  [u"image", 34])
         rows = c.fetchall()
         c.close()
         assert 1 == len(rows)

ptds/tests/functional/test_sqlalchemy_dialect.py

+try:
+    import sqlalchemy.dialects.mssql
+    from ptds.sqlalchemy import dialect
+    sqlalchemy.dialects.mssql.ptds = dialect
+    
+    from ptds.tests import config
+    def test_create_engine():
+        host = config.get('connection', 'host')
+        port = config.get('connection', 'port')
+        username = config.get('connection', 'username')
+        password = config.get('connection', 'password')
+        database = config.get('connection', 'database')
+        
+        uri = 'mssql+ptds://%s:%s@%s:%s/%s' % (
+            username, password, host, port, database)
+        engine = sqlalchemy.create_engine(uri)
+        engine.execute("select 4 as Four")
+
+except ImportError:
+    pass

ptds/tests/functional/test_types.py

         tablename = self.tablename(typespec)
         c = connection.cursor()
         c.execute(u'INSERT INTO %s (col_1) values (?)' % tablename,
-                  testvalue)
+                  [testvalue])
         c.execute(u'SELECT col_1 FROM %s' % tablename)
         row = c.fetchone()
         _assert_testvalue(testvalue, row.col_1)
         tablename = self.tablename(typespec)
         c = connection.cursor()
         c.execute(u'INSERT INTO %s (col_1) values (?)' % tablename,
-                  nonevalue)
+                  [nonevalue])
         c.execute(u'SELECT col_1 FROM %s' % tablename)
         row = c.fetchone()
         assert row.col_1 is None, row.col_1
             self._test_insert_and_select_NULL(connection, typespec, nonevalue)
         finally:
             self._drop_table(connection, typespec)
+            connection.rollback()
         
     @staticmethod
     def _assert_space_padded(testvalue, actual):

spec/[MS-TDS].pdf

Binary file added.