flowrate / flowrate / variables.py

import ast
import calendar
import datetime
import decimal


class Visitor(object):
    """Visitor for Python code."""

    def __init__(self, strict=True):
        self.strict = strict

    def visit(self, node):
        """Visit a node."""
        try:
            clsname = node.__class__.__name__
            visitor = getattr(self, 'visit_' + clsname, None)
            if visitor is None:
                if self.strict:
                    raise NotImplementedException(clsname)
                else:
                    return self.visit_all(node)
            else:
                return visitor(node)
        except Exception, exc:
            if hasattr(node, "lineno"):
                if hasattr(exc, "lineno"):
                    # The exception aleady had a lineno attached, probably in
                    # a nested call, so don't re-set it here to an outer node.
                    pass
                else:
                    exc.lineno = node.lineno
            raise

    def visit_all(self, node):
        """Visit all child nodes of the given node."""
        for field, value in ast.iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, ast.AST):
                        self.visit(item)
            elif isinstance(value, ast.AST):
                self.visit(value)


class ReferenceFinder(Visitor):
    """Visitor for finding variable references in Python code.

    Use the find(source) method of an instance of this class to find
    the set of names which the given source references; that is,
    if you update a variable "foo" you know you need to recalculate
    any "source" where `'foo' in ReferenceFinder().find(source)`.
    """

    def __init__(self):
        # We know we're only interested in Name. All the others can visit_all.
        self.strict = False

    def find(self, source, mode='exec'):
        """Return the set of variable names referred to in the given source."""
        self.variables = set()

        if isinstance(source, ast.AST):
            a = source
        elif isinstance(source, basestring):
            #start = time.time()
            a = compile(source, '<string>', mode, ast.PyCF_ONLY_AST)
            #print "Parsed block in %s seconds." % (time.time() - start)
        else:
            raise TypeError("source must be a string of code or an AST node.")

        #print ast.dump(a)
        #start = time.time()
        try:
            self.visit(a)
        except Exception, exc:
            exc.source = source
            raise
        #print "Checked block in %s seconds." % (time.time() - start)

        return self.variables

    def visit_Name(self, node):
        self.variables.add(node.id)


# --------------------------------- Builtins --------------------------------- #

def days(n):
    return datetime.timedelta(days=n)


class months(object):

    def __init__(self, n):
        self.n = n

    def __radd__(self, other):
        if not isinstance(other, (datetime.datetime, datetime.date)):
            raise TypeError("months may only be added to dates or datetimes")
        m = other.month + self.n
        y = other.year
        while m > 12:
            m -= 12
            y += 1
        return other.replace(year=y, month=m)

    def __rsub__(self, other):
        if not isinstance(other, (datetime.datetime, datetime.date)):
            raise TypeError(
                "months may only be subtracted from dates or datetimes")
        m = other.month - self.n
        y = other.year
        while m <= 0:
            m += 12
            y -= 1
        return other.replace(year=y, month=m)

def eom(d):
    return d.replace(day=calendar.monthrange(d.year, d.month)[1])

def pmt(interest_rate, number_of_payments, present_value):
    """Return the payment amount for a loan using a constant schedule."""
    i, N, pv = interest_rate, number_of_payments, present_value
    return (i * pv) / (1 - ((1 + i) ** -N))


missing = object()

class Expression(object):

    def __init__(self, source):
        self.source = source
        self.references = ReferenceFinder().find(source)


class Environment(object):

    globals = {
        'datetime': datetime,
        'decimal': decimal,
        'days': days,
        'months': months,
        'eom': eom,
        'pmt': pmt,
        }

    def __init__(self):
        # Maintain one dict with references to Expressions...
        self.variables = {}
        # ...and another with references to their values, for use with eval().
        self.locals = {}

    def eval(self, source, extra_locals=None):
        if extra_locals:
            l = self.locals.copy()
            l.update(extra_locals)
        else:
            l = self.locals

        try:
            return eval(source, self.globals, l)
        except Exception, e:
            e.args += (source, extra_locals)
            raise

    def bind(self, name, source):
        """Register the given source."""
        self.variables[name] = Expression(source)

    def calc(self, name):
        """Recalc the variable. Recalc dependents and return their names."""
        source = self.variables[name].source
        self.locals[name] = self.eval(source)

        # Re-calc any other variables which depend on this one (topologically).
        names = self.get_referrers(name, recursive=True)
        self.calc_all(names)

        return names

    def get_referrers(self, name, recursive=False):
        """Yield names of variables which refer to the given name.

        If 'recursive' is True, then recurse and yield all names which refer
        to the original referrers, and so on.
        """
        for k, expr in self.variables.iteritems():
            if name in expr.references:
                yield k

                if recursive:
                    for n in self.get_referrers(k, recursive):
                        yield n

    def calc_all(self, names=None):
        """Calculate the given expressions in topological order.

        If 'names' is given, it MUST be a list of names (from self.variables)
        to calculate. If None, all names in self.variables are calculated.
        Use this to pass in a portion of the graph of expressions.
        """
        if names is None:
            names = self.variables.keys()

        # Make a copy of the references so we can mutate them as we traverse.
        # However, omit any name which isn't in our 'names' argument;
        # we rely on emptying the 'refs' list to know when a node is ready
        # to calculate, and we assume that any name not passed in is
        # already calculated.
        allrefs = dict(
            (name, [r for r in self.variables[name].references if r in names])
            for name in names)

        # Start with all the ones which reference no others.
        Q = [name for name, refs in allrefs.iteritems() if not refs]

        while Q:
            name = Q.pop()

            # Eval and store result.
            self.locals[name] = self.eval(self.variables[name].source)

            # Mark all vars which reference "name".
            for k, refs in allrefs.iteritems():
                if name in refs:
                    refs.remove(name)
                    # Add (to the end!) if all references have been evaluated.
                    if not refs:
                        Q.append(k)

        # Error if there are any circular definitions.
        remaining = [(k, refs) for k, refs in allrefs.iteritems() if refs]
        if remaining:
            raise ValueError("Circular dependencies found.", remaining)

environment = Environment()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.