Commits

Joe LaFata committed 5dc4a21

added more tests
fixed a bug when a key was present
only one key present per report run
sort output by runtime in descending order
add the fibonacci example

Comments (0)

Files changed (3)

src/examples/fibonacci.py

+'''
+Copyright (c) 2011, Joseph LaFata
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+    * Redistributions of source code must retain the above copyright
+      notice, this list of conditions and the following disclaimer.
+    * Redistributions in binary form must reproduce the above copyright
+      notice, this list of conditions and the following disclaimer in the
+      documentation and/or other materials provided with the distribution.
+    * Neither the name of the unitbench nor the
+      names of its contributors may be used to endorse or promote products
+      derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
+DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+'''
+
+import unitbench
+
+profiler = unitbench.FunctionProfiler()
+
+@profiler
+def fib(n):
+    if n == 1 or n == 0:
+        return n
+    return fib(n - 1) + fib(n - 2)
+
+if __name__ == '__main__':
+    
+    fib(15)
+    fib(15)
+    for i in xrange(5):
+        fib(i)
+        fib(i)
+        
+    profiler.write_results()
+    profiler.reset(track_parameters=True)
+    
+    fib(15)
+    fib(15)
+    for i in xrange(5):
+        fib(i)
+        fib(i)
+        
+    print(profiler)
+    profiler.write_results(unitbench.CsvReporter())

src/tests/test_unitbench.py

     def test_add_time(self):
         result = BenchResult("bench_sample2", 16)
         
+        eq_(result.calculate_stats(), result)
+        
         result.add_time(TimeSet(3, 1, 0))
         eq_(len(result.times), 1)
         
         results = map(lambda x: x.calculate_stats(), fp.results.values())
         self.result_check(results, 'sum23', 'call_count', 4)
         
+    def test_reset(self):
+        fp = FunctionProfiler(1, 0, False)
+        
+        @fp
+        def fib(n):
+            if n == 1 or n == 0:
+                return n
+            return fib(n - 1) + fib(n - 2)
+        
+        for i in xrange(5):
+            fib(10)
+            fib(i)
+            fib(i)
+            
+        eq_(len(str(fp).strip().split("\n")), 3)
+        fp.reset(track_parameters=True)
+        
+        for i in xrange(5):
+            fib(10)
+            fib(i)
+            fib(i)
+        
+        eq_(len(str(fp).strip().split("\n")), 13)
+        
     def test_args_to_str(self):
         eq_(FunctionProfiler._args_to_str(),
            "None")
         
         bm = sample()
         bm.run()
-        assert bm.setup_count == 0
+        eq_(bm.setup_count, 0)
     
     def test_exception(self):
         class sample(OneRun):
                 1/0
                 
         bm = sample()
-        self.assertRaises(ZeroDivisionError, bm.run)
+        assert_raises(ZeroDivisionError, bm.run)
+        
+        class sample2(sample):
+            def warmup(self):
+                return 1
+        
+        bm = sample2()
+        assert_raises(ZeroDivisionError, bm.run)
     
     def test_input(self):
         class SampleBase(OneRun):
         eq_(Benchmark._function_name_to_title("XMLBenchmark"), "Xml Benchmark")
 
 class TestConsoleReporter(TestCase):
+    def create_bm(self):
+        class sample(OneRun):
+            def input(self):
+                yield 100
+                yield "this_is_a_long_value_that_requires_use_of_a_key"
+            
+            def bench_sample1(self):
+                pass
+            
+            def bench_sample2(self):
+                pass
+            
+            def bench_this_is_a_long_name_that_requires_use_of_a_key(self):
+                pass
+            
+        return sample()
+    
     def test_count_digits(self):
         eq_(ConsoleReporter.count_digits(1), 1)
         eq_(ConsoleReporter.count_digits(5), 1)
     def test_make_key(self):
         cr = ConsoleReporter()
         title = "012345678901234567890123456789"
-        key = []
+        key = {}
         t2 = cr._make_key(title, 30, key)
         eq_(len(key), 1)
-        eq_(t2, title[23:] + '...[1]')
+        eq_(t2, title[:23] + '...[1]')
         
-        key = []
+        key = {}
         t2 = cr._make_key(title, 31, key)
         eq_(len(key), 0)
         eq_(t2, title)
         
-        key = []
+        key = {}
         t2 = cr._make_key(title, 25, key)
         eq_(len(key), 1)
-        eq_(t2, title[24-6:] + "...[1]")
+        eq_(t2, title[:24-6] + "...[1]")
         
-    def test_write_results(self):
-        class sample(OneRun):
-            def input(self):
-                yield 100
-            
-            def bench_sample1(self):
-                pass
-            
-            def bench_sample2(self):
-                pass
-        
-        bm = sample()
-        
+    def test_write_results_by_value(self):
+        bm = self.create_bm()
         stream = StringIO()
         bm.run(ConsoleReporter(stream, ConsoleReporter.GroupType.VALUE)) 
         
         
         lines = output.strip().split("\n")
         
-        eq_(len(lines), 4)
+        print output
         eq_(lines[0].find("100"), 7)
         eq_(lines[2].find("Sample1"), 0)
         eq_(lines[3].find("Sample2"), 0)
+        eq_(lines[4].find("This Is A Long Name That Requires...[1]"), 0)
+        eq_(lines[6].find("this_is_a_long_value_that_...[2]"), 7)
+        eq_(lines[13].find("[1]: This Is A Long Name That Requires Use Of A Key"), 0)
+        eq_(lines[14].find("[2]: this_is_a_long_value_that_requires_use_of_a_key"), 0)
         
+    def test_write_results_by_function(self):
+        bm = self.create_bm()
+        stream = StringIO()
+        bm.run(ConsoleReporter(stream, ConsoleReporter.GroupType.FUNCTION)) 
+        
+        output = stream.getvalue()
+        stream.close()
+        
+        lines = output.strip().split("\n")
+        
+        print output
+        eq_(lines[0].find("Sample1"), 6)
+        eq_(lines[1], "======================================================================")
+        eq_(lines[2].find("100"), 0)
+        eq_(lines[3].find("this_is_a_long_value_that_require...[1]"), 0)
+        eq_(lines[5].find("Sample2"), 6)
+        eq_(lines[7].find("100"), 0)
+        eq_(lines[10][6:39], "This Is A Long Name That Re...[2]")
+        eq_(lines[17], "[2]: This Is A Long Name That Requires Use Of A Key")   
 
 class TestCsvReporter(TestCase):        
     def test_write_results(self):
         self.call_count = len(self.times) 
         
         time_types = ["wall", "user", "system"]
-        for type in time_types:
+        for ttype in time_types:
             minimum = maxint
             maximum = -maxint
             count = 0
-            sum = 0
+            time_sum = 0
             sum_2 = 0
 
             mean = 0.0
 
             if len(self.times) > 0:
                 for t in self.times:
-                    currentTime = getattr(t, type)
+                    currentTime = getattr(t, ttype)
     
                     count += 1
-                    sum += currentTime
+                    time_sum += currentTime
                     sum_2 += currentTime ** 2
                     minimum = min(currentTime, minimum)
                     maximum = max(currentTime, maximum)
 
-                mean = sum / count
+                mean = time_sum / count
                 variance = (sum_2 / count) - (mean ** 2)
                 if variance < 0.0:
                     variance = 0.0
                 std_dev = math.sqrt(variance)
 
-                setattr(self, type + "_min", minimum)
-                setattr(self, type + "_max", maximum)
-                setattr(self, type + "_mean", mean)
-                setattr(self, type + "_variance", variance)
-                setattr(self, type + "_std_dev", std_dev)
+                setattr(self, ttype + "_min", minimum)
+                setattr(self, ttype + "_max", maximum)
+                setattr(self, ttype + "_mean", mean)
+                setattr(self, ttype + "_variance", variance)
+                setattr(self, ttype + "_std_dev", std_dev)
         return self
 
 class BenchBase(object):
         def wrapper(*args, **kwargs):
             func_call = partial(func, *args, **kwargs);
             
-            func_call.value = FunctionProfiler._args_to_str(args, kwargs)   
+            func_call.value = FunctionProfiler._args_to_str(*args, **kwargs)   
             func_call.__name__ = func.__name__
             func_call.func_name = func.func_name
             
     def write_results(self, reporter=None):
         if reporter == None:
             reporter = ConsoleReporter(group_by=ConsoleReporter.GroupType.FUNCTION)
+        
+        if not self.track_parameters:
+            for val in self.results.itervalues():
+                val.value = "### call count: " + str(len(val.times))
+        
         reporter.write_results(self.results.values())
+        
+    def reset(self, repeats=1, warmup=0, track_parameters=False):
+        self.repeats = FunctionProfiler._wrap_int(repeats)
+        self.warmup = FunctionProfiler._wrap_int(warmup)
+        self.track_parameters = track_parameters
+        
+        self.results = {}
 
 class Benchmark(BenchBase):
     """
     
     def _make_key(self, title, length, key):
         if len(title) >= length:
+            if title in key and key[title][1] == length:
+                return key[title][2]
             number = len(key)+1
             digits = ConsoleReporter.count_digits(number)
-            key.append(title)            
-            return title[length-6-digits:] + "...[{0}]".format(number)
+            output = title[:length-6-digits] + "...[{0}]".format(number)
+            key[title] = (number, length, output)
+            return output
         return title
         
     def write_results(self, results):
         
         first_title = self.group_by.title() + ": "
         title_len = len(first_title)
+        key = {}
         while len(results) > 0:
-            key = []
             
             length = 40 - title_len
             title_format = first_title + "{0:<" + str(length) + "}{1:>10}{2:>10}{3:>10}\n"
             self.stream.write(title_format.format(display_name, "user", "sys", "real"))
             self.stream.write("=" * 70 + "\n")
             
-            while len(results) > 0:
-                if getter(results[0]) != title_name:
-                    break
-                r = results.popleft()
+            title_set = []
+            while len(results) > 0 and getter(results[0]) == title_name:
+                title_set.append(results.popleft())
+            
+            title_set.sort(key=attrgetter("user_mean"), reverse=True)
+            
+            for r in title_set:
                 if (hasattr(r, "user_mean") and
                     hasattr(r, "system_mean") and hasattr(r, "wall_mean")):
                     
                     self.stream.write("{0:<40} {1:>9.4} {2:>9.4} {3:>9.4}\n"
                                       .format(o, r.user_mean,
                                               r.system_mean, r.wall_mean))
-            if len(key) > 0:
-                self.stream.write("Key:\n")
-                for idx, k in enumerate(key, start=1):
-                    self.stream.write("{0}: {1}\n".format(idx, k))
+                    
             self.stream.write("\n")
+                
+        if len(key) > 0:
+            self.stream.write("* Key:\n")
+            items = sorted(key.items(), key=lambda x: x[1][0])
+            for i in items:
+                self.stream.write("[{0}]: {1}\n".format(i[1][0], i[0]))
 
  
 class CsvReporter(Reporter):