Commits

Alan Kennedy  committed d3fd05f Merge

merge w/2.5: Tidying up the situation with getsockname() and getpeername()

  • Participants
  • Parent commits 940d38d, e40f741

Comments (0)

Files changed (2)

File Lib/socket.py

     timeout = None
     mode = MODE_BLOCKING
 
-    def getpeername(self):
-        return (self.jsocket.getInetAddress().getHostAddress(), self.jsocket.getPort() )
-
     def config(self, mode, timeout):
         self.mode = mode
         if self.mode == MODE_BLOCKING:
         if how in (SHUT_WR, SHUT_RDWR):
             self.jsocket.shutdownOutput()
 
+    def getsockname(self):
+        return (self.jsocket.getLocalAddress().getHostAddress(), self.jsocket.getLocalPort())
+
+    def getpeername(self):
+        return (self.jsocket.getInetAddress().getHostAddress(), self.jsocket.getPort() )
+
 class _server_socket_impl(_nio_impl):
 
     options = {
         # later cause the user explicit close() call to fail
         pass
 
+    def getsockname(self):
+        return (self.jsocket.getInetAddress().getHostAddress(), self.jsocket.getLocalPort())
+
+    def getpeername(self):
+        # Not a meaningful operation for server sockets.
+        raise error(errno.ENOTCONN, "Socket is not connected")
+
 class _datagram_socket_impl(_nio_impl):
 
     options = {
         else:
             return self._do_receive_nio(0, num_bytes, flags)
 
+    def getsockname(self):
+        return (self.jsocket.getLocalAddress().getHostAddress(), self.jsocket.getLocalPort())
+
+    def getpeername(self):
+        peer_address = self.jsocket.getInetAddress()
+        if peer_address is None:
+            raise error(errno.ENOTCONN, "Socket is not connected")
+        return (peer_address.getHostAddress(), self.jsocket.getPort() )
+
 has_ipv6 = True # IPV6 FTW!
 
 # Name and address functions
         except java.lang.Exception, jlx:
             raise _map_exception(jlx)
 
+    def getsockname(self):
+        try:
+            if self.sock_impl is None:
+                # If the user has already bound an address, return that
+                if self.local_addr:
+                    return self.local_addr
+                # The user has not bound, connected or listened
+                # This is what cpython raises in this scenario
+                raise error(errno.EINVAL, "Invalid argument")
+            return self.sock_impl.getsockname()
+        except java.lang.Exception, jlx:
+            raise _map_exception(jlx)
+
+    def getpeername(self):
+        try:
+            if self.sock_impl is None:
+                raise error(errno.ENOTCONN, "Socket is not connected")
+            return self.sock_impl.getpeername()
+        except java.lang.Exception, jlx:
+            raise _map_exception(jlx)
+
     def _config(self):
         assert self.mode in _permitted_modes
         if self.sock_impl:
 
     sendall = send
 
-    def getsockname(self):
-        try:
-            if not self.sock_impl:
-                host, port = self.local_addr or ("", 0)
-                host = java.net.InetAddress.getByName(host).getHostAddress()
-            else:
-                if self.server:
-                    host = self.sock_impl.jsocket.getInetAddress().getHostAddress()
-                else:
-                    host = self.sock_impl.jsocket.getLocalAddress().getHostAddress()
-                port = self.sock_impl.jsocket.getLocalPort()
-            return (host, port)
-        except java.lang.Exception, jlx:
-            raise _map_exception(jlx)
-
-    def getpeername(self):
-        try:
-            assert self.sock_impl
-            assert not self.server
-            host = self.sock_impl.jsocket.getInetAddress().getHostAddress()
-            port = self.sock_impl.jsocket.getPort()
-            return (host, port)
-        except java.lang.Exception, jlx:
-            raise _map_exception(jlx)
-
     def close(self):
         try:
             if self.istream:
 
     sock_impl = None
     connected = False
+    local_addr = None
 
     def __init__(self):
         _nonblocking_api_mixin.__init__(self)
 
     def bind(self, addr):
-        try:
+        try:            
             assert not self.sock_impl
-            self.sock_impl = _datagram_socket_impl(_get_jsockaddr(addr, self.family, self.type, self.proto, AI_PASSIVE), 
+            assert not self.local_addr
+            # Do the address format check
+            _get_jsockaddr(addr, self.family, self.type, self.proto, 0)
+            self.local_addr = addr
+            self.sock_impl = _datagram_socket_impl(_get_jsockaddr(self.local_addr, self.family, self.type, self.proto, AI_PASSIVE), 
                                                     self.pending_options[ (SOL_SOCKET, SO_REUSEADDR) ])
             self._config()
         except java.lang.Exception, jlx:
             if not self.sock_impl:
                 self.sock_impl = _datagram_socket_impl()
                 self._config()
-                self.sock_impl.connect(_get_jsockaddr(addr, self.family, self.type, self.proto, 0))
+            self.sock_impl.connect(_get_jsockaddr(addr, self.family, self.type, self.proto, 0))
             self.connected = True
         except java.lang.Exception, jlx:
             raise _map_exception(jlx)
         except java.lang.Exception, jlx:
             raise _map_exception(jlx)
 
-    def getsockname(self):
-        try:
-            assert self.sock_impl
-            host = self.sock_impl.jsocket.getLocalAddress().getHostAddress()
-            port = self.sock_impl.jsocket.getLocalPort()
-            return (host, port)
-        except java.lang.Exception, jlx:
-            raise _map_exception(jlx)
-
-    def getpeername(self):
-        try:
-            assert self.sock
-            host = self.sock_impl.jsocket.getInetAddress().getHostAddress()
-            port = self.sock_impl.jsocket.getPort()
-            return (host, port)
-        except java.lang.Exception, jlx:
-            raise _map_exception(jlx)
-
     def __del__(self):
         self.close()
 

File Lib/test/test_socket.py

         else:
             self.fail("Shutdown on unconnected socket should have raised socket exception")
 
+class TestGetSockAndPeerName:
+
+    def testGetpeernameNoImpl(self):
+        try:
+            self.s.getpeername()
+        except socket.error, se:
+            if se[0] == errno.ENOTCONN:
+                return
+        self.fail("getpeername() on unconnected socket should have raised socket.error")
+
+    def testGetsocknameUnboundNoImpl(self):
+        try:
+            self.s.getsockname()
+        except socket.error, se:
+            if se[0] == errno.EINVAL:
+                return
+        self.fail("getsockname() on unconnected socket should have raised socket.error")
+
+    def testGetsocknameBoundNoImpl(self):
+        self.s.bind( ("localhost", 0) )
+        try:
+            self.s.getsockname()
+        except socket.error, se:
+            self.fail("getsockname() on bound socket should have not raised socket.error")
+
+    def testGetsocknameImplCreated(self):
+        self._create_impl_socket()
+        try:
+            self.s.getsockname()
+        except socket.error, se:
+            self.fail("getsockname() on active socket should not have raised socket.error")
+
+    def tearDown(self):
+        self.s.close()
+
+class TestGetSockAndPeerNameTCPClient(unittest.TestCase, TestGetSockAndPeerName):
+
+    def setUp(self):
+        self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        # This server is not needed for all tests, but create it anyway
+        # It uses an ephemeral port, so there should be no port clashes or
+        # problems with reuse.
+        self.server_peer = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        self.server_peer.bind( ("localhost", 0) )
+        self.server_peer.listen(5)
+
+    def _create_impl_socket(self):
+        self.s.connect(self.server_peer.getsockname())
+
+    def testGetpeernameImplCreated(self):
+        self._create_impl_socket()
+        try:
+            self.s.getpeername()
+        except socket.error, se:
+            self.fail("getpeername() on active socket should not have raised socket.error")
+        self.failUnlessEqual(self.s.getpeername(), self.server_peer.getsockname())
+
+    def tearDown(self):
+        self.server_peer.close()
+
+class TestGetSockAndPeerNameTCPServer(unittest.TestCase, TestGetSockAndPeerName):
+
+    def setUp(self):
+        self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+    def _create_impl_socket(self):
+        self.s.bind(("localhost", 0))
+        self.s.listen(5)
+
+    def testGetpeernameImplCreated(self):
+        self._create_impl_socket()
+        try:
+            self.s.getpeername()
+        except socket.error, se:
+            if se[0] == errno.ENOTCONN:
+                return
+        self.fail("getpeername() on listening socket should have raised socket.error")
+
+class TestGetSockAndPeerNameUDP(unittest.TestCase, TestGetSockAndPeerName):
+
+    def setUp(self):
+        self.s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+
+    def _create_impl_socket(self):
+        # Binding is enough to cause socket impl creation
+        self.s.bind(("localhost", 0))
+
+    def testGetpeernameImplCreatedNotConnected(self):
+        self._create_impl_socket()
+        try:
+            self.s.getpeername()
+        except socket.error, se:
+            if se[0] == errno.ENOTCONN:
+                return
+        self.fail("getpeername() on unconnected UDP socket should have raised socket.error")
+
+    def testGetpeernameImplCreatedAndConnected(self):
+        # This test also tests that an UDP socket can be bound and connected at the same time
+        self._create_impl_socket()
+        # Need to connect to an UDP port
+        self._udp_peer = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        self._udp_peer.bind( ("localhost", 0) )
+        self.s.connect(self._udp_peer.getsockname())
+        try:
+            try:
+                self.s.getpeername()
+            except socket.error, se:
+                self.fail("getpeername() on connected UDP socket should not have raised socket.error")
+            self.failUnlessEqual(self.s.getpeername(), self._udp_peer.getsockname())
+        finally:
+            self._udp_peer.close()
+
 def test_main():
     tests = [
         GeneralModuleTests,
         SmallBufferedFileObjectClassTestCase,
         UnicodeTest,
         IDNATest,
+        TestGetSockAndPeerNameTCPClient, 
+        TestGetSockAndPeerNameTCPServer, 
+        TestGetSockAndPeerNameUDP,
     ]
     if hasattr(socket, "socketpair"):
         tests.append(BasicSocketPairTest)