Commits

Matt Joiner committed 8c4db7d

Move a lot of stuff around.

Comments (0)

Files changed (1)

 import random
 import struct
 
-from gthread import queue, socket
+from gthread import queue
+import gthread.socket as socket
 import gthread
 
 from bencoding import buncode, bencode
 
-special_logger = logging.getLogger('special')
-special_logger.addHandler(logging.FileHandler('special.log'))
-special_logger.setLevel(logging.DEBUG)
-special_logger.propagate = False
 
 class ConnectionClosed(Exception): pass
 
     CANCEL = 8
     EXTENDED = 20
 
-    logger = logging.getLogger('conn')
-
     __slots__ = (
-        'socket', 'peer_bitfield',
+        'logger', 'socket', 'peer_bitfield',
         'last_recv_time', 'last_send_time',
         'peer_requests', 'our_requests',
         'peer_choked', 'am_choked',
         'peer_extended_protocols', 'peer_reqq')
 
     protocol = b'\x13BitTorrent protocol'
+    our_extensions = b'\0\0\0\0\0\x10\0\0'
 
-    def __init__(self, socket):
-        self.socket = socket
+    def __init__(self, sock):
+        self.logger = logging.getLogger('peer.{:x}'.format(id(self)))
+        self.logger.addHandler(logging.FileHandler('log/peer.{:x}'.format(id(self))))
+        self.logger.setLevel(1)
+        self.socket = sock
         self.peer_bitfield = set()
         self.last_recv_time = None
         self.last_send_time = None
         self.bytes_sent = 0
         self.bytes_received = 0
         self.out_queue = queue.Queue()
-        #~ self.send(self.protocol + b'\0\0\0\0\0\x10\0\0')
-        self.send(self.protocol + b'\0\0\0\0\0\x00\0\0')
+        self.send(self.protocol + self.our_extensions)
         self.in_buffer = b''
         self.send_routine_exc = None
         import collections
         self.peer_extended_protocols = collections.defaultdict(int)
         self.peer_reqq = 20
+        try:
+            localaddr = sock.getsockname()
+            peeraddr = sock.getpeername()
+        except socket.error as exc:
+            self.logger.warning('Error getting socket endpoints: %s', exc)
+        else:
+            self.logger.info('New connection: %s-%s', localaddr, peeraddr)
 
     def close(self):
-        try:
-            self.socket.shutdown(socket.SHUT_RDWR)
-        except socket.error as exc:
-            if exc.errno != errno.ENOTCONN:
-                raise
+        self.socket.close()
         self.out_queue.put(None)
         self.on_close()
 
+    def __del__(self):
+        self.logger.debug('Deleting %r', self)
+        handlers = self.logger.handlers
+        assert len(handlers) == 1, handlers
+        for h in handlers:
+            self.logger.removeHandler(h)
+
     def __enter__(self):
         return self
 
         self.out_queue.put(data)
 
     def send_routine(self):
-        try:
+        with self:
             while True:
-                data = b''
-                while True:
-                    try:
-                        data1 = self.out_queue.get(block=(not data))
-                    except queue.Empty:
-                        del data1
-                        break
-                    else:
-                        if data1 is None:
-                            return
-                        data += data1
+                data = self.out_queue.get()
+                if data is None:
+                    return
                 try:
                     self.socket.sendall(data)
                 except socket.error as exc:
                     if exc.errno in {errno.EPIPE, errno.ECONNRESET}:
                         return
-                    raise
-                #~ self.logger.debug('Sent %d bytes', len(data))
-        except Exception as exc:
-            self.send_routine_exc = exc
-        finally:
-            self.close()
+                self.logger.debug('Sent %r bytes', len(data))
 
     def recv(self, count):
         data = b''
         while len(data) < count:
             try:
-                data1 = self.socket.recv(count - len(data), socket.MSG_DONTWAIT)
+                data1 = self.socket.recv(count - len(data))
             except socket.error as exc:
                 # TODO ECONNRESET, EAGAIN and EWOULDBLOCK might want special treatment here
                 raise ConnectionClosed(exc)
             if not data1:
                 raise ConnectionClosed()
             data += data1
-        #~ self.logger.debug('Received %d bytes', len(data))
+        self.logger.debug('Received %d bytes', len(data))
         return data
 
     def recv_routine(self):
+        import functools
         self.recv_protocol()
         self.recv_extensions()
         self.recv_info_hash()
         while True:
             length, = struct.unpack('>I', self.recv(4))
             if length == 0:
+                self.logger.info('Received keep-alive')
                 continue
+            self.logger.debug('Next message length: %s', length)
             type = self.recv(1)[0]
-            try:
-                handler = {
-                    self.CHOKE: self.recv_choke,
-                    self.UNCHOKE: self.recv_unchoke,
-                    self.INTERESTED: self.recv_interested,
-                    self.NOT_INTERESTED: self.recv_not_interested,
-                    self.HAVE: self.recv_have,
-                    self.BITFIELD: self.recv_bitfield,
-                    self.REQUEST: self.recv_request,
-                    self.PIECE: self.recv_piece,
-                    self.CANCEL: self.recv_cancel,
-                    self.EXTENDED: self.recv_extended,
-                }[type]
-            except KeyError as exc:
-                special_logger.debug(
-                    '%s: Got unknown message: type=%s, data=%r',
-                    self,
-                    type,
-                    self.recv(length - 1))
-                #~ self.logger.error('%s: Error looking up handler for message: %s', self, exc)
-                #~ return
-            else:
-                handler(length - 1)
+            handler = {
+                self.CHOKE: self.recv_choke,
+                self.UNCHOKE: self.recv_unchoke,
+                self.INTERESTED: self.recv_interested,
+                self.NOT_INTERESTED: self.recv_not_interested,
+                self.HAVE: self.recv_have,
+                self.BITFIELD: self.recv_bitfield,
+                self.REQUEST: self.recv_request,
+                self.PIECE: self.recv_piece,
+                self.CANCEL: self.recv_cancel,
+                self.EXTENDED: self.recv_extended,
+            }.get(type, functools.partial(self.recv_unknown, type))
+            handler(length - 1)
 
     def run(self):
         try:
             self.recv_routine()
         except ConnectionClosed as exc:
             if exc.args:
-                logging.error('%s: Connection terminated: %s', self, exc)
+                self.logger.error('%s: Connection terminated: %s', self, exc)
         finally:
             self.close()
             if self.send_routine_exc:
                 raise self.send_routine_exc
 
+
     def recv_protocol(self):
         self.got_protocol(self.recv(20))
 
     def got_protocol(self, protocol):
         if protocol != self.protocol:
-            self.logger.warning('%s: Peer using unknown protocol: %r', self, protocol)
-            self.close()
+            raise ConnectionClosed('Unknown protocol %s' % protocol)
 
     def recv_extensions(self):
         self.got_extensions(self.recv(8))
 
     def got_extensions(self, extensions):
-        special_logger.debug('%s: Extensions: %r', self, extensions)
+        self.logger.debug('Received extensions: %s', extensions)
         self.peer_extensions = extensions
 
+    # INFO HASH
+
+    def send_info_hash(self, info_hash):
+        assert len(info_hash) == 20, len(info_hash)
+        self.send(info_hash)
+
     def recv_info_hash(self):
         self.got_info_hash(self.recv(20))
 
     def got_info_hash(self, info_hash):
         self.peer_info_hash = info_hash
 
+    # PEER ID
+
+    def send_peer_id(self, peer_id):
+        assert len(peer_id) == 20, peer_id
+        self.send(peer_id)
+
     def recv_peer_id(self):
         self.got_peer_id(self.recv(20))
 
     def got_peer_id(self, peer_id):
         self.peer_peer_id = peer_id
 
+    # keep alive
+
+    def send_keep_alive(self):
+        self.send(struct.pack('>I', 0))
+
+    def recv_unknown(self, type, length):
+        self.logger.warning('Received message with unknown type %r, length %r', type, length)
+        max_length = 0x100
+        data = self.recv(min(0x100, length))
+        self.logger.debug('Unknown message begins: %s', data)
+        if length > max_length:
+            raise ConnectionClosed('Unknown message is too long')
+        else:
+            self.got_unknown(type, data)
+
+    def got_unknown(self, type, data):
+        pass
+
+    # choke
+
+    def send_choke(self):
+        if not self.peer_choked:
+            self.logger.debug('%s: Choking peer', self)
+            self.peer_choked = True
+            self.send_message(self.CHOKE)
+        self.peer_requests.clear()
+
     def recv_choke(self, length):
         if length != 0:
             raise ConnectionClosed('Received choke with length=%d' % length)
         self.am_choked = True
         self.our_requests.clear()
 
+    # unchoke
+
+    def send_unchoke(self):
+        if self.peer_choked:
+            self.logger.debug('%s: Unchoking peer', self)
+            self.peer_choked = False
+            self.send_message(self.UNCHOKE)
+
     def recv_unchoke(self, length):
-        assert length == 0, length
+        if length != 0:
+            raise ConnectionClosed('Received unchoked with length %s' % length)
         self.got_unchoke()
 
     def got_unchoke(self):
         self.logger.debug('%s: Peer unchoked us', self)
         self.am_choked = False
 
+    # interested
+
+    def send_interested(self):
+        if not self.am_interested:
+            self.logger.debug('Sending interest')
+            self.am_interested = True
+            self.send_message(self.INTERESTED)
+
     def recv_interested(self, length):
         if length != 0:
             raise ConnectionClosed('Received interested with length=%d' % length)
         self.got_interested()
 
     def got_interested(self):
+        self.logger.debug('Peer interested')
         self.peer_interested = True
 
+    # not interested
+
+    def send_not_interested(self):
+        if self.am_interested:
+            self.logger.debug('Sending not interested')
+            self.am_interested = False
+            self.send_message(self.NOT_INTERESTED)
+
     def recv_not_interested(self, length):
         assert length == 0, length
         self.got_not_interested()
 
     def got_not_interested(self):
+        self.logger.debug('Peer not interested')
         self.peer_interested = False
 
+    # have
+
+    def send_have(self, index):
+        self.logger.debug('%s: Sending HAVE(%d)', self, index)
+        self.send_message(self.HAVE, struct.pack('>I', index))
+
     def recv_have(self, length):
-        index = struct.unpack('>I', self.recv(length))
+        index = struct.unpack('>I', self.recv(length))[0]
         self.got_have(index)
 
     def got_have(self, index):
+        self.logger.debug('Peer has piece %r', index)
         self.peer_bitfield.add(index)
 
+    # bitfield
+
+    def send_bitfield(self, bytes):
+        if any(bytes):
+            self.logger.debug('%s: Sending bitfield %s', self, bytes)
+            self.send_message(self.BITFIELD, bytes)
+
     def recv_bitfield(self, length):
         import itertools
         index_iter = itertools.count()
                     bitfield.add(index)
         self.got_bitfield(bitfield)
 
+    @staticmethod
+    def pretty_bitfield(bitfield):
+        def ranges(bitfield=bitfield):
+            bitfield = iter(sorted(bitfield))
+            start = end = next(bitfield)
+            for index in bitfield:
+                if index == end + 1:
+                    end = index
+                else:
+                    yield start, end
+                    start = end = index
+            yield start, end
+        return '[{}]'.format(
+            ', '.join(
+                str(s) if s == e else '{s}..{e}'.format(**vars())
+                for s, e in ranges()))
+
     def got_bitfield(self, bitfield):
-        self.logger.debug('%s: Got bitfield: %s', self, bitfield)
+        self.logger.debug('%s: Got bitfield: %s', self, self.pretty_bitfield(bitfield))
         self.peer_bitfield = bitfield
 
     def recv_request(self, length):
         else:
             self.logger.info('%s: Received block %s', self, request)
 
+    # cancel
+
     def recv_cancel(self, length):
         request = struct.unpack('>III', self.recv(length))
         self.got_cancel(request)
         except KeyError:
             self.logger.warning('%s: Peer canceled unknown request %s', self, request)
 
+    # extended
+
     def recv_extended(self, length):
+        import io
         type = self.recv(1)[0]
-        self.got_extended(type, buncode(self.socket.makefile('rb')).__next__())
+        self.logger.debug('Receiving extended message type: %s', type)
+        message = self.recv(length - 1)
+        self.logger.debug('Received extended message payload: %s', message)
+        self.got_extended(type, buncode(io.BytesIO(message)).__next__())
 
     def got_extended(self, type, msg):
         import pprint
-        special_logger.debug('%s: Got extended message (type=%s):\n%s', self, type, pprint.pformat(msg))
+        self.logger.debug('%s: Got extended message (type=%s):\n%s', self, type, pprint.pformat(msg))
         if type == 0:
             self.peer_extended_protocols.update(msg['m'])
             if 'reqq' in msg:
                 self.peer_reqq = msg['reqq']
 
-    # SEND ROUTINES
-
-    def send_info_hash(self, info_hash):
-        assert len(info_hash) == 20, len(info_hash)
-        self.send(info_hash)
-
-    def send_peer_id(self, peer_id):
-        assert len(peer_id) == 20, peer_id
-        self.send(peer_id)
-
-    def send_keep_alive(self):
-        self.send(struct.pack('>I', 0))
-
     def send_message(self, type, payload=b''):
         self.send(struct.pack('>IB', len(payload) + 1, type) + payload)
 
-    def send_have(self, index):
-        self.logger.debug('%s: Sending HAVE(%d)', self, index)
-        self.send_message(self.HAVE, struct.pack('>I', index))
-
-    def send_bitfield(self, bytes):
-        if any(bytes):
-            self.logger.debug('%s: Sending bitfield %s', self, bytes)
-            self.send_message(self.BITFIELD, bytes)
-
-    def send_choke(self):
-        if not self.peer_choked:
-            self.logger.debug('%s: Choking peer', self)
-            self.peer_choked = True
-            self.send_message(self.CHOKE)
-        self.peer_requests.clear()
-
-    def send_unchoke(self):
-        if self.peer_choked:
-            self.logger.debug('%s: Unchoking peer', self)
-            self.peer_choked = False
-            self.send_message(self.UNCHOKE)
-
-    def send_not_interested(self):
-        if self.am_interested:
-            self.logger.debug('%s: Sending disinterest', self)
-            self.am_interested = False
-            self.send_message(self.NOT_INTERESTED)
-
-    def send_interested(self):
-        if not self.am_interested:
-            self.logger.debug('%s: Sending interest', self)
-            self.am_interested = True
-            self.send_message(self.INTERESTED)
-
     def send_request(self, request):
         self.logger.info('%s: Sending request: %s', self, request)
         assert not self.am_choked
 
     def send_extended(self, type, msg):
         assert self.peer_extensions[5] & 0x10, self.peer_extensions
+        assert self.our_extensions[5] & 0x10, self.our_extensions
         assert 'm' in msg, msg
         data = bytes([type]) + bencode(msg)
         self.logger.debug('%s: Sending extended message: type=%d, data=%r', self, type, data)