Commits

Gregory Petukhov committed 86a2cbf

Refactor method of global request counter generation

  • Participants
  • Parent commits 8bc36fd

Comments (0)

Files changed (2)

File grab/base.py

 from random import randint, choice
 from copy import copy
 import threading
+import itertools
 try:
     from urlparse import urljoin
 except ImportError:
 # This could be helpful in debuggin when your script
 # creates multiple Grab instances - in case of shared counter
 # grab instances do not overwrite dump logs
-REQUEST_COUNTER_LOCK = threading.Lock()
 GLOBAL_STATE = {
-    'request_counter': 0,
     'dom_build_time': 0,
     'selector_time': 0,
 }
+REQUEST_COUNTER = itertools.count(1)
 
 # Some extensions need GLOBAL_STATE variable
 # what's why they go after GLOBAL_STATE definition
         # Reset the state setted by previous request
         if not self._request_prepared:
             self.reset()
-            self.request_counter = self.get_request_counter()
+            self.request_counter = REQUEST_COUNTER.next()
             if kwargs:
                 self.setup(**kwargs)
             if self.proxylist and self.config['proxy_auto_change']:
             'Expect': '',
         }
 
-    def get_request_counter(self):
-        """
-        Increase global request counter and return new value
-        which will be used as request number for current request.
-        """
-
-        # TODO: do not use lock in main thread
-        REQUEST_COUNTER_LOCK.acquire()
-        GLOBAL_STATE['request_counter'] += 1
-        counter = GLOBAL_STATE['request_counter']
-        REQUEST_COUNTER_LOCK.release()
-        return counter
-
     def save_dumps(self):
         if self.config['log_dir']:
             tname = threading.currentThread().getName().lower()

File test/grab_api.py

         #SimpleExtension.get_data()['counter'] = 0
         #g = VeryCustomGrab()
         #self.assertEqual(SimpleExtension.get_data()['counter'], 2)
+
+    def test_request_counter(self):
+        import grab.base
+        import itertools
+        import threading
+
+        grab.base.REQUEST_COUNTER = itertools.count(1)
+        g = Grab(transport=GRAB_TRANSPORT)
+        g.go(SERVER.BASE_URL)
+        self.assertEqual(g.request_counter, 1)
+
+        g.go(SERVER.BASE_URL)
+        self.assertEqual(g.request_counter, 2)
+
+        def func():
+            g = Grab(transport=GRAB_TRANSPORT)
+            g.go(SERVER.BASE_URL)
+
+        # Make 10 requests in concurrent threads
+        threads = []
+        for x in xrange(10):
+            th = threading.Thread(target=func)
+            threads.append(th)
+            th.start()
+        for th in threads:
+            th.join()
+
+        g.go(SERVER.BASE_URL)
+        self.assertEqual(g.request_counter, 13)