Commits

Eric Raymond committed 1cae726

Make the methods of the IRC object thread-safe.

How it's done: We equip each IRC instance with a reentrant thread lock. Then, method bodies that modify mutable shared state (the connection list, the handlers list, and the delayed-events list) are guarded by the mutex using the Python `with` statement, which implicitly invokes the lock acquire()/release() methods at its scope bundaries.

The body of the process_once() function is also guarded, so any event handlers that it fires also won't be able to asynchronously mess with the shared state.

These changes have been tested in irker.

Comments (0)

Files changed (1)

 import struct
 import logging
 import itertools
+import threading
 
 try:
     import pkg_resources
 
 # TODO
 # ----
-# (maybe) thread safety
 # (maybe) color parser convenience functions
 # documentation (including all event types)
 # (maybe) add awareness of different types of ircds
     This will connect to the IRC server irc.some.where on port 6667
     using the nickname my_nickname and send the message "Hi there!"
     to the nickname a_nickname.
+
+    The methods of this class are thread-safe; accesses to and modifications of
+    its internal lists of connections, handlers, and delayed commands
+    are guarded by a mutex.
     """
 
     def __init__(self, fn_to_add_socket=None,
         self.connections = []
         self.handlers = {}
         self.delayed_commands = []  # list of DelayedCommands
+        # Modifications to these shared lists and dict need to be thread-safe
+        self.mutex = threading.RLock()
 
         self.add_global_handler("ping", _ping_ponger, -42)
 
         """Creates and returns a ServerConnection object."""
 
         c = ServerConnection(self)
-        self.connections.append(c)
+        with self.mutex:
+            self.connections.append(c)
         return c
 
     def process_data(self, sockets):
 
         See documentation for IRC.__init__.
         """
-        log.log(logging.DEBUG-2, "process_data()")
-        for s, c in itertools.product(sockets, self.connections):
-            if s == c._get_socket():
-                c.process_data()
+        with self.mutex:
+            log.log(logging.DEBUG-2, "process_data()")
+            for s, c in itertools.product(sockets, self.connections):
+                if s == c._get_socket():
+                    c.process_data()
 
     def process_timeout(self):
         """Called when a timeout notification is due.
 
         See documentation for IRC.__init__.
         """
-        while self.delayed_commands:
-            command = self.delayed_commands[0]
-            if not command.due():
-                break
-            command.function(*command.arguments)
-            if isinstance(command, PeriodicCommand):
-                self._schedule_command(command.next())
-            del self.delayed_commands[0]
+        with self.mutex:
+            while self.delayed_commands:
+                command = self.delayed_commands[0]
+                if not command.due():
+                    break
+                command.function(*command.arguments)
+                if isinstance(command, PeriodicCommand):
+                    self._schedule_command(command.next())
+                del self.delayed_commands[0]
 
     def process_once(self, timeout=0):
         """Process data from connections once.
         incoming data, if there are any.  If that seems boring, look
         at the process_forever method.
         """
-        log.log(logging.DEBUG-2, "process_once()")
-        sockets = [x._get_socket() for x in self.connections if x is not None]
-        if sockets:
-            (i, o, e) = select.select(sockets, [], [], timeout)
-            self.process_data(i)
-        else:
-            time.sleep(timeout)
-        self.process_timeout()
+        with self.mutex:
+            log.log(logging.DEBUG-2, "process_once()")
+            sockets = [x._get_socket() for x in self.connections if x is not None]
+            if sockets:
+                (i, o, e) = select.select(sockets, [], [], timeout)
+                self.process_data(i)
+            else:
+                time.sleep(timeout)
+            self.process_timeout()
 
     def process_forever(self, timeout=0.2):
         """Run an infinite loop, processing data from connections.
 
             timeout -- Parameter to pass to process_once.
         """
+        # This loop should specifically *not* be mutex-locked.
+        # Otherwise no other thread would ever be able to change
+        # the shared state of an IRC object running this function.
         log.debug("process_forever(timeout=%s)", timeout)
         while 1:
             self.process_once(timeout)
 
     def disconnect_all(self, message=""):
         """Disconnects all connections."""
-        for c in self.connections:
-            c.disconnect(message)
+        with self.mutex:
+            for c in self.connections:
+                c.disconnect(message)
 
     def add_global_handler(self, event, handler, priority=0):
         """Adds a global handler function for a specific event type.
         number is highest priority).  If a handler function returns
         "NO MORE", no more handlers will be called.
         """
-        event_handlers = self.handlers.setdefault(event, [])
-        bisect.insort(event_handlers, (priority, handler))
+        with self.mutex:
+            event_handlers = self.handlers.setdefault(event, [])
+            bisect.insort(event_handlers, (priority, handler))
 
     def remove_global_handler(self, event, handler):
         """Removes a global handler function.
 
         Returns 1 on success, otherwise 0.
         """
-        if not event in self.handlers:
-            return 0
-        for h in self.handlers[event]:
-            if handler == h[1]:
-                self.handlers[event].remove(h)
+        with self.mutex:
+            if not event in self.handlers:
+                return 0
+            for h in self.handlers[event]:
+                if handler == h[1]:
+                    self.handlers[event].remove(h)
         return 1
 
     def execute_at(self, at, function, arguments=()):
         self._schedule_command(command)
 
     def _schedule_command(self, command):
-        bisect.insort(self.delayed_commands, command)
-        if self.fn_to_add_timeout:
-            self.fn_to_add_timeout(util.total_seconds(command.delay))
+        with self.mutex:
+            bisect.insort(self.delayed_commands, command)
+            if self.fn_to_add_timeout:
+                self.fn_to_add_timeout(util.total_seconds(command.delay))
 
     def dcc(self, dcctype="chat"):
         """Creates and returns a DCCConnection object.
                        incoming data will be split in newline-separated
                        chunks. If "raw", incoming data is not touched.
         """
-        c = DCCConnection(self, dcctype)
-        self.connections.append(c)
+        with self.mutex:
+            c = DCCConnection(self, dcctype)
+            self.connections.append(c)
         return c
 
     def _handle_event(self, connection, event):
         """[Internal]"""
-        h = self.handlers
-        th = sorted(h.get("all_events", []) + h.get(event.eventtype(), []))
-        for handler in th:
-            if handler[1](connection, event) == "NO MORE":
-                return
+        with self.mutex:
+            h = self.handlers
+            th = sorted(h.get("all_events", []) + h.get(event.eventtype(), []))
+            for handler in th:
+                if handler[1](connection, event) == "NO MORE":
+                    return
 
     def _remove_connection(self, connection):
         """[Internal]"""
-        self.connections.remove(connection)
-        if self.fn_to_remove_socket:
-            self.fn_to_remove_socket(connection._get_socket())
+        with self.mutex:
+            self.connections.remove(connection)
+            if self.fn_to_remove_socket:
+                self.fn_to_remove_socket(connection._get_socket())
 
 class DelayedCommand(datetime.datetime):
     """