Jason R. Coombs avatar Jason R. Coombs committed 18ccf50

Factored out much of the connection establishement logic into irc.connection.Factory. This change will enable simplifying the signature of the connect methods and also enables more custom construction of the connection (such as custom SSL parameters) without the proliferation of parameters to the connect method.

Comments (0)

Files changed (3)

+3.3
+===
+
+* Added `connection` module with a Factory for creating socket connections.
+* Added `connect_factory` parameter to the ServerConnection.
+
+It's now possible to create connections with custom SSL parameters or other
+socket wrappers. For example, to create a connection with a custom SSL cert::
+
+    import ssl
+    import irc.client
+    import irc.connection
+    import functools
+
+    irc = irc.client.IRC()
+    server = irc.server()
+    wrapper = functools.partial(ssl.wrap_socket, ssl_cert=my_cert())
+    server.connect(connect_factory = irc.connection.Factory(wrapper=wrapper))
+
+With this release, many of the parameters to `ServerConnection.connect` are
+now deprecated:
+
+    - localaddress
+    - localport
+    - ssl
+    - ipv6
+
+Instead, one should pass the appropriate values to a `connection.Factory`
+instance and pass that factory to the .connect method. Backwards-compatibility
+will be maintained for these parameters until the release of irc 4.0.
+
 3.2.3
 =====
 
 import socket
 import string
 import time
-import ssl as ssl_mod
+import warnings
 import datetime
 import struct
 import logging
 from . import util
 from . import strings
 from . import modes
+from . import connection
 
 log = logging.getLogger(__name__)
 
         super(ServerConnection, self).__init__(irclibobj)
         self.connected = False
         self.socket = None
-        self.ssl = None
 
     # save the method args to allow for easier reconnection.
     @irc_functools.save_method_args
     def connect(self, server, port, nickname, password=None, username=None,
-            ircname=None, localaddress="", localport=0, ssl=False, ipv6=False):
+            ircname=None, localaddress="", localport=0, ssl=False, ipv6=False,
+            connect_factory=connection.Factory):
         """Connect/reconnect to a server.
 
         Arguments:
             password -- Password (if any).
             username -- The username.
             ircname -- The IRC name ("realname").
+            server_address -- The remote host/port of the server.
+            connect_factory -- A callable that takes the server address and
+                returns a connection (with a socket interface).
+
+        Deprecated Arguments:
             localaddress -- Bind the connection to a specific local IP address.
             localport -- Bind the connection to a specific local port.
             ssl -- Enable support for ssl.
         log.debug("connect(server=%r, port=%r, nickname=%r, ...)", server,
             port, nickname)
 
+        if localaddress or localport or ssl or ipv6:
+            warnings.warn("localaddress, localport, ssl, and ipv6 parameters "
+                "are deprecated. Use connect_factory instead.",
+                DeprecationWarning)
+            connect_factory.use_legacy_params(localaddress, localport, ssl,
+                ipv6)
+
         if self.connected:
             self.disconnect("Changing servers")
 
         self.real_nickname = nickname
         self.server = server
         self.port = port
+        self.server_address = (server, port)
         self.nickname = nickname
         self.username = username or nickname
         self.ircname = ircname or nickname
         self.password = password
-        self.localaddress = localaddress
-        self.localport = localport
-        self.localhost = socket.gethostname()
-        self.ipv6 = ipv6
-        if ipv6:
-            self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
-        else:
-            self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        self.connect_factory = connect_factory
         try:
-            self.socket.bind((self.localaddress, self.localport))
-            self.socket.connect((self.server, self.port))
-            if ssl:
-                self.ssl = ssl_mod.wrap_socket(self.socket)
+            self.socket = self.connect_factory(self.server_address)
         except socket.error as err:
-            self.socket.close()
             self.socket = None
             raise ServerConnectionError("Couldn't connect to socket: %s" % err)
         self.connected = True
         with self.irclibobj.mutex:
             self.disconnect("Closing object")
             self.irclibobj._remove_connection(self)
+
     def _get_socket(self):
         """[Internal]"""
         return self.socket
         """[Internal]"""
 
         try:
-            reader = self.ssl.read if self.ssl else self.socket.recv
+            reader = getattr(self.socket, 'read', self.socket.recv)
             new_data = reader(2 ** 14)
         except socket.error:
             # The server hung up.
         # clients should not transmit more than 512 bytes.
         if len(bytes) > 512:
             raise ValueError("Messages limited to 512 bytes")
-        sender = self.ssl.write if self.ssl else self.socket.send
         if self.socket is None:
             raise ServerNotConnectedError("Not connected.")
+        sender = getattr(self.socket, 'write', self.socket.send)
         try:
             sender(bytes)
             log.debug("TO SERVER: %s", string)
     def _dcc_disconnect(self, c, e):
         self.dcc_connections.remove(c)
 
-    def connect(self, server, port, nickname, password=None, username=None,
-                ircname=None, localaddress="", localport=0, ssl=False, ipv6=False):
-        """Connect/reconnect to a server.
-
-        Arguments:
-
-            server -- Server name.
-
-            port -- Port number.
-
-            nickname -- The nickname.
-
-            password -- Password (if any).
-
-            username -- The username.
-
-            ircname -- The IRC name.
-
-            localaddress -- Bind the connection to a specific local IP address.
-
-            localport -- Bind the connection to a specific local port.
-
-            ssl -- Enable support for ssl.
-
-            ipv6 -- Enable support for ipv6.
-
-        This function can be called to reconnect a closed connection.
-        """
-        self.connection.connect(server, port, nickname,
-                                password, username, ircname,
-                                localaddress, localport, ssl, ipv6)
+    def connect(self, *args, **kwargs):
+        """Connect using the underlying connection"""
+        self.connection.connect(*args, **kwargs)
 
     def dcc_connect(self, address, port, dcctype="chat"):
         """Connect to a DCC peer.

irc/connection.py

+
+from __future__ import absolute_import
+
+import socket
+import importlib
+
+identity = lambda x: x
+
+class Factory(object):
+    """
+    A class for creating custom socket connections.
+
+    To create a simple connection:
+
+        server_address = ('localhost', 80)
+        Factory()(server_address)
+
+    To create an SSL connection:
+
+        Factory(wrapper=ssl.wrap_socket)(server_address)
+
+    To create an SSL connection with parameters to wrap_socket:
+
+        wrapper = functools.partial(ssl.wrap_socket, ssl_cert=get_cert())
+        Factory(wrapper=wrapper)(server_address)
+
+    Note that Factory doesn't save the state of the socket itself. The
+    caller must do that, as necessary. As a result, the Factory may be
+    re-used to create new connections with the same settings.
+
+    """
+
+    family = socket.AF_INET
+
+    def __init__(self, bind_address=('', 0), wrapper=identity):
+        self.bind_address = bind_address
+        self.wrapper = wrapper
+
+    def from_legacy_params(self, localaddress='', localport=0, ssl=False,
+            ipv6=False):
+        if localaddress or localport:
+            self.bind_address = (localaddress, localport)
+        if ssl:
+            ssl_mod = importlib.importmodule('ssl')
+            self.wrapper = ssl_mod.wrap_socket
+        if ipv6:
+            self.family = socket.AF_INET6
+
+    def connect(self, server_address):
+        sock = self.wrapper(socket.socket(self.family, socket.SOCK_STREAM))
+        sock.bind(self.bind_address)
+        sock.connect(server_address)
+        return sock
+    __call__ = connect
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.