Commits

Petar Marić committed 68f7260

Added optimized integration support, along with smart caching and integral parent/child relationships

  • Participants
  • Parent commits 18f90c6

Comments (0)

Files changed (4)

File beam_integrals/integrals.py

+import itertools
 import re
+from sympy import Float, mpmath
+from . import DEFAULT_DECIMAL_PRECISION
+from .characteristic_equation_solvers import find_best_root
 from .utils import FriendlyNameFromClassMixin, PluginMount
 
 
 class BaseIntegral(FriendlyNameFromClassMixin):
     _id_re = re.compile('I(\d+)')
     
+    used_variables = ('m', 't', 'v', 'n')
+    
     __metaclass__ = PluginMount
     
     def __str__(self):
             self._id = int(self._id_re.match(self.name).group(1))
         
         return self._id
+    
+    @staticmethod
+    def _contribute_to_plugins(_plugins):
+        _plugins.child_id_to_parent_id = dict(
+            (id, _plugins.class_to_id[cls.__base__])
+            for cls, id in _plugins.class_to_id.items() #@ReservedAssignment
+            if cls.__base__ in _plugins.classes
+        )
+    
+    @classmethod
+    def parent_id(cls):
+        return cls.plugins.child_id_to_parent_id.get(cls.plugins.class_to_id[cls]) 
+    
+    @classmethod
+    def has_parent(cls):
+        return cls.parent_id() is not None
+    
+    def iterate_over_used_variables(self, max_mode, start_mode=1):
+        modes = range(start_mode, max_mode+1)
+        get_modes = lambda var: modes if var in self.used_variables else (None,)
+        
+        var_modes = (get_modes(var) for var in BaseIntegral.used_variables)
+        return itertools.product(*var_modes)
+    
+    def cache_key(self, m, t, v, n, max_mode):
+        d = locals()
+        return sum(
+            d[var] * max_mode**idx
+            for idx, var in enumerate(self.used_variables)
+        )
+    
+    def integrand(self, beam_type, m, t, v, n, decimal_precision=DEFAULT_DECIMAL_PRECISION):
+        def resolve_mu_m(func, *args, **kwargs):
+            def wrapper(mode):
+                return func(mode, *args, **kwargs).subs('mu_m',
+                    find_best_root(beam_type, mode, decimal_precision)
+                )
+            
+            return wrapper
+        
+        Y_m = resolve_mu_m(beam_type.Y_m)
+        dY_m = resolve_mu_m(beam_type.Y_m_derivative_from_cache, order=1)
+        ddY_m = resolve_mu_m(beam_type.Y_m_derivative_from_cache, order=2)
+        
+        return self._integrand(Y_m, dY_m, ddY_m, m, t, v, n)
+    
+    def _integrand(self, Y_m, dY_m, ddY_m, m, t, v, n):
+        raise NotImplementedError
+    
+    def __call__(self, beam_type, m, t, v, n, decimal_precision=DEFAULT_DECIMAL_PRECISION):
+        return self.integrand(beam_type, m, t, v, n, decimal_precision)
 
 
-class I1(BaseIntegral):
-    pass
+class BaseIntegralWithSymetricVariables(BaseIntegral):
+    def cache_key(self, m, t, v, n, max_mode):
+        d = locals()
+        values = sorted(d[var] for var in self.used_variables)
+        return sum(
+            val * max_mode**idx
+            for idx, val in enumerate(values)
+        )
 
 
-class I2(BaseIntegral):
-    pass
+class I1(BaseIntegralWithSymetricVariables):
+    used_variables = ('m', 'n')
+    
+    def _integrand(self, Y_m, dY_m, ddY_m, m, t, v, n): #@UnusedVariable
+        return Y_m(m) * Y_m(n)
+
+class I21(I1): pass
+
+
+class I2(BaseIntegralWithSymetricVariables):
+    used_variables = ('m', 'n')
+    
+    def _integrand(self, Y_m, dY_m, ddY_m, m, t, v, n): #@UnusedVariable
+        return dY_m(m) * dY_m(n)
+
+class I4(I2): pass
+class I6(I2): pass
+class I8(I2): pass
+class I25(I2): pass
 
 
 class I3(BaseIntegral):
-    pass
+    used_variables = ('m', 'n')
+    
+    def _integrand(self, Y_m, dY_m, ddY_m, m, t, v, n): #@UnusedVariable
+        return ddY_m(m) * Y_m(n)
 
-
-class I4(BaseIntegral):
-    pass
+class I22(I3): pass
 
 
 class I5(BaseIntegral):
-    pass
+    used_variables = ('m', 'n')
+    
+    def _integrand(self, Y_m, dY_m, ddY_m, m, t, v, n): #@UnusedVariable
+        return Y_m(m) * ddY_m(n)
 
+class I23(I5): pass
 
-class I6(BaseIntegral):
-    pass
 
+class I7(BaseIntegralWithSymetricVariables):
+    used_variables = ('m', 'n')
+    
+    def _integrand(self, Y_m, dY_m, ddY_m, m, t, v, n): #@UnusedVariable
+        return ddY_m(m) * ddY_m(n)
 
-class I7(BaseIntegral):
-    pass
+class I24(I7): pass
 
 
-class I8(BaseIntegral):
-    pass
-
-
-class I21(BaseIntegral):
-    pass
-
-
-class I22(BaseIntegral):
-    pass
-
-
-class I23(BaseIntegral):
-    pass
-
-
-class I24(BaseIntegral):
-    pass
-
-
-class I25(BaseIntegral):
-    pass
+def integrate(integral, beam_type, a, m=None, t=None, v=None, n=None, decimal_precision=DEFAULT_DECIMAL_PRECISION, **kwargs):
+    cached_subs = integral(beam_type, m, t, v, n, decimal_precision).subs('a', a)
+    f = lambda y: cached_subs.evalf(n=decimal_precision, subs={'y': y})
+    
+    with mpmath.workdps(decimal_precision):
+        result = mpmath.quad(f, (0., a), **kwargs)
+        
+        # If not converted to `sympy.Float` precision will be lost after the
+        # original `mpmath` context is restored
+        if isinstance(result, tuple): # Integration error included
+            return tuple(Float(x, decimal_precision) for x in result)
+        else:
+            return Float(result, decimal_precision)

File tests/integrals/test_cache_key.py

+from itertools import chain, combinations
+from beam_integrals.integrals import BaseIntegral, BaseIntegralWithSymetricVariables
+from tests.tools import assert_in, assert_not_in #@UnresolvedImport
+
+
+MAX_MODE = 10 # Lower than defaults to speed up tests
+
+
+def gen_integral_variable_combinations(min_len=1):
+    all_vars = BaseIntegral.used_variables
+    return chain.from_iterable(
+        combinations(all_vars, r)
+        for r in range(min_len, len(all_vars)+1)
+    )
+
+def setup():
+    global plain_integral, integral_with_symetric_variables
+    
+    class I9001(BaseIntegral):
+        pass
+    plain_integral = I9001()
+    
+    class I9002(BaseIntegralWithSymetricVariables):
+        pass
+    integral_with_symetric_variables = I9002()
+
+def teardown():
+    type(plain_integral)._unregister_plugin()
+    type(integral_with_symetric_variables)._unregister_plugin()
+
+def test_plain_integral():
+    for used_variables in gen_integral_variable_combinations():
+        yield check_plain_integral, used_variables
+
+def check_plain_integral(used_variables):
+    integral = plain_integral
+    integral.used_variables = used_variables
+    
+    keys_seen = set()
+    for variables in integral.iterate_over_used_variables(max_mode=MAX_MODE):
+        key = integral.cache_key(*variables, max_mode=MAX_MODE)
+        assert_not_in(key, keys_seen) # Should be a cache miss
+        keys_seen.add(key)
+
+def test_integral_with_symetric_variables():
+    for used_variables in gen_integral_variable_combinations(min_len=2):
+        yield check_integral_with_symetric_variables, used_variables
+
+def check_integral_with_symetric_variables(used_variables):
+    integral = integral_with_symetric_variables
+    integral.used_variables = used_variables
+    
+    keys_seen = set()
+    sorted_variables_seen = set()
+    for variables in integral.iterate_over_used_variables(max_mode=MAX_MODE):
+        key = integral.cache_key(*variables, max_mode=MAX_MODE)
+        sorted_variables = tuple(sorted(variables))
+        if sorted_variables in sorted_variables_seen:
+            assert_in(key, keys_seen) # Should be a cache hit
+        else:
+            assert_not_in(key, keys_seen) # Should be a cache miss
+            keys_seen.add(key)
+            sorted_variables_seen.add(sorted_variables)

File tests/integrals/test_integration.py

+from nose.plugins.skip import SkipTest
+import shutil
+from sympy import Abs
+import tempfile
+from beam_integrals import a, mu_m
+from beam_integrals import characteristic_equation_solvers as ces
+from beam_integrals.beam_types import BaseBeamType
+from beam_integrals.integrals import BaseIntegral, integrate
+import tests
+from tests.tools import assert_almost_equal, assert_equal, assert_less_equal, assert_is #@UnresolvedImport
+
+
+START_MODE = tests.MAX_MODE - 1 # Higher than defaults to speed up tests
+
+INTEGRAL_CLOSED_FORMS_FOR_SIMPLY_SUPPORTED_BEAM = {
+    1:   a/2.,
+    2:  (a/2.) * (mu_m/a)**2,
+    3: -(a/2.) * (mu_m/a)**2,
+    5: -(a/2.) * (mu_m/a)**2,
+    7:  (a/2.) * (mu_m/a)**4,
+}
+
+
+def setup():
+    global cache_keys_seen, integral_cache, disk_cache_dir, _old_best_roots_cache
+    
+    cache_keys_seen = set()
+    integral_cache = {}
+    
+    _old_best_roots_cache = ces.best_roots_cache
+    disk_cache_dir = tempfile.mkdtemp()
+    ces.best_roots_cache = ces.BestRootsCache(disk_cache_dir)
+    ces.best_roots_cache.regenerate(tests.MAX_MODE, tests.DECIMAL_PRECISION)
+
+def teardown():
+    cache_keys_seen.clear()
+    ces.best_roots_cache = _old_best_roots_cache
+    shutil.rmtree(disk_cache_dir)
+
+def test_integrate():
+    for integral_id in BaseIntegral.plugins.valid_ids: #@UndefinedVariable
+        integral = BaseIntegral.coerce(integral_id) #@UndefinedVariable
+        
+        # Skip integrals with parents, as they behave the same
+        if integral.has_parent():
+            continue
+        
+        integral_cache[integral_id] = {}
+        
+        for beam_type_id in BaseBeamType.plugins.valid_ids: #@UndefinedVariable
+            # Clear out `cache_keys_seen` before testing a new
+            # beam_type/integral combination
+            cache_keys_seen.clear()
+            
+            for m, t, v, n in integral.iterate_over_used_variables(start_mode=START_MODE, max_mode=tests.MAX_MODE):
+                yield check_integrate, integral_id, beam_type_id, m, t, v, n
+
+def test_integrate_options():
+    integral = BaseIntegral.coerce(1) #@UndefinedVariable
+    beam_type = BaseBeamType.coerce(1) #@UndefinedVariable
+    
+    def base_integrate(**kwargs):
+        return integrate(
+            integral, beam_type,
+            a=1.,
+            m=1, n=1,
+            decimal_precision=tests.DECIMAL_PRECISION,
+            **kwargs
+        )
+    
+    # When called with `error=True` `integrate` should return a `(result, error)` tuple
+    result_with_error_info = base_integrate(error=True)
+    assert_is(type(result_with_error_info), tuple)
+    
+    # `integrate` should return the same `result` regardless of the `error` flag
+    assert_equal(base_integrate(), result_with_error_info[0])
+
+def check_integrate(integral_id, beam_type_id, m, t, v, n):
+    integral = BaseIntegral.coerce(integral_id) #@UndefinedVariable
+    
+    cache_key = integral.cache_key(m, t, v, n, max_mode=tests.MAX_MODE)
+    if cache_key in cache_keys_seen: # Skip cached integrals
+        raise SkipTest
+    
+    cache_keys_seen.add(cache_key)
+    beam_type = BaseBeamType.coerce(beam_type_id) #@UndefinedVariable
+    
+    result, error = integrate(
+        integral, beam_type,
+        a=1.,
+        m=m, t=t, v=v, n=n,
+        decimal_precision=tests.DECIMAL_PRECISION,
+        error=True
+    )
+    
+    # CHEAT: Cache the integration results to speed up `test_simply_supported_beam_closed_form`
+    if beam_type_id == 1 and integral_id in INTEGRAL_CLOSED_FORMS_FOR_SIMPLY_SUPPORTED_BEAM:
+        integral_cache[integral_id][(m, t, v, n)] = result
+    
+    assert_less_equal(error, tests.MAX_ERROR_TOLERANCE)
+
+def test_simply_supported_beam_closed_form():
+    for integral_id in INTEGRAL_CLOSED_FORMS_FOR_SIMPLY_SUPPORTED_BEAM:
+        integral = BaseIntegral.coerce(integral_id) #@UndefinedVariable
+        
+        # Clear out `cache_keys_seen` before testing a new integral
+        cache_keys_seen.clear()
+        
+        for m, t, v, n in integral.iterate_over_used_variables(start_mode=START_MODE, max_mode=tests.MAX_MODE):
+            yield check_simply_supported_beam_closed_form, integral_id, m, t, v, n
+
+def check_simply_supported_beam_closed_form(integral_id, m, t, v, n):
+    integral = BaseIntegral.coerce(integral_id) #@UndefinedVariable
+    
+    cache_key = integral.cache_key(m, t, v, n, max_mode=tests.MAX_MODE)
+    if cache_key in cache_keys_seen: # Skip cached integrals
+        raise SkipTest
+    
+    cache_keys_seen.add(cache_key)
+    beam_type = BaseBeamType.coerce(1) #@UndefinedVariable
+    
+    # SPEED HACK: Integration already done in `test_integrate`
+    result = integral_cache[integral_id][(m, t, v, n)]
+    
+    if m != n:
+        # The result should be 0, as per D.D. Milasinovic
+        assert_less_equal(Abs(result), tests.MAX_ERROR_TOLERANCE)
+    else:
+        closed_form = INTEGRAL_CLOSED_FORMS_FOR_SIMPLY_SUPPORTED_BEAM[integral_id]
+        closed_form_result = closed_form.evalf(n=tests.DECIMAL_PRECISION, subs={
+            'a': 1.,
+            'mu_m': ces.find_best_root(beam_type, m, tests.DECIMAL_PRECISION)
+        })
+        assert_almost_equal(result, closed_form_result, delta=tests.MAX_ERROR_TOLERANCE)

File tests/integrals/test_parent_child_relationship.py

+from nose.tools import eq_
+from beam_integrals.integrals import BaseIntegral
+from tests.tools import assert_in #@UnresolvedImport
+
+
+def test_parent_id():
+    for id in BaseIntegral.plugins.valid_ids: #@ReservedAssignment @UndefinedVariable
+        yield check_parent_id, id
+
+def check_parent_id(id): #@ReservedAssignment
+    cls = BaseIntegral.plugins.id_to_class[id] #@UndefinedVariable
+    if cls.parent_id():
+        parent_cls = BaseIntegral.plugins.id_to_class[cls.parent_id()] #@UndefinedVariable
+        # Can't use `issubclass()` for this check as `issubclass(cls, cls)`
+        # returns `True` and need to make sure `cls` can't be its own parent
+        # integral
+        assert_in(parent_cls, cls.__bases__) 
+    else:
+        # Make sure a parentless integral isn't inheriting any other integrals
+        superclasses = set(cls.__mro__[1:]) # `cls.__mro__[0] is cls`
+        assert superclasses.isdisjoint(BaseIntegral.plugins.classes) #@UndefinedVariable
+
+def test_has_parent():
+    for id in BaseIntegral.plugins.valid_ids: #@ReservedAssignment @UndefinedVariable
+        yield check_has_parent, id
+
+def check_has_parent(id): #@ReservedAssignment
+    cls = BaseIntegral.plugins.id_to_class[id] #@UndefinedVariable
+    eq_(cls.has_parent(), cls.parent_id() is not None)