Commits

Anonymous committed 6677f0f

Fixes #2 - Added some basic synchronization to Connection class.

Comments (0)

Files changed (1)

stompclient/connection.py

     
     This is useful for publish-only clients when desiring a connection pool to be used in a 
     multi-threaded context (e.g. web servers).  This notably does NOT work for publish-
-    subscribe clients, since the message messages are received by a separate thread. 
+    subscribe clients, since a listener thread needs to be able to share the *same* socket 
+    with other publisher thread(s). 
     """
     pass
 
     Manages TCP connection to the STOMP server and provides an abstracted interface for sending
     and receiving STOMP frames.
     
-    This class is notably not thread-safe.  You need to use external mechanisms to guard access
-    to connections.  This is typically accomplished by using a thread-safe connection pool 
-    implementation (e.g. L{stompclient.connection.ThreadLocalConnectionPool}).
+    This class provides some basic synchronization to avoid threads stepping on eachother. 
+    Specifically the following activities are each protected by [their own] C{threading.RLock}
+    instances:
+    - connect() and disconnect() methods (share a lock).
+    - read()
+    - send()
+    
+    It is assumed that send() and recv() should be allowed to happen concurrently, so these do 
+    not *share* a lock.  If you need more thread-isolation, consider using a thread-safe 
+    connection pool implementation (e.g. L{stompclient.connection.ThreadLocalConnectionPool}).
     
     @ivar host: The hostname/address for this connection.
     @type host: C{str}
         self._sock = None
         self._buffer = FrameBuffer()
         self._connected = threading.Event()
-
+        self._connect_lock = threading.RLock()
+        self._send_lock = threading.RLock()
+        self._read_lock = threading.RLock()
+        
     @property
     def connected(self):
         """
         """
         Connects to the STOMP server if not already connected.
         """
-        if self._sock:
-            return
-        try:
-            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-            sock.connect((self.host, self.port))
-        except socket.timeout as exc:
-            raise ConnectionTimeoutError(*exc.args)
-        except socket.error as exc:
-            raise ConnectionError(*exc.args)
-        
-        sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
-        sock.settimeout(self.socket_timeout)
-        self._sock = sock
-        self._connected.set()
+        with self._connect_lock:
+            if self._sock:
+                return
+            try:
+                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+                sock.connect((self.host, self.port))
+            except socket.timeout as exc:
+                raise ConnectionTimeoutError(*exc.args)
+            except socket.error as exc:
+                raise ConnectionError(*exc.args)
+            
+            sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
+            sock.settimeout(self.socket_timeout)
+            self._sock = sock
+            self._connected.set()
         
     def disconnect(self, conf=None):
         """
         
         @raise NotConnectedError: If the connection is not currently connected. 
         """
-        if self._sock is None:
-            raise NotConnectedError()
-        try:
-            self._sock.close()
-        except socket.error:
-            pass
-        self._sock = None
-        self._buffer.clear()
-        self._connected.clear()
+        with self._connect_lock:
+            if self._sock is None:
+                raise NotConnectedError()
+            try:
+                self._sock.close()
+            except socket.error:
+                pass
+            self._sock = None
+            self._buffer.clear()
+            self._connected.clear()
     
     def send(self, frame):
         """
         @param frame: The frame to send to server.
         @type frame: L{stompclient.frame.Frame}
         """
-        self.connect()
-        try:
-            self._sock.sendall(str(frame))
-        except socket.error, e:
-            if e.args[0] == errno.EPIPE:
-                self.disconnect()
-            raise ConnectionError("Error %s while writing to socket. %s." % e.args)
+        with self._send_lock:
+            self.connect()
+            try:
+                self._sock.sendall(str(frame))
+            except socket.error, e:
+                if e.args[0] == errno.EPIPE:
+                    self.disconnect()
+                raise ConnectionError("Error %s while writing to socket. %s." % e.args)
 
     def read(self):
         """
         @return: A frame read from socket or buffered from previous socket read.
         @rtype: L{stompclient.frame.Frame}
         """
-        self.connect()
-        
-        buffered_frame = self._buffer.extract_frame()
-        
-        if buffered_frame:
-            return buffered_frame
-        else:
-            # Read bytes from socket until we have read a frame (or timeout out) and then return it.
-            received_frame = None
-            try:
-                while self._connected.is_set():
-                    bytes = self._sock.recv(8192)
-                    self._buffer.append(bytes)
-                    received_frame = self._buffer.extract_frame()
-                    if received_frame:
-                        break
-            except socket.timeout:
-                pass
-            except socket.error, e:
-                if e.args[0] == errno.EPIPE:
-                    self.disconnect()
-                raise ConnectionError("Error %s while reading from socket. %s." % e.args)
+        with self._read_lock:
+            self.connect()
             
-            return received_frame
+            buffered_frame = self._buffer.extract_frame()
+            
+            if buffered_frame:
+                return buffered_frame
+            else:
+                # Read bytes from socket until we have read a frame (or timeout out) and then return it.
+                received_frame = None
+                try:
+                    while self._connected.is_set():
+                        bytes = self._sock.recv(8192)
+                        self._buffer.append(bytes)
+                        received_frame = self._buffer.extract_frame()
+                        if received_frame:
+                            break
+                except socket.timeout:
+                    pass
+                except socket.error, e:
+                    if e.args[0] == errno.EPIPE:
+                        self.disconnect()
+                    raise ConnectionError("Error %s while reading from socket. %s." % e.args)
+                
+                return received_frame
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.