Commits

Edward Catmur  committed 0cd3db3

Catch recursive assertions within __repr__; test.

If assertion introspection is turned on (--assert=reinterp, --assert=rewrite)
then pytest will call str() on intermediates in failing assert statements.
If the failing assert statement occurs within an object's __repr__, and the
object or a bound method is an intermediate in the assert, then this will
result in infinite recursion. Use a thread-local to detect when our
assertion.util.saferepr is called within itself, indicating a recursive
asserting __repr__.

  • Participants
  • Parent commits 3bd27c4

Comments (1)

  1. Edward Catmur author

    This was originally discovered in Pandas; https://github.com/pydata/pandas/blob/maintenance/0.7.x/pandas/core/frame.py#L3097

      def _apply_standard(self, func, axis, ignore_failures=False):
            try:
    
                assert(not self._is_mixed_type)  # maybe a hack for now
                ...
            except Exception:    # catches AssertionError
                pass
    
            # code for mixed-type DataFrame follows
    

    The latest version of pandas doesn’t have the bug, because the assert has been changed to throw an exception instead. It’s still using exceptions for intra-method flow control, where an if statement would be just as good and far clearer: https://github.com/pydata/pandas/blob/master/pandas/core/frame.py#L4108

    This was the original triggering test:

    import pandas as pd
    def test_should_not_throw_recursion_error_for_columns_with_different_types():
        df = pd.DataFrame({
                'foo': [1.5],
                'bar': [2]
                })
        df_str = df.to_string()
        print "got as string"
        print df_str
    

    PyTest assertion introspection http://pytest.org/latest/assert.html#assert-details kicks in and PyTest decides to print out the objects in the expression that failed the assertion – this means calling repr(self), which calls self.to_string, which hits the assertion again, and so on.

    Exception RuntimeError: 'maximum recursion depth exceeded in __subclasscheck__' in <type 'exceptions.RuntimeError'> ignored
    Exception RuntimeError: 'maximum recursion depth exceeded while calling a Python object' in <type 'exceptions.AttributeError'> ignored
    Exception RuntimeError: 'maximum recursion depth exceeded while calling a Python object' in <type 'exceptions.AttributeError'> ignored
    …
    

Files changed (5)

File _pytest/assertion/newinterpret.py

                 result = self.frame.eval(co)
             except Exception:
                 raise Failure()
-            explanation = self.frame.repr(result)
+            explanation = util.saferepr(result)
             return explanation, result
         elif _is_ast_stmt(node):
             mod = ast.Module([node])
         except Exception:
             raise Failure(explanation)
         pattern = "%s\n{%s = %s\n}"
-        rep = self.frame.repr(result)
+        rep = util.saferepr(result)
         explanation = pattern % (rep, rep, explanation)
         return explanation, result
 
             result = self.frame.eval(co, __exprinfo_expr=source_result)
         except Exception:
             raise Failure(explanation)
-        explanation = "%s\n{%s = %s.%s\n}" % (self.frame.repr(result),
-                                              self.frame.repr(result),
+        explanation = "%s\n{%s = %s.%s\n}" % (util.saferepr(result),
+                                              util.saferepr(result),
                                               source_explanation, attr.attr)
         # Check if the attr is from an instance.
         source = "%r in getattr(__exprinfo_expr, '__dict__', {})"
         except Exception:
             from_instance = None
         if from_instance is None or self.frame.is_true(from_instance):
-            rep = self.frame.repr(result)
+            rep = util.saferepr(result)
             pattern = "%s\n{%s = %s\n}"
             explanation = pattern % (rep, rep, explanation)
         return explanation, result

File _pytest/assertion/oldinterpret.py

 import py
 import sys, inspect
 from compiler import parse, ast, pycodegen
-from _pytest.assertion.util import format_explanation, BuiltinAssertionError
+from _pytest.assertion.util import format_explanation, BuiltinAssertionError, saferepr
 
 passthroughex = py.builtin._sysex
 
         except:
             raise Failure(self)
         self.result = result
-        self.explanation = self.explanation or frame.repr(self.result)
+        self.explanation = self.explanation or saferepr(self.result)
 
     def run(self, frame):
         # fall-back for unknown statement nodes
         except:
             raise Failure(self)
         if not node.is_builtin(frame) or not self.is_bool(frame):
-            r = frame.repr(self.result)
+            r = saferepr(self.result)
             self.explanation = '%s\n{%s = %s\n}' % (r, r, self.explanation)
 
 class Getattr(Interpretable):
         except:
             from_instance = True
         if from_instance:
-            r = frame.repr(self.result)
+            r = saferepr(self.result)
             self.explanation = '%s\n{%s = %s\n}' % (r, r, self.explanation)
 
 # == Re-interpretation of full statements ==

File _pytest/assertion/rewrite.py

     AssertionRewriter().run(mod)
 
 
-_saferepr = py.io.saferepr
+from _pytest.assertion.util import saferepr as _saferepr
 from _pytest.assertion.util import format_explanation as _format_explanation
 
 def _should_repr_global_name(obj):

File _pytest/assertion/util.py

 """Utilities for assertion debugging"""
 
 import py
+import threading
 
 BuiltinAssertionError = py.builtin.builtins.AssertionError
 
 # DebugInterpreter.
 _reprcompare = None
 
+
+def saferepr(obj):
+    if saferepr.thread_local.active:
+        return '<[recursive assertion in __repr__] %s object at 0x%x>' % (
+            obj.__class__.__name__, id(obj))
+    saferepr.thread_local.active = True
+    try:
+        return py.io.saferepr(obj)
+    finally:
+        saferepr.thread_local.active = False
+saferepr.thread_local = threading.local()
+saferepr.thread_local.active = False
+
+
 def format_explanation(explanation):
     """This formats an explanation
 

File testing/test_assertion.py

     result = testdir.runpytest("--assert=reinterp")
     assert result.ret == 0
 
+def test_repr_assertion_recursion(testdir):
+    testdir.makepyfile("""
+        import pytest
+
+        class A(object):
+            _in___repr__ = 0
+            def is_valid(self):
+                return False
+            def __repr__(self):
+                if self._in___repr__ > 10:
+                    pytest.exit('recursion in __repr__')
+                self._in___repr__ += 1
+                try:
+                    assert self.is_valid()
+                finally:
+                    self._in___repr__ -= 1
+
+        def test_repr_checks_is_valid():
+            try:
+                with pytest.raises(AssertionError) as ex:
+                    repr(A())
+            except pytest.exit.Exception as ex:
+                pytest.fail(ex)
+    """)
+    for mode in '--assert=reinterp', '--assert=rewrite':
+        assert testdir.runpytest(mode).ret == 0
+
 def test_triple_quoted_string_issue113(testdir):
     testdir.makepyfile("""
         def test_hello():