1. Matthew Turk
  2. analysis_plugins


analysis_plugins / analysis_plugins.py

import abc
import matplotlib;matplotlib.use("Agg")
import pylab
from input_parameters import input_parameter_types, NoDefaultValue
from yt.mods import *
from global_registry import GlobalSharedRegistry

def iterable(obj):
    Grabbed from Python Cookbook / matploblib.cbook.  Returns true/false for
    *obj* iterable.
    try: len(obj)
    except: return False
    return True

def ptype_to_default(ptype):
    if isinstance(ptype, types.TupleType):
        return ptype
    return ptype, NoDefaultValue()

class AnalysisModule(object):
    _skip_add = False

    class __metaclass__(abc.ABCMeta):
        def __new__(mcls, name, b, d):
            cls = abc.ABCMeta.__new__(mcls, name, b, d)
            if hasattr(cls, "_module_name") and iterable(cls.parameters):
                analysis_modules[cls._module_name] = cls
                for pname, ptype in cls.parameters:
                    ptype = ptype_to_default(ptype)[0]
                    if ptype not in input_parameter_types:
                        raise KeyError(ptype)
            return cls

    def __init__(self, *args, **kwargs):
        # We use a list because we want it to be ordered.  Otherwise using a
        # dict would be fine.
        self.pnames = [a for a, b in self.parameters]
        to_add = []
        for a in args:
            if not isinstance(a, AnalysisModule): continue
            if not a.executed: a.execute()
        for new_parameters in to_add:
            # There's a more elegant way to do this, but I don't know it
            np = [(a,b) for a, b in new_parameters.items() if a in self.pnames]
            np.sort(key = lambda a: self.pnames.index(a[0]))
            for k, v in np:
                if k not in self.pnames: continue
                self._set_parameter(k, v)
        self.result = {}

    def _get_ptype_instance(self, pname, value = None):
        pname, ptype = self.parameters[self.pnames.index(pname)]
        ptype = ptype_to_default(ptype)[0]
        pvals = [(n, getattr(self, n)) for n in self.pnames if hasattr(self, n)]
        cls = input_parameter_types[ptype](value, dict(pvals))
        return cls

    def _set_parameter(self, name, value):
        # We just snag a list of the parameter names for convenience
        # If we don't know the name, complain
        if name not in self.pnames: raise KeyError(name)
        inst = self._get_ptype_instance(name, value)
        if not inst.validate(): raise RuntimeError
        # Okay, we know what KIND of parameter is and we have verified the
        # VALUE is correct, so we go ahead and set the thing.  Note that we
        # pull 'value' from the parameter type, because we allow it to cast or
        # modify the value in some way.
        setattr(self, name, inst.value)

    def prompt_for_parameters(self, only_unset = True):
        for pname, ptype in self.parameters:
            # We can't use None here
            default_val = NoDefaultValue()
            ptype, default_val = ptype_to_default(ptype)
            if only_unset and hasattr(self, pname): continue
            while 1:
                value = raw_input("Parameter name: %s (%s: %s).  Value? " % (
                        pname, ptype, default_val))
                if value.strip() == "" and \
                    not isinstance(default_val, NoDefaultValue):
                    value = default_val
                elif value.strip() == "?":
                    inst = self._get_ptype_instance(pname)
                    opts = inst.options() 
                    if opts is not None:
                        for i in opts: print i
                        print "Sorry, no suggestions available."
                    self._set_parameter(pname, value)
                except RuntimeError:
                    print "Try again."

    executed = False
    def execute(self):
        self.executed = True

    def _execute(self):
        # This function MUST set a result dictionary, even if it returns
        # something.

    # We mandate that subclasses implement "parameters"
    def parameters(self):

    def _module_name(self):
        # This must be a string or the parameter type won't get added to the
        # registry

    def __str__(self):
        s = " Analysis Module: %s\n --- Parameters" % (self._module_name)
        for pname, ptype in self.parameters:
            ptype, default_val = ptype_to_default(ptype)
            if hasattr(self, pname): default_val = getattr(self, pname)
            s += "\n   %s: %s (%s)" % (pname, ptype, default_val)
        if len(self.result) == 0: return s
        s += "\n --- Results"
        for r, rv in sorted(self.result.items()):
            s += "\n   %s: %s" % (r, rv)
        return s

    __repr__ = __str__

class AnalysisModuleList(GlobalSharedRegistry):
    valid_type = AnalysisModule

analysis_modules = AnalysisModuleList()

# A couple example modules ...

class MaxValueLocation(AnalysisModule):
    parameters = [("pf", "parameter_file"),
                  ("field", ("field", "Density"))]
    _module_name = "max_value_location"

    def _execute(self):
        v, c = self.pf.h.find_max(self.field)
        self.result['value'] = v
        self.result['center'] = c
        return v, c

class Sphere(AnalysisModule):
    parameters = [("center", "position"),
                  ("radius", ("float", 0.1)),
                  ("pf", "parameter_file")]
    _module_name = "sphere"

    def _execute(self):
        sp = self.pf.h.sphere(self.center, self.radius)
        self.result['source'] = sp
        return sp

class Extrema(AnalysisModule):
    parameters = [("source", "AMRData"),
                  ("field", "field")]
    _module_name = "extrema"

    def _execute(self):
        ex = self.source.quantities["Extrema"](self.field)[0]
        self.result["lower_bound"] = ex[0]
        self.result["upper_bound"] = ex[1]
        return ex

class Profile1D(AnalysisModule):
    parameters = [("source", "AMRData"),
                  ("n_bins", ("int", 64)),
                  ("bin_field", "field"),
                  ("lower_bound", "float"),
                  ("upper_bound", "float"),
                  ("profile_field", "field"),
                  ("weight_field", ("field_or_none",None))]
    _module_name = "profile1d"

    def _execute(self):
        prof = BinnedProfile1D(self.source, self.n_bins, self.bin_field,
                               self.lower_bound, self.upper_bound)
        prof.add_fields(self.profile_field, self.weight_field)
        self.result['profile'] = prof
        self.result['x'] = prof[self.bin_field]
        self.result['y'] = prof[self.profile_field]
        return prof

class LinePlot(AnalysisModule):
    parameters = [("x", "array"),
                  ("y", "array"),
                  ("x_name", ("string", "Unknown")),
                  ("y_name", ("string", "Unknown")),
                  ("log_x", ("bool", "True")),
                  ("log_y", ("bool", "True")),
                  ("filename", ("string", "temp.png"))]
    _module_name = "line_plot"

    def _execute(self):
        if self.log_x and self.log_y:
            pylab.loglog(self.x, self.y, '-x')
        elif self.log_x and not self.log_y:
            pylab.semilogx(self.x, self.y, '-x')
        elif not self.log_x and self.log_y:
            pylab.semilogy(self.x, self.y, '-x')

if __name__ == "__main__":
    print analysis_modules.keys()
    print input_parameter_types.keys()
    ta = analysis_modules["max_value_location"](
            pf = "RD0005-mine/RedshiftOutput0005", field="Density")
    sp = analysis_modules["sphere"](ta, pf = ta.pf, radius=0.1)
    ex = analysis_modules["extrema"](sp, field="Density")
    p = analysis_modules["profile1d"](sp, ex)
    lp = analysis_modules["line_plot"](p,
        x_name = p.bin_field, y_name = p.profile_field)

    print p.result