Source

roktar / roktar / backends / sqlite_backend.py

Full commit
import os
import sqlite3
from cPickle import loads, dumps
import time
from uuid import uuid4

from _interface import RoktarBackend
from roktar import errors


def _entity_factory(cursor, row):
    return {
        col[0]: row[idx]
        for idx, col in enumerate(cursor.description)
    }


schema_definition = """
CREATE TABLE Entities (
    entity_id INTEGER PRIMARY KEY,
    time_created INTEGER,
    cached_value BLOB
);

CREATE TABLE Fields (
    field_id INTEGER PRIMARY KEY,
    entity_id INTEGER,
    name TEXT,
    integer_value INTEGER,
    real_value REAL,
    text_value TEXT,
    text_value_for_indexing TEXT
);

CREATE TABLE Types (
    name TEXT PRIMARY KEY,
    type TEXT
);

CREATE TABLE Changes (
    id INTEGER PRIMARY KEY,
    change_uuid TEXT,
    parent_uuid TEXT,
    change_time INTEGER,
    entity_id INTEGER,
    field_name TEXT,
    old_value BLOB,
    new_value BLOB
);

CREATE INDEX Entities_time_created ON Entities (time_created ASC);

CREATE INDEX Fields_entity_id ON Fields (entity_id ASC);
CREATE INDEX Fields_integer_value ON Fields (integer_value ASC);
CREATE INDEX Fields_real_value ON Fields (real_value ASC);
CREATE INDEX Fields_text_value1 ON Fields (text_value ASC);
CREATE INDEX Fields_text_value2 ON Fields (text_value_for_indexing ASC);

CREATE INDEX Changes_change_uuid ON Changes (change_uuid ASC);
CREATE INDEX Changes_entity_id ON Changes (entity_id ASC);
"""


FIELD_INTEGER = "INTEGER"
FIELD_REAL = "REAL"
FIELD_TEXT = "TEXT"
FIELD_PICKLE = "PICKLE"

_field_name_type_map = {
    FIELD_INTEGER: "integer_value",
    FIELD_REAL: "real_value",
    FIELD_TEXT: "text_value",
    FIELD_PICKLE: "text_value"
}


class BrokerMixin:
    def commit(self):
        self.db_connection.commit()

    def insert(self, table_name, **kwargs):
        self.cursor.execute(
            "INSERT INTO %s (%s) VALUES (%s)" % (
                table_name,
                ",".join(kwargs.keys()),
                ",".join("?" * len(kwargs))
            ), kwargs.values()
        )

    def get_last_entity_id(self):
        return self.db_connection.execute(
            "SELECT MAX(entity_id) AS _max_id FROM Entities"
        ).fetchone()["_max_id"]

    def get_last_change_uuid(self):
        return self.db_connection.execute(
            "SELECT MAX(id) AS _max_id FROM Changes"
        ).fetchone()["_max_id"]

    def select(self, table_name, **conditions):
        if conditions:
            return self.cursor.execute(
                "SELECT * FROM %s WHERE %s" % (
                    table_name, " AND ".join(
                        ["%s=?" % k for k in conditions.keys()]
                    )
                ), conditions.values()
            )
        else:
            return self.cursor.execute("SELECT * FROM %s" % table_name)

    def select_one(self, table_name, key_name, identity):
        return self.select(table_name, **{key_name: identity}).fetchone()


class SQLiteBackend(RoktarBackend, BrokerMixin):
    def init(self):
        self.db_connection = sqlite3.connect(
            os.path.join(self._configuration.data_folder, "entities.sqlite")
        )
        self.db_connection.row_factory = _entity_factory
        self.cursor = cursor = self.db_connection.cursor()

        table_definitions = [
            row["sql"]
            for row in cursor.execute("SELECT sql FROM sqlite_master")
        ]
        table_definitions = sorted(filter(None, table_definitions))

        if not table_definitions:
            cursor.executescript(schema_definition)

        cursor.executescript("""
            PRAGMA synchronous = OFF;
            PRAGMA temp_store = MEMORY;
            PRAGMA count_changes = OFF;
            PRAGMA locking_mode = EXCLUSIVE;
        """)

        self.commit()

    def close(self):
        self.db_connection.close()

    def create_entity(self, data, author=None, data_for_indexing=None):
        timestamp = int(time.time())

        self.insert(
            "Entities", time_created=timestamp, cached_value=dumps(data)
        )
        entity_id = self.get_last_entity_id()

        last_change = self.get_last_change_uuid()
        for key, value in data.iteritems():
            t = self.select("Types", name=key).fetchone()
            old_field_type = t["type"] if t else None

            if isinstance(value, int):
                field_type = FIELD_INTEGER
            elif isinstance(value, float):
                field_type = FIELD_REAL
            elif isinstance(value, basestring):
                field_type = FIELD_TEXT
            else:
                field_type = FIELD_PICKLE

            if old_field_type and old_field_type != field_type:
                raise errors.WrongFieldTypeError(key)

            if old_field_type is None:
                self.insert(
                    "Types",
                    name=key,
                    type=field_type
                )

            row = dict(entity_id=entity_id, name=key)
            row[_field_name_type_map[field_type]] = (
                dumps(value) if field_type == FIELD_PICKLE else value
            )
            self.insert("Fields", **row)

            new_change = uuid4().hex
            self.insert(
                "Changes",
                change_uuid=new_change,
                parent_uuid=last_change,
                change_time=timestamp,
                entity_id=entity_id,
                field_name=key,
                old_value=dumps(None),
                new_value=dumps(value)
            )
            last_change = new_change
        self.commit()

        return entity_id

    def get_entity(self, entity_id):
        data = self.select_one("Entities", "entity_id", entity_id)
        return loads(str(data["cached_value"]))

    def get_changes(self, entity_id):
        return self.select("Changes", entity_id=entity_id)
    
    def get_all_entities(self):
        return (
            (data["entity_id"], loads(str(data["cached_value"])))
            for data in self.select("Entities")
        )