Commits

Owen Nelson  committed 983d910

[#45908471 #45908479 #45908481] fleshing out the driver/result classes and added a small script to help test the driver API

  • Participants
  • Parent commits 10a65b5
  • Branches omn

Comments (0)

Files changed (2)

File cloudstudio/driver.py

+import logging
 from collections import namedtuple
 from PySide import QtSql
 from google.storage.speckle.python.api import rdbms_googleapi as dbi
 
 
+LOG = logging.getLogger()
+
 # When a cursor has executed a statement, it will have a description
 # property with a list of values describing the data it contains.
-ResultDescription = namedtuple(
-    'ResultDescription',
+ColumnDescription = namedtuple(
+    'ColumnDescription',
     'name type_code display_size internal_size precision scale null_ok')
 
 
-class CloudSqlRecord(QtSql.QSqlRecord):
-    """ A row of data returned by a query
-    http://srinikom.github.com/pyside-docs/PySide/QtSql/QSqlRecord.html
-    """
-    # TODO: MASSIVE AMOUNTS OF IMPLEMENTATION
-    def __init__(self, other):
-        self.setReadOnly(True)
-        self.setGenerated(True)
+connections = {}
 
 
 class CloudSqlResult(QtSql.QSqlResult):
     http://srinikom.github.com/pyside-docs/PySide/QtSql/QSqlResult.html
     for interface details.
     """
-    _driver = None
+    bindingSyntax = QtSql.QSqlResult.PositionalBinding
     _cursor = None
+    _description = None
     _idx = None
     _row = None
+    _err = None
+    _query = None
 
     def __init__(self, db):
-        self._driver = db
+        super(CloudSqlResult, self).__init__(db)
         self.setForwardOnly(True)
+        self._cursor = connections[self.driver()].cursor()
 
-    @classmethod
-    def bindingSyntax(cls):
-        # we're using an odbc-style interface
-        cls.PositionalBinding
+    @property
+    def cursor(self):
+        return self._cursor
 
     def data(self, i):
-        return self._row[i]
+        return self._row.value(i)
 
     def exec_(self):
         try:
-            self._cursor = self.driver().cursor()
-            self._cursor.execute(self._query)
-        except:  # TODO: handle error messages
+            self.cursor.execute(self._query)
+            self._description = map(ColumnDescription._make, self.cursor.description)
+            # self.setActive(True)
+        except Exception as exc:
+            self.setLastError(str(exc))
             return False
         return True
 
     def fetchLast(self):
         raise NotImplementedError
 
-    def fetchNext(self):
-        self._row = self._cursor.fetchone()
-        self._idx = 0 if self._idx is None else self._idx + 1
-        return self._row
-
     def fetchPrevious(self):
         raise NotImplementedError
 
+    def fetchNext(self):
+        try:
+            row = self.cursor.fetchone()
+            if row:
+                rec = QtSql.QSqlRecord()
+                for i, col in enumerate(self._description):
+                    field = QtSql.QSqlField(col.name)
+                    field.setValue(row[i])
+                    rec.append(field)
+                self._idx = 0 if self._idx is None else self._idx + 1
+                self._row = rec
+            return True
+        except Exception as exc:
+            self.setLastError(str(exc))
+        return False
+
     def handle(self):
-        pass
+        raise NotImplementedError
 
-    def isNull(self, i):
-        return self._row[i] is None
+    # def isNull(self, i):
+    #     return self._row.value(i) is None
 
     def lastInsertId(self):
         raise NotImplementedError
 
     def numRowsAffected(self):
-        return self._cursor.rowcount if not self.isSelect() else -1
+        return self.cursor.rowcount if not self.isSelect() else -1
 
-    def prepare(self, query):
-        raise NotImplementedError
+    # def prepare(self, query):
+    #     raise NotImplementedError
 
     def record(self):
-        return self._row
+        self.fetchNext()
+        record = self._row or QtSql.QSqlRecord()
+        return record
 
     def at(self):
         return self._idx
 
     def reset(self, sqlquery):
-        self._idx = None
-        self._row = None
+        self._idx = self._description = self._err = self._query = None
+        self._last_query = self._row = None
         self.setQuery(sqlquery)
-
         self.exec_()
+        return True
 
-    def savePrepare(self, sqlquery):
-        pass
-
-    def setLastError(self, e):
-        pass
+    # def savePrepare(self, sqlquery):
+    #     raise NotImplementedError
 
     def setQuery(self, query):
         self._last_query, self._query = self._query, query
             (self._query or '').strip().lower().startswith('select'))
 
     def size(self):
-        return self._cursor.rowcount if self.isSelect() else -1
+        return self.cursor.rowcount if self.isSelect() else -1
 
     def clear(self):
-        pass
-
-    def detachFromResultSet(self):
-        pass
+        self._description = None
+        self._idx = None
+        self._row = None
+        self.setActive(False)
 
-    def driver(self):
-        return self._driver
+    # def detachFromResultSet(self):
+    #     raise NotImplementedError
 
-    def execBatch(self, arrayBind=False):
-        pass
+    # def execBatch(self, arrayBind=False):
+    #     raise NotImplementedError
 
     def executedQuery(self):
         return self._query or ''
 
-    def hasOutValues(self):
-        pass
+    # def hasOutValues(self):
+    #     pass
 
-    def isValid(self):
-        return self._row is not None
-
-    def lastError(self):
-        pass
+    # def isValid(self):
+    #     return self._row is not None
 
     def lastQuery(self):
         return self._last_query
 
     def nextResult(self):
-        return self._row
-
-    def numericalPrecisionPolicy(self):
-        pass
+        # I think this might be supposed to return True when there is
+        # known to be another result in the data set (but I'm not sure).
+        return self._idx < (self.size() - 1)
 
-    def resetBindCount(self):
-        pass
-
-    def setNumericalPrecisionPolicy(self, policy):
-        pass
+    # def resetBindCount(self):
+    #     raise NotImplementedError
 
 
 class CloudSqlDriver(QtSql.QSqlDriver):
     http://srinikom.github.com/pyside-docs/PySide/QtSql/QSqlDriver.html#PySide.QtSql.QSqlDriver
     for details on expected return values.
     """
-    _conn = None
-    _cursor = None
+
+    @property
+    def cursor(self):
+        return connections[self].cursor()
+
+    def __init__(self, *args, **kwargs):
+        self._instance = kwargs.pop('instance')
+        self._dbname = kwargs.pop('dbname')
+        self._user = kwargs.pop('user', '')
+        self._password = kwargs.pop('password', '')
+
+        # note that connections is from the module scope
+        connections[self] = dbi.connect(
+            '',  # host
+            self._instance,
+            database=self._dbname,
+            user=self._user,
+            password=self._password)
+
+        super(CloudSqlDriver, self).__init__(*args, **kwargs)
 
     def beginTransaction(self):
         # Python db api does not provide an interface for explicitly
         pass
 
     def close(self):
-        if self._conn is not None:
-            self._conn.close()
-            self._conn = self._cursor = None
+        if connections[self] is not None:
+            connections[self].close()
+            connections[self] = connections[self] = None
+        self.setOpen(False)
 
     def commitTransaction(self):
-        self._conn.commit()
+        connections[self].commit()
 
     def createResult(self):
         return CloudSqlResult(self)
     def escapeIdentifier(self, identifier, type):
         return '`{}`'.format(identifier)
 
-    def formatValue(self, field, trimStrings=False):
-        pass
-
-    def hasFeature(self, f):
-        pass
+    # def formatValue(self, field, trimStrings=False):
+    #     raise NotImplementedError
 
-    def isOpen(self):
-        return self._cursor and self._cursor._open
+    # def hasFeature(self, f):
+    #     raise NotImplementedError
 
-    def open(self, db, instance, host="", user="", password=""):
-        self._conn = dbi.connect(
-            host, instance, database=db, user=user, password=password)
-        self._cursor = self._conn.cursor()
+    def open(self, *args, **kwargs):
+        try:
+            self.cursor
+            self.setOpen(True)
+            return True
+        except Exception as exc:
+            self.setOpenError(str(exc))
+            return False
 
     def primaryIndex(self, tableName):
-        self._cursor.execute("show keys from {} where Key_name = 'PRIMARY'".format(tableName))
-        pk_cols = tuple(col[4] for col in self._cursor.fetchall())
+        self.cursor.execute("show keys from {} where Key_name = 'PRIMARY'".format(tableName))
+        pk_cols = tuple(col[4] for col in self.cursor.fetchall())
         return QtSql.QSqlIndex(name=','.join(pk_cols))
 
     def record(self, tableName):
-        self._cursor.execute('select * from {} limit 0'.format(tableName))
-        # TODO: return QSqlRecord, not tuple
-        return tuple(col[0] for col in self._cursor.description)
+        self.cursor.execute('select * from {} limit 0'.format(tableName))
+        rec = QtSql.QSqlRecord()
+        for col in self.cursor.description:
+            rec.append(QtSql.QSqlField(col[0]))
+        return rec
 
     def rollbackTransaction(self):
-        self._conn.rollback()
-
-    def setLastError(self, e):
-        pass
-
-    def setOpen(self, o):
-        pass
-
-    def setOpenError(self, e):
-        pass
+        connections[self].rollback()
 
     def sqlStatement(self, type, tableName, rec, preparedStatement):
-        pass
+        raise NotImplementedError
 
     def tables(self, tableType):
-        self._cursor.execute('show tables')
-        return tuple(row[0] for row in self._cursor.fetchall())
+        self.cursor.execute('show tables')
+        return tuple(row[0] for row in self.cursor.fetchall())
 
 
 class CloudSqlDriverCreatorBase(QtSql.QSqlDriverCreatorBase):
         return CloudSqlDriver()
 
 
-_driver = CloudSqlDriverCreatorBase()
-QtSql.QSqlDatabase.registerSqlDriver("CLOUDSQL", _driver)
-db = QtSql.QSqlDatabase.addDatabase("CLOUDSQL")
+def setupCloudSql():
+    _driver = CloudSqlDriverCreatorBase()
+    QtSql.QSqlDatabase.registerSqlDriver("CLOUDSQL", _driver)

File model-test.py

+#!/usr/bin/env python
+
+from PySide import QtCore, QtGui, QtSql
+from cloudstudio.driver import CloudSqlDriver
+
+
+def createConnection():
+    driver = CloudSqlDriver(
+        dbname='omn',
+        instance='qualitydistribution.com:qdi-mysql01:dev')
+    db = QtSql.QSqlDatabase.addDatabase(driver)
+    if not db.open():
+        QtGui.QMessageBox.critical(
+            None,
+            QtGui.qApp.tr("Cannot open database"),
+            QtGui.qApp.tr("Unable to establish a database connection.\n"
+                          "This example needs SQLite support. Please read "
+                          "the Qt SQL driver documentation for information "
+                          "how to build it.\n\nClick Cancel to exit."),
+            QtGui.QMessageBox.Cancel, QtGui.QMessageBox.NoButton)
+        return False
+    return True
+
+
+class CustomSqlModel(QtSql.QSqlQueryModel):
+    def data(self, index, role):
+        value = super(CustomSqlModel, self).data(index, role)
+        if value is not None and role == QtCore.Qt.DisplayRole:
+            if index.column() == 0:
+                return '#%d' % value
+            elif index.column() == 2:
+                return value.upper()
+
+        if role == QtCore.Qt.TextColorRole and index.column() == 1:
+            return QtGui.QColor(QtCore.Qt.blue)
+
+        return value
+
+
+def initializeModel(model):
+    print 'init model'
+    model.setQuery('show tables')
+    print 'set query'
+    rec = model.record()
+    print 'building record'
+    for idx in xrange(rec.count()):
+        model.setHeaderData(idx, QtCore.Qt.Horizontal, rec.fieldName(idx))
+
+offset = 0
+views = []
+
+
+def createView(title, model):
+    global offset, views
+
+    view = QtGui.QTableView()
+    views.append(view)
+    view.setModel(model)
+    view.setWindowTitle(title)
+    view.move(100 + offset, 100 + offset)
+    offset += 20
+    view.show()
+
+
+if __name__ == '__main__':
+
+    import sys
+    app = QtGui.QApplication(sys.argv)
+    if not createConnection():
+        sys.exit(1)
+    print 'got connection'
+    plainModel = QtSql.QSqlQueryModel()
+    # customModel = CustomSqlModel()
+    initializeModel(plainModel)
+    # initializeModel(customModel)
+    createView("Plain Query Model", plainModel)
+    # createView("Custom Query Model", customModel)
+    sys.exit(app.exec_())