1. Jeremy Banks
  2. irc

Commits

Eric Raymond  committed 1cae726

Make the methods of the IRC object thread-safe.

How it's done: We equip each IRC instance with a reentrant thread lock. Then, method bodies that modify mutable shared state (the connection list, the handlers list, and the delayed-events list) are guarded by the mutex using the Python `with` statement, which implicitly invokes the lock acquire()/release() methods at its scope bundaries.

The body of the process_once() function is also guarded, so any event handlers that it fires also won't be able to asynchronously mess with the shared state.

These changes have been tested in irker.

  • Participants
  • Parent commits 81b2e36
  • Branches default

Comments (0)

Files changed (1)

File irc/client.py

View file
 import struct
 import logging
 import itertools
+import threading
 
 try:
     import pkg_resources
 
 # TODO
 # ----
-# (maybe) thread safety
 # (maybe) color parser convenience functions
 # documentation (including all event types)
 # (maybe) add awareness of different types of ircds
     This will connect to the IRC server irc.some.where on port 6667
     using the nickname my_nickname and send the message "Hi there!"
     to the nickname a_nickname.
+
+    The methods of this class are thread-safe; accesses to and modifications of
+    its internal lists of connections, handlers, and delayed commands
+    are guarded by a mutex.
     """
 
     def __init__(self, fn_to_add_socket=None,
         self.connections = []
         self.handlers = {}
         self.delayed_commands = []  # list of DelayedCommands
+        # Modifications to these shared lists and dict need to be thread-safe
+        self.mutex = threading.RLock()
 
         self.add_global_handler("ping", _ping_ponger, -42)
 
         """Creates and returns a ServerConnection object."""
 
         c = ServerConnection(self)
-        self.connections.append(c)
+        with self.mutex:
+            self.connections.append(c)
         return c
 
     def process_data(self, sockets):
 
         See documentation for IRC.__init__.
         """
-        log.log(logging.DEBUG-2, "process_data()")
-        for s, c in itertools.product(sockets, self.connections):
-            if s == c._get_socket():
-                c.process_data()
+        with self.mutex:
+            log.log(logging.DEBUG-2, "process_data()")
+            for s, c in itertools.product(sockets, self.connections):
+                if s == c._get_socket():
+                    c.process_data()
 
     def process_timeout(self):
         """Called when a timeout notification is due.
 
         See documentation for IRC.__init__.
         """
-        while self.delayed_commands:
-            command = self.delayed_commands[0]
-            if not command.due():
-                break
-            command.function(*command.arguments)
-            if isinstance(command, PeriodicCommand):
-                self._schedule_command(command.next())
-            del self.delayed_commands[0]
+        with self.mutex:
+            while self.delayed_commands:
+                command = self.delayed_commands[0]
+                if not command.due():
+                    break
+                command.function(*command.arguments)
+                if isinstance(command, PeriodicCommand):
+                    self._schedule_command(command.next())
+                del self.delayed_commands[0]
 
     def process_once(self, timeout=0):
         """Process data from connections once.
         incoming data, if there are any.  If that seems boring, look
         at the process_forever method.
         """
-        log.log(logging.DEBUG-2, "process_once()")
-        sockets = [x._get_socket() for x in self.connections if x is not None]
-        if sockets:
-            (i, o, e) = select.select(sockets, [], [], timeout)
-            self.process_data(i)
-        else:
-            time.sleep(timeout)
-        self.process_timeout()
+        with self.mutex:
+            log.log(logging.DEBUG-2, "process_once()")
+            sockets = [x._get_socket() for x in self.connections if x is not None]
+            if sockets:
+                (i, o, e) = select.select(sockets, [], [], timeout)
+                self.process_data(i)
+            else:
+                time.sleep(timeout)
+            self.process_timeout()
 
     def process_forever(self, timeout=0.2):
         """Run an infinite loop, processing data from connections.
 
             timeout -- Parameter to pass to process_once.
         """
+        # This loop should specifically *not* be mutex-locked.
+        # Otherwise no other thread would ever be able to change
+        # the shared state of an IRC object running this function.
         log.debug("process_forever(timeout=%s)", timeout)
         while 1:
             self.process_once(timeout)
 
     def disconnect_all(self, message=""):
         """Disconnects all connections."""
-        for c in self.connections:
-            c.disconnect(message)
+        with self.mutex:
+            for c in self.connections:
+                c.disconnect(message)
 
     def add_global_handler(self, event, handler, priority=0):
         """Adds a global handler function for a specific event type.
         number is highest priority).  If a handler function returns
         "NO MORE", no more handlers will be called.
         """
-        event_handlers = self.handlers.setdefault(event, [])
-        bisect.insort(event_handlers, (priority, handler))
+        with self.mutex:
+            event_handlers = self.handlers.setdefault(event, [])
+            bisect.insort(event_handlers, (priority, handler))
 
     def remove_global_handler(self, event, handler):
         """Removes a global handler function.
 
         Returns 1 on success, otherwise 0.
         """
-        if not event in self.handlers:
-            return 0
-        for h in self.handlers[event]:
-            if handler == h[1]:
-                self.handlers[event].remove(h)
+        with self.mutex:
+            if not event in self.handlers:
+                return 0
+            for h in self.handlers[event]:
+                if handler == h[1]:
+                    self.handlers[event].remove(h)
         return 1
 
     def execute_at(self, at, function, arguments=()):
         self._schedule_command(command)
 
     def _schedule_command(self, command):
-        bisect.insort(self.delayed_commands, command)
-        if self.fn_to_add_timeout:
-            self.fn_to_add_timeout(util.total_seconds(command.delay))
+        with self.mutex:
+            bisect.insort(self.delayed_commands, command)
+            if self.fn_to_add_timeout:
+                self.fn_to_add_timeout(util.total_seconds(command.delay))
 
     def dcc(self, dcctype="chat"):
         """Creates and returns a DCCConnection object.
                        incoming data will be split in newline-separated
                        chunks. If "raw", incoming data is not touched.
         """
-        c = DCCConnection(self, dcctype)
-        self.connections.append(c)
+        with self.mutex:
+            c = DCCConnection(self, dcctype)
+            self.connections.append(c)
         return c
 
     def _handle_event(self, connection, event):
         """[Internal]"""
-        h = self.handlers
-        th = sorted(h.get("all_events", []) + h.get(event.eventtype(), []))
-        for handler in th:
-            if handler[1](connection, event) == "NO MORE":
-                return
+        with self.mutex:
+            h = self.handlers
+            th = sorted(h.get("all_events", []) + h.get(event.eventtype(), []))
+            for handler in th:
+                if handler[1](connection, event) == "NO MORE":
+                    return
 
     def _remove_connection(self, connection):
         """[Internal]"""
-        self.connections.remove(connection)
-        if self.fn_to_remove_socket:
-            self.fn_to_remove_socket(connection._get_socket())
+        with self.mutex:
+            self.connections.remove(connection)
+            if self.fn_to_remove_socket:
+                self.fn_to_remove_socket(connection._get_socket())
 
 class DelayedCommand(datetime.datetime):
     """