htsql / src / htsql / core / tr /

# Copyright (c) 2006-2013, Prometheus Research, LLC

from ..util import listof
from ..context import context
from ..domain import ListDomain, RecordDomain, Profile, Product
from ..syn.syntax import Syntax
from import bind
from import Binding
from import encode
from import OrderedFlow
from import rewrite
from import compile
from import assemble
from import reduce
from import serialize
from import Plan, Statement
from ..connect import transaction, scramble, unscramble
from ..error import PermissionError

class RowStream(object):

    def open(cls, statement, cursor, input=None):
        converts = [unscramble(domain)
                    for domain in]
        sql = statement.sql.encode('utf-8')
        parameters = None
        if statement.placeholders:
            assert input is not None
            parameters = {}
            for index in sorted(statement.placeholders):
                domain = statement.placeholders[index]
                convert = scramble(domain)
                value = convert(input[index])
                parameters[str(index+1)] = value
        if parameters is None:
            cursor.execute(sql, parameters)
        rows = []
        for row in cursor:
            row = tuple(convert(item)
                        for item, convert in zip(row, converts))
        substreams = [, cursor)
                      for substatement in statement.substatements]
        return cls(rows, substreams)

    def __init__(self, rows, substreams):
        assert isinstance(rows, list)
        assert isinstance(substreams, listof(RowStream))
        self.rows = rows
        self.substreams = substreams = 0
        self.last_top = None
        self.last_key = None

    def __iter__(self): = 0
        for row in self.rows:
            yield row
   += 1

    def get(self, stencil):
        return tuple(self.rows[][index]
                     for index in stencil)

    def slice(self, stencil, key):
        if key != self.last_key:
            self.last_top =
            self.last_key = key
            if key != ():
                while < len(self.rows):
                    row = self.rows[]
                    if key != tuple(row[index] for index in stencil):
                    yield row
           += 1
                assert not stencil
                while < len(self.rows):
                    yield self.rows[]
           += 1
            top =
   = self.last_top
            for idx in range(self.last_top, top):
       = idx
                yield self.rows[idx]
   = top

    def close(self):
        assert == len(self.rows)
        for substream in self.substreams:

class FetchPipe(object):

    def __init__(self, plan):
        assert isinstance(plan, Plan)
        self.plan = plan
        self.profile = plan.profile
        self.statement = plan.statement
        self.compose = plan.compose

    def __call__(self, input=None):
        meta = self.profile.clone(plan=self.plan)
        data = None
        if self.statement:
            if not context.env.can_read:
                raise PermissionError("No read permissions")
            stream = None
            with transaction() as connection:
                cursor = connection.cursor()
                stream =, cursor, input)
            data = self.compose(None, stream)
        return Product(meta, data)

def translate(syntax, environment=None, limit=None):
    assert isinstance(syntax, (Syntax, Binding, unicode, str))
    if isinstance(syntax, (str, unicode)):
        syntax = parse(syntax)
    if not isinstance(syntax, Binding):
        binding = bind(syntax, environment=environment)
        binding = syntax
    expression = encode(binding)
    if limit is not None:
        expression = safe_patch(expression, limit)
    expression = rewrite(expression)
    term = compile(expression)
    frame = assemble(term)
    frame = reduce(frame)
    plan = serialize(frame)
    return FetchPipe(plan)

def safe_patch(expression, limit):
    segment = expression.segment
    if segment is None:
        return expression
    flow = segment.flow
    while not flow.is_axis:
        if (isinstance(flow, OrderedFlow) and flow.limit is not None
                                          and flow.limit <= limit):
            return expression
        flow = flow.base
    if flow.is_root:
        return expression
    if isinstance(segment.flow, OrderedFlow):
        flow = segment.flow.clone(limit=limit)
        flow = OrderedFlow(segment.flow, [], limit, None, segment.binding)
    segment = segment.clone(flow=flow)
    expression = expression.clone(segment=segment)
    return expression