1. Robert Brewer
  2. flowrate


flowrate / flowrate / variables.py

import ast
import calendar
import datetime
import decimal

class Visitor(object):
    """Visitor for Python code."""

    def __init__(self, strict=True):
        self.strict = strict

    def visit(self, node):
        """Visit a node."""
            clsname = node.__class__.__name__
            visitor = getattr(self, 'visit_' + clsname, None)
            if visitor is None:
                if self.strict:
                    raise NotImplementedException(clsname)
                    return self.visit_all(node)
                return visitor(node)
        except Exception, exc:
            if hasattr(node, "lineno"):
                if hasattr(exc, "lineno"):
                    # The exception aleady had a lineno attached, probably in
                    # a nested call, so don't re-set it here to an outer node.
                    exc.lineno = node.lineno

    def visit_all(self, node):
        """Visit all child nodes of the given node."""
        for field, value in ast.iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, ast.AST):
            elif isinstance(value, ast.AST):

class ReferenceFinder(Visitor):
    """Visitor for finding variable references in Python code.

    Use the find(source) method of an instance of this class to find
    the set of names which the given source references; that is,
    if you update a variable "foo" you know you need to recalculate
    any "source" where `'foo' in ReferenceFinder().find(source)`.

    def __init__(self):
        # We know we're only interested in Name. All the others can visit_all.
        self.strict = False

    def find(self, source, mode='exec'):
        """Return the set of variable names referred to in the given source."""
        self.variables = set()

        if isinstance(source, ast.AST):
            a = source
        elif isinstance(source, basestring):
            #start = time.time()
            a = compile(source, '<string>', mode, ast.PyCF_ONLY_AST)
            #print "Parsed block in %s seconds." % (time.time() - start)
            raise TypeError("source must be a string of code or an AST node.")

        #print ast.dump(a)
        #start = time.time()
        except Exception, exc:
            exc.source = source
        #print "Checked block in %s seconds." % (time.time() - start)

        return self.variables

    def visit_Name(self, node):

# --------------------------------- Builtins --------------------------------- #

def days(n):
    return datetime.timedelta(days=n)

class months(object):

    def __init__(self, n):
        self.n = n

    def __radd__(self, other):
        if not isinstance(other, (datetime.datetime, datetime.date)):
            raise TypeError("months may only be added to dates or datetimes")
        m = other.month + self.n
        y = other.year
        while m > 12:
            m -= 12
            y += 1
        return other.replace(year=y, month=m)

    def __rsub__(self, other):
        if not isinstance(other, (datetime.datetime, datetime.date)):
            raise TypeError(
                "months may only be subtracted from dates or datetimes")
        m = other.month - self.n
        y = other.year
        while m <= 0:
            m += 12
            y -= 1
        return other.replace(year=y, month=m)

def eom(d):
    return d.replace(day=calendar.monthrange(d.year, d.month)[1])

def pmt(interest_rate, number_of_payments, present_value):
    """Return the payment amount for a loan using a constant schedule."""
    i, N, pv = interest_rate, number_of_payments, present_value
    return (i * pv) / (1 - ((1 + i) ** -N))

missing = object()

class Expression(object):

    def __init__(self, source):
        self.source = source
        self.references = ReferenceFinder().find(source)

class Environment(object):

    globals = {
        'datetime': datetime,
        'decimal': decimal,
        'days': days,
        'months': months,
        'eom': eom,
        'pmt': pmt,

    def __init__(self):
        # Maintain one dict with references to Expressions...
        self.variables = {}
        # ...and another with references to their values, for use with eval().
        self.locals = {}

    def eval(self, source, extra_locals=None):
        if extra_locals:
            l = self.locals.copy()
            l = self.locals

            return eval(source, self.globals, l)
        except Exception, e:
            e.args += (source, extra_locals)

    def bind(self, name, source):
        """Register the given source."""
        self.variables[name] = Expression(source)

    def calc(self, name):
        """Recalc the variable. Recalc dependents and return their names."""
        source = self.variables[name].source
        self.locals[name] = self.eval(source)

        # Re-calc any other variables which depend on this one (topologically).
        names = self.get_referrers(name, recursive=True)

        return names

    def get_referrers(self, name, recursive=False):
        """Yield names of variables which refer to the given name.

        If 'recursive' is True, then recurse and yield all names which refer
        to the original referrers, and so on.
        for k, expr in self.variables.iteritems():
            if name in expr.references:
                yield k

                if recursive:
                    for n in self.get_referrers(k, recursive):
                        yield n

    def calc_all(self, names=None):
        """Calculate the given expressions in topological order.

        If 'names' is given, it MUST be a list of names (from self.variables)
        to calculate. If None, all names in self.variables are calculated.
        Use this to pass in a portion of the graph of expressions.
        if names is None:
            names = self.variables.keys()

        # Make a copy of the references so we can mutate them as we traverse.
        # However, omit any name which isn't in our 'names' argument;
        # we rely on emptying the 'refs' list to know when a node is ready
        # to calculate, and we assume that any name not passed in is
        # already calculated.
        allrefs = dict(
            (name, [r for r in self.variables[name].references if r in names])
            for name in names)

        # Start with all the ones which reference no others.
        Q = [name for name, refs in allrefs.iteritems() if not refs]

        while Q:
            name = Q.pop()

            # Eval and store result.
            self.locals[name] = self.eval(self.variables[name].source)

            # Mark all vars which reference "name".
            for k, refs in allrefs.iteritems():
                if name in refs:
                    # Add (to the end!) if all references have been evaluated.
                    if not refs:

        # Error if there are any circular definitions.
        remaining = [(k, refs) for k, refs in allrefs.iteritems() if refs]
        if remaining:
            raise ValueError("Circular dependencies found.", remaining)

environment = Environment()