htsql / src / htsql / core / cmd /

# Copyright (c) 2006-2013, Prometheus Research, LLC

from ..adapter import adapt, Utility
from ..util import listof
from ..context import context
from ..domain import ListDomain, RecordDomain, Profile, Product
from .command import FetchCmd, SkipCmd, SQLCmd
from .act import (analyze, Act, ProduceAction, SafeProduceAction,
                  AnalyzeAction, RenderAction)
from import bind
from import Binding
from import encode
from import OrderedFlow
from import rewrite
from import compile
from import assemble
from import reduce
from import serialize
from import Plan, Statement
from import decorate_void
from ..connect import transaction, scramble, unscramble
from ..error import PermissionError

class RowStream(object):

    def open(cls, statement, cursor, input=None):
        converts = [unscramble(domain)
                    for domain in]
        sql = statement.sql.encode('utf-8')
        parameters = None
        if statement.placeholders:
            assert input is not None
            parameters = {}
            for index in sorted(statement.placeholders):
                domain = statement.placeholders[index]
                convert = scramble(domain)
                value = convert(input[index])
                parameters[str(index+1)] = value
        if parameters is None:
            cursor.execute(sql, parameters)
        rows = []
        for row in cursor:
            row = tuple(convert(item)
                        for item, convert in zip(row, converts))
        substreams = [, cursor)
                      for substatement in statement.substatements]
        return cls(rows, substreams)

    def __init__(self, rows, substreams):
        assert isinstance(rows, list)
        assert isinstance(substreams, listof(RowStream))
        self.rows = rows
        self.substreams = substreams = 0
        self.last_top = None
        self.last_key = None

    def __iter__(self): = 0
        for row in self.rows:
            yield row
   += 1

    def get(self, stencil):
        return tuple(self.rows[][index]
                     for index in stencil)

    def slice(self, stencil, key):
        if key != self.last_key:
            self.last_top =
            self.last_key = key
            if key != ():
                while < len(self.rows):
                    row = self.rows[]
                    if key != tuple(row[index] for index in stencil):
                    yield row
           += 1
                assert not stencil
                while < len(self.rows):
                    yield self.rows[]
           += 1
            top =
   = self.last_top
            for idx in range(self.last_top, top):
       = idx
                yield self.rows[idx]
   = top

    def close(self):
        assert == len(self.rows)
        for substream in self.substreams:

class FetchPipe(object):

    def __init__(self, plan):
        assert isinstance(plan, Plan)
        self.plan = plan
        self.profile = plan.profile
        self.statement = plan.statement
        self.compose = plan.compose

    def __call__(self, input=None):
        meta = self.profile.clone(plan=self.plan)
        data = None
        if self.statement:
            if not context.env.can_read:
                raise PermissionError("No read permissions")
            stream = None
            with transaction() as connection:
                cursor = connection.cursor()
                stream =, cursor, input)
            data = self.compose(None, stream)
        return Product(meta, data)

class BuildFetch(Utility):

    def __init__(self, syntax, environment=None, limit=None):
        self.syntax = syntax
        self.environment = environment
        self.limit = limit

    def __call__(self):
        if not isinstance(self.syntax, Binding):
            binding = bind(self.syntax, environment=self.environment)
            binding = self.syntax
        expression = encode(binding)
        if self.limit is not None:
            expression = self.safe_patch(expression, self.limit)
        expression = rewrite(expression)
        term = compile(expression)
        frame = assemble(term)
        frame = reduce(frame)
        plan = serialize(frame)
        return FetchPipe(plan)

    def safe_patch(self, expression, limit):
        segment = expression.segment
        if segment is None:
            return expression
        flow = segment.flow
        while not flow.is_axis:
            if (isinstance(flow, OrderedFlow) and flow.limit is not None
                                              and flow.limit <= limit):
                return expression
            flow = flow.base
        if flow.is_root:
            return expression
        if isinstance(segment.flow, OrderedFlow):
            flow = segment.flow.clone(limit=limit)
            flow = OrderedFlow(segment.flow, [], limit, None, segment.binding)
        segment = segment.clone(flow=flow)
        expression = expression.clone(segment=segment)
        return expression

class ProduceFetch(Act):

    adapt(FetchCmd, ProduceAction)

    def __call__(self):
        cut = None
        if isinstance(self.action, SafeProduceAction):
            cut = self.action.cut
        pipe = build_fetch(self.command.syntax, self.action.environment, cut)
        return pipe()

class AnalyzeFetch(Act):

    adapt(FetchCmd, AnalyzeAction)

    def __call__(self):
        pipe = build_fetch(self.command.syntax, self.action.environment)
        return pipe.plan

class ProduceSkip(Act):

    adapt(SkipCmd, ProduceAction)

    def __call__(self):
        profile = decorate_void()
        return Product(profile, None)

class RenderSQL(Act):

    adapt(SQLCmd, RenderAction)

    def __call__(self):
        plan = analyze(self.command.feed)
        status = '200 OK'
        headers = [('Content-Type', 'text/plain; charset=UTF-8')]
        body = []
        if plan.statement:
            queue = [plan.statement]
            while queue:
                statement = queue.pop(0)
                if body:
        return (status, headers, body)

build_fetch = BuildFetch.__invoke__