Robert Brewer avatar Robert Brewer committed 726e0b5

Fix for #775 (Caching has a performance-killing race condition?). The caching tool now does antistampeding by default.

Comments (0)

Files changed (3)

cherrypy/_cptools.py

 class CachingTool(Tool):
     """Caching Tool for CherryPy."""
     
-    def _wrapper(self, invalid_methods=("POST", "PUT", "DELETE"), **kwargs):
+    def _wrapper(self, **kwargs):
         request = cherrypy.serving.request
-        
-        if not hasattr(cherrypy, "_cache"):
-            # Make a process-wide Cache object.
-            cherrypy._cache = kwargs.pop("cache_class", _caching.MemoryCache)()
-            
-            # Take all remaining kwargs and set them on the Cache object.
-            for k, v in kwargs.items():
-                setattr(cherrypy._cache, k, v)
-        
-        if _caching.get(invalid_methods=invalid_methods):
+        if _caching.get(**kwargs):
             request.handler = None
         else:
             if request.cacheable:

cherrypy/lib/caching.py

 import cherrypy
 from cherrypy.lib import cptools, httputil
 
-class VaryHeaderAwareStore:
+
+class Cache(object):
+    
+    def get(self):
+        raise NotImplemented
+    
+    def put(self, obj, size):
+        raise NotImplemented
+    
+    def delete(self):
+        raise NotImplemented
+    
+    def clear(self):
+        raise NotImplemented
+
+
+
+# ------------------------------- Memory Cache ------------------------------- #
+
+
+class AntiStampedeCache(dict):
+    
+    def wait(self, key, timeout=5, debug=False):
+        """Return the cached value for the given key, or None.
+        
+        If timeout is not None (the default), and the value is already
+        being calculated by another thread, wait until the given timeout has
+        elapsed. If the value is available before the timeout expires, it is
+        returned. If not, None is returned, and a sentinel placed in the cache
+        to signal other threads to wait.
+        
+        If timeout is None, no waiting is performed nor sentinels used.
+        """
+        value = self.get(key)
+        if isinstance(value, threading._Event):
+            if timeout is None:
+                # Ignore the other thread and recalc it ourselves.
+                if debug:
+                    cherrypy.log('No timeout', 'TOOLS.CACHING')
+                return None
+            
+            # Wait until it's done or times out.
+            if debug:
+                cherrypy.log('Waiting up to %s seconds' % timeout, 'TOOLS.CACHING')
+            value.wait(timeout)
+            if value.result is not None:
+                # The other thread finished its calculation. Use it.
+                if debug:
+                    cherrypy.log('Result!', 'TOOLS.CACHING')
+                return value.result
+            # Timed out. Stick an Event in the slot so other threads wait
+            # on this one to finish calculating the value.
+            if debug:
+                cherrypy.log('Timed out', 'TOOLS.CACHING')
+            e = threading.Event()
+            e.result = None
+            dict.__setitem__(self, key, e)
+            
+            return None
+        elif value is None:
+            # Stick an Event in the slot so other threads wait
+            # on this one to finish calculating the value.
+            if debug:
+                cherrypy.log('Timed out', 'TOOLS.CACHING')
+            e = threading.Event()
+            e.result = None
+            dict.__setitem__(self, key, e)
+        return value
+    
+    def __setitem__(self, key, value):
+        """Set the cached value for the given key."""
+        existing = self.get(key)
+        dict.__setitem__(self, key, value)
+        if isinstance(existing, threading._Event):
+            # Set Event.result so other threads waiting on it have
+            # immediate access without needing to poll the cache again.
+            existing.result = value
+            existing.set()
+
+
+class MemoryCache(Cache):
+    """An in-memory cache for varying response content.
+    
+    Each key in self.store is a URI, and each value is an AntiStampedeCache.
+    The response for any given URI may vary based on the values of
+    "selecting request headers"; that is, those named in the Vary
+    response header. We assume the list of header names to be constant
+    for each URI throughout the lifetime of the application, and store
+    that list in self.store[uri].selecting_headers.
+    
+    The items contained in self.store[uri] have keys which are tuples of request
+    header values (in the same order as the names in its selecting_headers),
+    and values which are the actual responses.
     """
-    A cache store that honors the Vary headers and keeps
-    a separate cached copy for each.
-    """
-    def __init__(self):
-        # keep a nested dictionary of cached responses indexed first by
-        #  URI and then by "Vary" header values
-        self.uri_store = {}
-    
-    def get_key_from_request(self, request, response=None):
-        """The key into the index needs to be some combination of
-        the URI and the request headers indicated by the response
-        as Varying in a normalized (e.g. sorted) format."""
-        # First, get a cached response for the URI
-        uri = VaryHeaderUnawareStore.get_key_from_request(request)
-        try:
-            # Try to get the cached response from the uri
-            cached_resp = self._get_any_response(uri)
-            response_headers = cached_resp[1]
-        except KeyError:
-            # if the cached response isn't available, use the immediate
-            #  response
-            response_headers = response.headers
-        vary_header_values = self._get_vary_header_values(request, response_headers)
-        return uri, '|'.join(vary_header_values)
-
-    def _get_vary_header_values(request, response_headers):
-        h = response_headers
-        vary_header_names = [e.value for e in h.elements('Vary')]
-        vary_header_names.sort()
-        vary_header_values = [
-            request.headers.get(h_name, '')
-            for h_name in vary_header_names]
-        return vary_header_values
-    _get_vary_header_values = staticmethod(_get_vary_header_values)
-
-    def _get_any_response(self, uri):
-        """
-        When a request for a URI comes in, we need to check
-        if it is already in the cache, but the response hasn't
-        yet been generated to determine the vary headers.
-        We can assume the Vary headers do not change for a
-        given URI, so use the Vary headers from a previous
-        response (any will do).
-        """
-        vary_store = self.uri_store.get(uri)
-        if not vary_store:
-            # No values exist for this URI
-            raise KeyError(uri)
-        # Return the first value
-        for s in vary_store.values():
-            return s
-
-    def __getitem__(self, key):
-        return self.get(key)
-        
-    def get(self, key, *args, **kwargs):
-        uri, h_vals = key
-        vary_store = self.uri_store.get(uri, {})
-        return vary_store.get(h_vals, *args, **kwargs)
-        
-    def __setitem__(self, key, value):
-        uri, h_vals = key
-        vary_store = self.uri_store.setdefault(uri, {})
-        vary_store[h_vals] = value
-    
-    def __delitem__(self, key):
-        uri, h_vals = key
-        vary_store = self.uri_store[uri]
-        del vary_store[h_vals]
-        if not vary_store:
-            # if the vary store is empty, delete the URI entry also
-            del self.uri_store[uri]
-    
-    def pop(self, key, *args, **kwargs):
-        item = self.get(key, *args, **kwargs)
-        del self[key]
-    
-    def __len__(self):
-        lengths = [len(store) for store in self.uri_store.values()]
-        return sum(lengths)
-
-class VaryHeaderUnawareStore(dict):
-    def get_key_from_request(request, response=None):
-        return cherrypy.url(qs=request.query_string)
-    get_key_from_request = staticmethod(get_key_from_request)
-
-class MemoryCache:
     
     maxobjects = 1000
     maxobj_size = 100000
     maxsize = 10000000
     delay = 600
+    antistampede_timeout = 5
+    expire_freq = 0.1
+    debug = False
     
     def __init__(self):
         self.clear()
+        
+        # Run self.expire_cache in a separate daemon thread.
         t = threading.Thread(target=self.expire_cache, name='expire_cache')
         self.expiration_thread = t
         if hasattr(threading.Thread, "daemon"):
     
     def clear(self):
         """Reset the cache to its initial, empty state."""
-        self.store = VaryHeaderAwareStore()
+        self.store = {}
         self.expirations = {}
         self.tot_puts = 0
         self.tot_gets = 0
         self.tot_non_modified = 0
         self.cursize = 0
     
-    def key(self):
-        request = cherrypy.serving.request
-        try:
-            response = cherrypy.serving.response
-        except AttributeError:
-            response = None
-        return self.store.get_key_from_request(request, response)
-    
     def expire_cache(self):
         # expire_cache runs in a separate thread which the servers are
         # not aware of. It's possible that "time" will be set to None
             # during iteration
             for expiration_time, objects in self.expirations.items():
                 if expiration_time <= now:
-                    for obj_size, obj_key in objects:
+                    for obj_size, uri, sel_header_values in objects:
                         try:
-                            del self.store[obj_key]
+                            del self.store[uri][sel_header_values]
                             self.tot_expires += 1
                             self.cursize -= obj_size
                         except KeyError:
                             # the key may have been deleted elsewhere
                             pass
                     del self.expirations[expiration_time]
-            time.sleep(0.1)
+            time.sleep(self.expire_freq)
     
     def get(self):
-        """Return the object if in the cache, else None."""
+        """Return the current variant if in the cache, else None."""
+        request = cherrypy.serving.request
         self.tot_gets += 1
-        cache_item = self.store.get(self.key(), None)
-        if cache_item:
+        
+        uri = cherrypy.url(qs=request.query_string)
+        uricache = self.store.get(uri)
+        if uricache is None:
+            return None
+        
+        header_values = [request.headers.get(h, '')
+                         for h in uricache.selecting_headers]
+        header_values.sort()
+        variant = uricache.wait(key=tuple(header_values),
+                                timeout=self.antistampede_timeout,
+                                debug=self.debug)
+        if variant is not None:
             self.tot_hist += 1
-            return cache_item
-        else:
-            return None
+        return variant
     
-    def put(self, obj):
+    def put(self, variant, size):
+        """Store the current variant in the cache."""
+        request = cherrypy.serving.request
+        response = cherrypy.serving.response
+        
+        uri = cherrypy.url(qs=request.query_string)
+        uricache = self.store.get(uri)
+        if uricache is None:
+            uricache = AntiStampedeCache()
+            uricache.selecting_headers = [
+                e.value for e in response.headers.elements('Vary')]
+            self.store[uri] = uricache
+        
         if len(self.store) < self.maxobjects:
-            # Size check no longer includes header length
-            obj_size = len(obj[2])
-            total_size = self.cursize + obj_size
+            total_size = self.cursize + size
             
             # checks if there's space for the object
-            if (obj_size < self.maxobj_size and total_size < self.maxsize):
-                # add to the expirations list and cache
-                expiration_time = cherrypy.serving.response.time + self.delay
-                obj_key = self.key()
+            if (size < self.maxobj_size and total_size < self.maxsize):
+                # add to the expirations list
+                expiration_time = response.time + self.delay
                 bucket = self.expirations.setdefault(expiration_time, [])
-                bucket.append((obj_size, obj_key))
-                self.store[obj_key] = obj
+                bucket.append((size, uri, uricache.selecting_headers))
+                
+                # add to the cache
+                header_values = [request.headers.get(h, '')
+                                 for h in uricache.selecting_headers]
+                header_values.sort()
+                uricache[tuple(header_values)] = variant
                 self.tot_puts += 1
                 self.cursize = total_size
+                return
     
     def delete(self):
-        self.store.pop(self.key(), None)
+        """Remove ALL cached variants of the current resource."""
+        uri = cherrypy.url(qs=cherrypy.serving.request.query_string)
+        self.store.pop(uri, None)
 
 
 def get(invalid_methods=("POST", "PUT", "DELETE"), debug=False, **kwargs):
     request = cherrypy.serving.request
     response = cherrypy.serving.response
     
+    if not hasattr(cherrypy, "_cache"):
+        # Make a process-wide Cache object.
+        cherrypy._cache = kwargs.pop("cache_class", MemoryCache)()
+        
+        # Take all remaining kwargs and set them on the Cache object.
+        for k, v in kwargs.items():
+            setattr(cherrypy._cache, k, v)
+        cherrypy._cache.debug = debug
+    
     # POST, PUT, DELETE should invalidate (delete) the cached copy.
     # See http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.10.
     if request.method in invalid_methods:
         return False
     
     cache_data = cherrypy._cache.get()
-    request.cached = c = bool(cache_data)
-    request.cacheable = not c
-    if c:
+    request.cached = bool(cache_data)
+    request.cacheable = not request.cached
+    if request.cached:
+        # Serve the cached copy.
         if debug:
             cherrypy.log('Reading response from cache', 'TOOLS.CACHING')
-        s, h, b, create_time, original_req_headers = cache_data
+        s, h, b, create_time = cache_data
         
         # Copy the response headers. See http://www.cherrypy.org/ticket/721.
         response.headers = rh = httputil.HeaderMap()
     else:
         if debug:
             cherrypy.log('request is not cached', 'TOOLS.CACHING')
-    return c
+    return request.cached
 
 
 def tee_output():
         if response.headers.get('Pragma', None) != 'no-cache':
             # save the cache data
             body = ''.join(output)
-            vary = [he.value for he in response.headers.elements('Vary')]
-            sel_headers = dict([(k, v) for k, v
-                                in cherrypy.serving.request.headers.items()
-                                if k in vary])
             cherrypy._cache.put((response.status, response.headers or {},
-                                 body, response.time, sel_headers))
+                                 body, response.time), len(body))
     
     response = cherrypy.serving.response
     response.body = tee(response.body)

cherrypy/test/test_caching.py

 from cherrypy.test import test
 test.prefer_parent_path()
 
+import datetime
 import gzip
+from itertools import count
 import os
 curdir = os.path.join(os.getcwd(), os.path.dirname(__file__))
-from itertools import count
+import sys
+import threading
+import time
+import urllib
 
 import cherrypy
 from cherrypy.lib import httputil
         
         def __init__(self):
             cherrypy.counter = 0
+            self.longlock = threading.Lock()
         
         def index(self):
             cherrypy.counter += 1
             cherrypy.response.headers['Last-Modified'] = httputil.HTTPDate()
             return gif_bytes
         a_gif.exposed = True
-
+        
+        def long_process(self, seconds='1'):
+            try:
+                self.longlock.acquire()
+                time.sleep(float(seconds))
+            finally:
+                self.longlock.release()
+            return 'success!'
+        long_process.exposed = True
+        
+        def clear_cache(self, path):
+            cherrypy._cache.store[cherrypy.request.base + path].clear()
+        clear_cache.exposed = True
+    
     class VaryHeaderCachingServer(object):
         
         _cp_config = {'tools.caching.on': True,
         self.assertStatus("200 OK")
         self.assertHeaderItemValue('Vary', 'Our-Varying-Header')
         self.assertBody('visit #1')
-
-        #Now check that diffrent 'Vary'-fields don't evict eachother.
-        # This test creates a 2 requests with different 'Our-Varying-Header'
-        # and then test if the first one still exists.
+        
+        # Now check that different 'Vary'-fields don't evict each other.
+        # This test creates 2 requests with different 'Our-Varying-Header'
+        # and then tests if the first one still exists.
         self.getPage("/varying_headers/", headers=[('Our-Varying-Header', 'request 2')])
         self.assertStatus("200 OK")
         self.assertBody('visit #2')
         self.assertNoHeader("Last-Modified")
         if not getattr(cherrypy.server, "using_apache", False):
             self.assertHeader("Age")
+    
+    def test_antistampede(self):
+        SECONDS = 4
+        # We MUST make an initial synchronous request in order to create the
+        # AntiStampedeCache object, and populate its selecting_headers,
+        # before the actual stampede.
+        self.getPage("/long_process?seconds=%d" % SECONDS)
+        self.assertBody('success!')
+        self.getPage("/clear_cache?path=" +
+            urllib.quote('/long_process?seconds=%d' % SECONDS, safe=''))
+        self.assertStatus(200)
+        sys.stdout.write("prepped... ")
+        sys.stdout.flush()
+        
+        start = datetime.datetime.now()
+        def run():
+            self.getPage("/long_process?seconds=%d" % SECONDS)
+            # The response should be the same every time
+            self.assertBody('success!')
+        ts = [threading.Thread(target=run) for i in xrange(100)]
+        for t in ts:
+            t.start()
+        for t in ts:
+            t.join()
+        self.assertEqualDates(start, datetime.datetime.now(),
+                              # Allow a second for our thread/TCP overhead etc.
+                              seconds=SECONDS + 1)
 
 
 if __name__ == '__main__':
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.