Commits

Matt Joiner committed a2d127f

Yeah whatever.

Comments (0)

Files changed (4)

lib/peer.py

-import errno
-import logging
-import pdb
-import random
-import struct
-
-from . import bencode
-from .util import pretty_bitfield, extract_packed_peer_addrs
-
-
-class Connection:
-
-    KEEP_ALIVE = None
-    CHOKE = 0
-    UNCHOKE = 1
-    INTERESTED = 2
-    NOT_INTERESTED = 3
-    HAVE = 4
-    BITFIELD = 5
-    REQUEST = 6
-    PIECE = 7
-    CANCEL = 8
-    EXTENDED = 20
-
-    protocol = b'\x13BitTorrent protocol'
-    our_extensions = b'\0\0\0\0\0\x10\0\0'
-
-    logger = logging.getLogger('peer')
-
-    class Closed(Exception): pass
-    class ProtocolError(Exception): pass
-
-    def __init__(self, sock, torrent, upload_limiter):
-        self.socket = sock
-        self.torrent = torrent
-        self.upload_limiter = upload_limiter
-        self.peer_bitfield = set()
-        self.last_recv_time = None
-        self.last_send_time = None
-        self.peer_requests = set()
-        self.our_requests = set()
-        self.peer_choked = True
-        self.am_choked = True
-        self.peer_interested = False
-        self.am_interested = False
-        self.bytes_sent = 0
-        self.bytes_received = 0
-        import queue
-        self.out_queue = queue.Queue()
-        self.send(self.protocol + self.our_extensions)
-        self.in_buffer = b''
-        self.send_routine_exc = None
-        import collections
-        self.peer_extended_protocols = collections.defaultdict(int)
-        self.peer_reqq = 20 # how many requests the peer lets us queue
-        try:
-            localaddr = sock.getsockname()
-            peeraddr = sock.getpeername()
-        except socket.error as exc:
-            self.logger.warning('Error getting socket endpoints: %s', exc)
-        else:
-            self.logger.info('New connection: %s-%s', localaddr, peeraddr)
-
-    def close(self):
-        self.socket.close()
-        self.out_queue.put(None)
-
-    def __enter__(self):
-        return self
-
-    def __exit__(self, *args):
-        self.close()
-
-    def send(self, data):
-        self.out_queue.put(data)
-
-    def send_routine(self):
-        with self:
-            import socket, errno
-            while True:
-                data = self.out_queue.get()
-                if data is None:
-                    return
-                self.upload_limiter.reserve(len(data))
-                try:
-                    self.socket.sendall(data)
-                except socket.error as exc:
-                    if exc.errno in {errno.EPIPE, errno.ECONNRESET}:
-                        return
-                    elif exc.errno == errno.EBADF:
-                        assert self.closed
-                        return
-                    raise
-                self.logger.debug('Sent %r bytes', len(data))
-
-    def _recv(self, count):
-        import socket
-        data = b''
-        while len(data) < count:
-            try:
-                data1 = self.socket.recv(count - len(data), socket.MSG_WAITALL)
-            except socket.error as exc:
-                # TODO ECONNRESET, EAGAIN and EWOULDBLOCK might want special treatment here
-                raise self.Closed from exc
-            if not data1:
-                raise self.Closed
-            data += data1
-        self.logger.debug('Received %d bytes', len(data))
-        return data
-
-    def _recv_routine(self):
-        import functools
-        self._recv_protocol()
-        self._recv_extensions()
-        self._recv_info_hash()
-        self._recv_peer_id()
-        while True:
-            length, = struct.unpack('>I', self.recv(4))
-            if length == 0:
-                self.logger.info('Received keep-alive')
-                continue
-            self.logger.debug('Next message length: %s', length)
-            type = self.recv(1)[0]
-            handler = {
-                self.CHOKE: self._recv_choke,
-                self.UNCHOKE: self._recv_unchoke,
-                self.INTERESTED: self._recv_interested,
-                self.NOT_INTERESTED: self._recv_not_interested,
-                self.HAVE: self._recv_have,
-                self.BITFIELD: self._recv_bitfield,
-                self.REQUEST: self._recv_request,
-                self.PIECE: self._recv_piece,
-                self.CANCEL: self._recv_cancel,
-                self.EXTENDED: self._recv_extended,
-            }.get(type, functools.partial(self._recv_unknown, type))
-            handler(length - 1)
-
-    def _recv_protocol(self):
-        self.got_protocol(self.recv(20))
-
-    def got_protocol(self, protocol):
-        if protocol != self.protocol:
-            raise self.ProtocolError('Unknown protocol %s' % protocol)
-
-    # EXTENSIONS
-
-    # send extensions done during __init__
-
-    def _recv_extensions(self):
-        self.got_extensions(self.recv(8))
-
-    def got_extensions(self, extensions):
-        self.logger.debug('Received extensions: %s', extensions)
-        self.peer_extensions = extensions
-
-    # INFO HASH
-
-    def send_info_hash(self, info_hash):
-        assert len(info_hash) == 20, len(info_hash)
-        self.send(info_hash)
-
-    def _recv_info_hash(self):
-        self.got_info_hash(self.recv(20))
-
-    def got_info_hash(self, info_hash):
-        self.peer_info_hash = info_hash
-        self.torrent.peer_sent_info_hash(self)
-
-    # PEER ID
-
-    def send_peer_id(self, peer_id):
-        assert len(peer_id) == 20, peer_id
-        self.send(peer_id)
-
-    def _recv_peer_id(self):
-        self.got_peer_id(self.recv(20))
-
-    def got_peer_id(self, peer_id):
-        self.peer_peer_id = peer_id
-        self.torrent.peer_sent_peer_id(self)
-
-    # keep alive
-
-    def send_keep_alive(self):
-        self.send(struct.pack('>I', 0))
-
-    def _recv_unknown(self, type, length):
-        self.logger.warning('Received message with unknown type %r, length %r', type, length)
-        max_length = 0x100
-        data = self.recv(min(0x100, length))
-        self.logger.debug('Unknown message begins: %s', data)
-        if length > max_length:
-            raise self.ProtocolError('Unknown message is too long')
-        else:
-            self.got_unknown(type, data)
-
-    def got_unknown(self, type, data):
-        pass
-
-    # choke
-
-    def send_choke(self):
-        if not self.peer_choked:
-            self.logger.debug('%s: Choking peer', self)
-            self.peer_choked = True
-            self.send_message(self.CHOKE)
-        self.peer_requests.clear()
-
-    def _recv_choke(self, length):
-        if length != 0:
-            raise self.ProtocolError('Received choke with length=%d' % length)
-        self.got_choke()
-
-    def got_choke(self):
-        self.logger.debug('%s: Peer choked us', self)
-        self.am_choked = True
-        self.our_requests.clear()
-
-    # unchoke
-
-    def send_unchoke(self):
-        if self.peer_choked:
-            self.logger.debug('%s: Unchoking peer', self)
-            self.peer_choked = False
-            self.send_message(self.UNCHOKE)
-
-    def _recv_unchoke(self, length):
-        if length != 0:
-            raise self.ProtocolError('Received unchoked with length %s' % length)
-        self.got_unchoke()
-
-    def got_unchoke(self):
-        self.logger.debug('%s: Peer unchoked us', self)
-        self.am_choked = False
-
-    # interested
-
-    def send_interested(self):
-        if not self.am_interested:
-            self.logger.debug('Sending interest')
-            self.am_interested = True
-            self.send_message(self.INTERESTED)
-
-    def _recv_interested(self, length):
-        if length != 0:
-            raise self.ProtocolError('Received interested with length=%d' % length)
-        self.got_interested()
-
-    def got_interested(self):
-        self.logger.debug('Peer interested')
-        self.peer_interested = True
-
-    # not interested
-
-    def send_not_interested(self):
-        if self.am_interested:
-            self.logger.debug('Sending not interested')
-            self.am_interested = False
-            self.send_message(self.NOT_INTERESTED)
-
-    def _recv_not_interested(self, length):
-        assert length == 0, length
-        self.got_not_interested()
-
-    def got_not_interested(self):
-        self.logger.debug('Peer not interested')
-        self.peer_interested = False
-
-    # have
-
-    def send_have(self, index):
-        self.logger.debug('%s: Sending HAVE(%d)', self, index)
-        self.send_message(self.HAVE, struct.pack('>I', index))
-
-    def _recv_have(self, length):
-        index = struct.unpack('>I', self.recv(length))[0]
-        self.got_have(index)
-
-    def got_have(self, index):
-        self.logger.debug('Peer has piece %r', index)
-        self.peer_bitfield.add(index)
-        self.torrent.top_up_requests(self)
-
-    # bitfield
-
-    def send_bitfield(self, bytes):
-        if any(bytes):
-            self.logger.debug('%s: Sending bitfield %s', self, bytes)
-            self.send_message(self.BITFIELD, bytes)
-
-    def _recv_bitfield(self, length):
-        if length != (self.torrent.piece_count + 7) // 8:
-            raise self.ProtocolError('Bitfield packet has wrong length: {}'.format(length))
-        import itertools
-        index_iter = itertools.count()
-        bitfield = set()
-        for byte in self.recv(length):
-            for bit_shift in range(7, -1, -1):
-                index = next(index_iter)
-                if byte >> bit_shift & 1:
-                    if index not in range(self.torrent.piece_count):
-                        raise self.ProtocolError('Bitfield packet incorrectly sets padding bits')
-                    bitfield.add(index)
-        self.got_bitfield(bitfield)
-
-    def got_bitfield(self, bitfield):
-        if self.logger.isEnabledFor(logging.DEBUG):
-            self.logger.debug('%s: Got bitfield: %s', self, pretty_bitfield(bitfield))
-        assert bitfield <= set(range(self.torrent.piece_count))
-        self.peer_bitfield = bitfield
-        self.torrent.top_up_requests(self)
-
-    def _recv_request(self, length):
-        request = struct.unpack('>III', self.recv(length))
-        self.peer_requests.add(request)
-        self.got_request(request)
-
-    def got_request(self, request):
-        self.logger.debug('Got request %s', request)
-        self.peer_requests.add(request)
-
-    def _recv_piece(self, length):
-        index, offset = struct.unpack_from('>II', self.recv(8))
-        self.got_piece(index, offset, self.recv(length - 8))
-
-    def got_piece(self, index, offset, data):
-        request = index, offset, len(data)
-        try:
-            self.our_requests.remove(request)
-        except KeyError:
-            self.logger.warning('%s: Received unrequested block %s', self, request)
-        else:
-            self.logger.info('%s: Received block %s', self, request)
-
-    # cancel
-
-    def _recv_cancel(self, length):
-        request = struct.unpack('>III', self.recv(length))
-        self.got_cancel(request)
-
-    def got_cancel(self, request):
-        try:
-            self.peer_requests.remove(request)
-        except KeyError:
-            self.logger.warning('%s: Peer canceled unknown request %s', self, request)
-
-    # extended
-
-    def _recv_extended(self, length):
-        import io
-        type = self.recv(1)[0]
-        self.logger.debug('Receiving extended message type: %s', type)
-        message = self.recv(length - 1)
-        self.logger.debug('Received extended message payload: %s', message)
-        self.got_extended(type, next(bencode.decode(io.BytesIO(message))))
-
-    def got_extended(self, type, msg):
-        if self.logger.isEnabledFor(logging.DEBUG):
-            import pprint
-            self.logger.debug(
-                '%s: Got extended message (type=%s):\n%s',
-                self, type, pprint.pformat(msg))
-        if type == 0:
-            self.peer_extended_protocols.update(msg['m'])
-            if 'reqq' in msg:
-                self.peer_reqq = msg['reqq']
-        elif type == 1:
-            addrs = list(extract_packed_peer_addrs(msg.get('added', b'')))
-            self.logger.info('%s: Sent us %d new peer addrs', self, len(addrs))
-            self.torrent.add_swarm_peers(addrs)
-        else:
-            raise ProtocolError('Extended message unknown type: {}'.format(type))
-
-    def send_message(self, type, payload=b''):
-        self.send(struct.pack('>IB', len(payload) + 1, type) + payload)
-
-    def send_request(self, request):
-        self.logger.info('%s: Sending request: %s', self, request)
-        assert not self.am_choked
-        assert self.am_interested
-        assert request not in self.our_requests
-        self.our_requests.add(request)
-        self.send_message(self.REQUEST, struct.pack('>III', *request))
-
-    def send_piece(self, index, begin, data):
-        request = index, begin, len(data)
-        self.logger.debug('%s: Sending piece %s', self, request)
-        self.peer_requests.remove(request)
-        self.send_message(self.PIECE, struct.pack('>II', index, begin) + data)
-        self.bytes_sent += len(data)
-
-    def send_cancel(self, request):
-        try:
-            self.our_requests.remove(request)
-        except KeyError:
-            pass
-        else:
-            self.logger.info('%s: Sending cancel: %s', self, request)
-            self.send_message(self.CANCEL, struct.pack('>III', *request))
-
-    def send_extended(self, type, msg):
-        assert self.extended_protocol_enabled
-        assert 'm' in msg, msg
-        data = bytes([type]) + bencode.encode(msg)
-        import pprint
-        self.logger.debug(
-            'Sending extended message (type=%s):\n%s',
-            type,
-            pprint.pformat(msg))
-        self.send_message(self.EXTENDED, data)
-
-    @property
-    def extended_protocol_enabled(self):
-        return self.our_extensions[5] & 0x10 and self.peer_extensions[5] & 0x10
-
+import errno
+import logging
+import pdb
+import random
+import struct
+
+from . import bencode
+from .util import pretty_bitfield, extract_packed_peer_addrs
+
+KEEP_ALIVE = None
+CHOKE = 0
+UNCHOKE = 1
+INTERESTED = 2
+NOT_INTERESTED = 3
+HAVE = 4
+BITFIELD = 5
+REQUEST = 6
+PIECE = 7
+CANCEL = 8
+EXTENDED = 20
+
+
+class Connection:
+
+    our_protocol = b'\x13BitTorrent protocol'
+    our_extensions = b'\0\0\0\0\0\x10\0\0'
+
+    logger = logging.getLogger('peer')
+
+    class Closed(Exception): pass
+    class ProtocolError(Exception): pass
+
+    def __init__(self, sock):
+        import queue, collections
+        self.socket = sock
+        self.peer_bitfield = set()
+        self.last_recv_time = None
+        self.last_send_time = None
+        self.peer_requests = set()
+        self.our_requests = set()
+        self.peer_choked = True
+        self.am_choked = True
+        self.peer_interested = False
+        self.am_interested = False
+        self.bytes_sent = 0
+        self.bytes_received = 0
+        self.out_queue = queue.Queue()
+        self.peer_extended_protocols = collections.defaultdict(int)
+        self.peer_reqq = 20 # how many requests the peer lets us queue
+        self.peer_protocol = None
+        self.peer_info_hash = None
+        self.peer_peer_id = None
+        self.closed = False
+
+    def send_protocol(self):
+        self.send(self.our_protocol)
+
+    def send_extensions(self):
+        self.send(self.our_extensions)
+
+    def close(self):
+        self.closed = True
+        self.socket.close()
+        self.out_queue.put(None)
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, tp, exc, tb):
+        self.close()
+        if isinstance(exc, self.Closed):
+            self.logger.debug('%s: %s', self, exc)
+            return True
+
+    def send(self, data):
+        if self.closed:
+            raise self.Closed
+        self.out_queue.put(data)
+
+    def _recv(self, count):
+        import socket, errno
+        data = b''
+        while len(data) < count:
+            try:
+                data1 = self.socket.recv(count - len(data), socket.MSG_WAITALL)
+            except socket.error as exc:
+                if exc.errno == errno.EBADF and self.closed:
+                    raise self.Closed
+                elif exc.errno == errno.ECONNRESET:
+                    raise self.Closed(exc) from exc
+                raise
+            if not data1:
+                raise self.Closed
+            data += data1
+        self.logger.debug('Received %d bytes', len(data))
+        return data
+
+    def recv_protocol(self):
+        assert self.peer_protocol is None
+        self.peer_protocol = self._recv(20)
+        return self.peer_protocol
+
+    def recv_extensions(self):
+        self.peer_extensions = self._recv(8)
+
+    def send_info_hash(self, info_hash):
+        assert len(info_hash) == 20, len(info_hash)
+        self.send(info_hash)
+
+    # PEER ID
+
+    def send_peer_id(self, peer_id):
+        assert len(peer_id) == 20, peer_id
+        self.send(peer_id)
+
+    def send_keep_alive(self):
+        self.send(struct.pack('>I', 0))
+
+    # choke
+
+    def send_choke(self):
+        if not self.peer_choked:
+            self.logger.debug('%s: Choking peer', self)
+            self.peer_choked = True
+            self.send_message(CHOKE)
+        self.peer_requests.clear()
+
+    def _recv_choke(self, length):
+        if length != 0:
+            raise self.ProtocolError('Received choke with length=%d' % length)
+        self.got_choke()
+
+    # unchoke
+
+    def send_unchoke(self):
+        if self.peer_choked:
+            self.logger.debug('%s: Unchoking peer', self)
+            self.peer_choked = False
+            self.send_message(UNCHOKE)
+
+    def _recv_unchoke(self, length):
+        if length != 0:
+            raise self.ProtocolError('Received unchoked with length %s' % length)
+        self.got_unchoke()
+
+    def got_unchoke(self):
+        self.logger.debug('%s: Peer unchoked us', self)
+        self.am_choked = False
+
+    # interested
+
+    def send_interested(self):
+        if not self.am_interested:
+            self.logger.debug('Sending interest')
+            self.am_interested = True
+            self.send_message(INTERESTED)
+
+    def _recv_interested(self, length):
+        if length != 0:
+            raise self.ProtocolError('Received interested with length=%d' % length)
+        self.got_interested()
+
+    def got_interested(self):
+        self.logger.debug('Peer interested')
+        self.peer_interested = True
+
+    # not interested
+
+    def send_not_interested(self):
+        if self.am_interested:
+            self.logger.debug('Sending not interested')
+            self.am_interested = False
+            self.send_message(NOT_INTERESTED)
+
+    def _recv_not_interested(self, length):
+        assert length == 0, length
+        self.got_not_interested()
+
+    def got_not_interested(self):
+        self.logger.debug('Peer not interested')
+        self.peer_interested = False
+
+    # have
+
+    def send_have(self, index):
+        self.logger.debug('%s: Sending HAVE(%d)', self, index)
+        self.send_message(HAVE, struct.pack('>I', index))
+
+    def _recv_have(self, length):
+        index = struct.unpack('>I', self.recv(length))[0]
+        self.got_have(index)
+
+    def got_have(self, index):
+        self.logger.debug('Peer has piece %r', index)
+        self.peer_bitfield.add(index)
+        self.torrent.top_up_requests(self)
+
+    # bitfield
+
+    def send_bitfield(self, bytes):
+        if any(bytes):
+            self.logger.debug('%s: Sending bitfield %s', self, bytes)
+            self.send_message(BITFIELD, bytes)
+
+    def recv_bitfield(self, length):
+        def decode(buf):
+            for byte in buf:
+                assert not byte & ~0xff, byte
+                for bit_shift in range(7, -1, -1):
+                    yield {0: False, 1: True}[byte >> bit_shift & 1]
+        return set(index for index, have in enumerate(decode(self._recv(length))))
+
+    def got_bitfield(self, bitfield):
+        if self.logger.isEnabledFor(logging.DEBUG):
+            self.logger.debug('%s: Got bitfield: %s', self, pretty_bitfield(bitfield))
+        assert bitfield <= set(range(self.torrent.piece_count))
+        self.peer_bitfield = bitfield
+        self.torrent.top_up_requests(self)
+
+    def _recv_request(self, length):
+        request = struct.unpack('>III', self.recv(length))
+        self.peer_requests.add(request)
+        self.got_request(request)
+
+    def got_request(self, request):
+        self.logger.debug('Got request %s', request)
+        self.peer_requests.add(request)
+
+    def _recv_piece(self, length):
+        index, offset = struct.unpack_from('>II', self.recv(8))
+        self.got_piece(index, offset, self.recv(length - 8))
+
+    def got_piece(self, index, offset, data):
+        request = index, offset, len(data)
+        try:
+            self.our_requests.remove(request)
+        except KeyError:
+            self.logger.warning('%s: Received unrequested block %s', self, request)
+        else:
+            self.logger.info('%s: Received block %s', self, request)
+
+    # extended
+
+    def _recv_extended(self, length):
+        import io
+        type = self.recv(1)[0]
+        self.logger.debug('Receiving extended message type: %s', type)
+        message = self.recv(length - 1)
+        self.logger.debug('Received extended message payload: %s', message)
+        self.got_extended(type, next(bencode.decode(io.BytesIO(message))))
+
+    def got_extended(self, type, msg):
+        if self.logger.isEnabledFor(logging.DEBUG):
+            import pprint
+            self.logger.debug(
+                '%s: Got extended message (type=%s):\n%s',
+                self, type, pprint.pformat(msg))
+        if type == 0:
+            self.peer_extended_protocols.update(msg['m'])
+            if 'reqq' in msg:
+                self.peer_reqq = msg['reqq']
+        elif type == 1:
+            addrs = list(extract_packed_peer_addrs(msg.get('added', b'')))
+            self.logger.info('%s: Sent us %d new peer addrs', self, len(addrs))
+            self.torrent.add_swarm_peers(addrs)
+        else:
+            raise self.ProtocolError('Extended message unknown type: {}'.format(type))
+
+    def send_message(self, type, payload=b''):
+        self.send(struct.pack('>IB', len(payload) + 1, type) + payload)
+
+    def send_request(self, request):
+        self.logger.info('%s: Sending request: %s', self, request)
+        assert not self.am_choked
+        assert self.am_interested
+        assert request not in self.our_requests
+        self.our_requests.add(request)
+        self.send_message(REQUEST, struct.pack('>III', *request))
+
+    def send_piece(self, index, begin, data):
+        request = index, begin, len(data)
+        self.logger.debug('%s: Sending piece %s', self, request)
+        self.peer_requests.remove(request)
+        self.send_message(PIECE, struct.pack('>II', index, begin) + data)
+        self.bytes_sent += len(data)
+
+    def send_cancel(self, request):
+        try:
+            self.our_requests.remove(request)
+        except KeyError:
+            pass
+        else:
+            self.logger.info('%s: Sending cancel: %s', self, request)
+            self.send_message(CANCEL, struct.pack('>III', *request))
+
+    def send_extended(self, type, msg):
+        assert self.extended_protocol_enabled
+        assert 'm' in msg, msg
+        data = bytes([type]) + bencode.encode(msg)
+        import pprint
+        self.logger.debug(
+            'Sending extended message (type=%s):\n%s',
+            type,
+            pprint.pformat(msg))
+        self.send_message(EXTENDED, data)
+
+    @property
+    def extended_protocol_enabled(self):
+        return self.our_extensions[5] & 0x10 and self.peer_extensions[5] & 0x10
+
+    def recv_message(self):
+        length = struct.unpack('>I', self._recv(4))[0] - 1
+        if length == -1:
+            return KEEP_ALIVE,
+        type = self._recv(1)[0]
+        if type in {CHOKE, UNCHOKE, INTERESTED, NOT_INTERESTED}:
+            if length != 0:
+                raise self.ProtocolError('Invalid message length')
+            return type,
+        elif type == HAVE:
+            return type, struct.unpack('>I', self._recv(length))
+        elif type == BITFIELD:
+            return type, self.recv_bitfield(length)
+        elif type == REQUEST:
+            return type, struct.unpack('>III', self._recv(length))
+        elif type == PIECE:
+            return type, struct.unpack('>II', self._recv(8)) + (self._recv(length - 8),)
+        elif type == CANCEL:
+            return type, struct.unpack('>III', self._recv(length))
+        elif type == EXTENDED:
+            return type, self.recv_extended(length)
+        else:
+            raise self.ProtocolError('Unknown message type: {}'.format(type))
+
+    def recv_extended(self, length):
+        import io
+        buf = self._recv(length)
+        msg_it = bencode.decode(io.BytesIO(buf[1:]))
+        try:
+            msg = next(msg_it)
+        except StopIteration:
+            raise self.ProtocolError('Invalid extended message')
+        try:
+            next(msg_it)
+        except StopIteration:
+            return buf[0], msg
+        raise self.ProtocolError('Invalid extended message')
+
+    def recv_info_hash(self):
+        assert self.peer_info_hash is None
+        self.peer_info_hash = self._recv(20)
+        return self.peer_info_hash
+
+    def recv_peer_id(self):
+        assert self.peer_peer_id is None
+        self.peer_peer_id = self._recv(20)
+        return self.peer_peer_id
         super().__init__()
         self.level = 0
         self._lock = threading.Lock()
-        self._below_zero = threading.Condition(self._lock)
-        self._not_full = threading.Condition(self._lock)
         self._above_zero = threading.Condition(self._lock)
         self.rate = rate
         self.interval = interval
 from .util import pretty_bitfield
 from .trackers import AnnounceError
-from . import peer
+from . import pwp
 
 
 class Torrent:
                 bytes[index // 8] |= 1 << 7 - index % 8
         return bytes
 
-    def peer_sent_info_hash(self, conn):
-        if conn.peer_info_hash != self.info_hash:
-            self.logger.warning(
-                "%s: Peer %r's info hash does not match: %s",
-                self, conn, conn.peer_info_hash)
-            conn.close()
-
     def routine_wrapper(routine):
         def wrapper(self, *args, **kwds):
             try:
                 self.logger.debug('Routine %s returned', routine)
         return wrapper
 
+    def spawn(self, target, *args, **kwargs):
+        import threading
+        thread = threading.Thread(target=target, args=args, kwargs=kwargs)
+        thread.name = target
+        thread.daemon = True
+        thread.start()
+        return thread
+
+    @routine_wrapper
+    def peer_conn_send_routine(self, conn):
+        with conn:
+            import socket, errno
+            while not conn.closed:
+                data = conn.out_queue.get()
+                if data is None:
+                    assert conn.closed
+                    return
+                self.upload_limiter.reserve(len(data))
+                try:
+                    conn.socket.sendall(data)
+                except socket.error as exc:
+                    if exc.errno in {errno.EPIPE, errno.ECONNRESET}:
+                        return
+                    elif exc.errno == errno.EBADF:
+                        assert conn.closed
+                        return
+                    raise
+                self.logger.debug('%s: Sent %r bytes', conn, len(data))
+
     @routine_wrapper
     def peer_routine(self, sock, addr):
+        import threading, socket
         # create and connect a socket if a socket wasn't given
         if sock is None:
             with self.half_open:
-                import socket
                 sock = socket.socket()
                 sock.settimeout(30)
                 try:
                     sock.connect(addr)
                 except socket.error as exc:
-                    self.logger.info('Connecting to %s: %s', addr, exc)
+                    self.logger.info('%s: %s', addr, exc)
                     return
                 else:
-                    self.logger.info('Connected to %s', addr)
+                    self.logger.info('%s: Connected to %r', self, addr)
                     sock.setblocking(True)
 
-        with sock:
-            with peer.Connection(
-                sock,
-                self,
-                upload_limiter=self.upload_limiter
-            ) as conn:
-                import threading
-                threading.current_thread().name = conn
-                conn.send_info_hash(self.info_hash)
-                conn.send_peer_id(self.peer_id)
-                conn.run()
-
-    def peer_sent_peer_id(self, conn):
-        # add the connection to the active peers
-        new_peer_id = conn.peer_peer_id
-        with self.active_peers_lock:
-            for conn1 in self.active_peers:
-                if conn1.peer_peer_id == new_peer_id:
-                    self.logger.warning(
-                        '%s: Already connected to peer %s', self, new_peer_id)
-                    assert conn1 is not conn
-                    conn.close()
-                    break
-            else:
-                self.active_peers.add(conn)
-                conn.send_bitfield(self.bitfield_bytes())
-                if conn.extended_protocol_enabled:
-                    import socket
-                    conn.send_extended(0, {
-                        'm': {'ut_pex': 1},
-                        'reqq': 5,
-                        'v': 'erutor-0.1',
-                        'yourip': socket.inet_aton(conn.socket.getpeername()[0]),
-                        'p': self.port})
+        with pwp.Connection(sock) as conn:
+            threading.current_thread().name = conn
+            self.spawn(self.peer_conn_send_routine, conn)
+            conn.send_protocol()
+            conn.send_extensions()
+            conn.send_info_hash(self.info_hash)
+            conn.send_peer_id(self.peer_id)
+            try:
+                conn.recv_protocol()
+                conn.recv_extensions()
+                conn.recv_info_hash()
+                if conn.peer_info_hash != self.info_hash:
+                    self.logger.debug(
+                        '%s: Peer sent wrong info hash: %s',
+                        conn, conn.peer_info_hash)
+                    return
+                conn.recv_peer_id()
+            except conn.Closed as exc:
+                self.logger.debug('%s: %s', conn, exc)
+                return
+            with self.active_peers_lock:
+                if self.closed:
+                    return
+                for conn1 in self.active_peers:
+                    if conn1.peer_peer_id == conn.peer_peer_id:
+                        self.logger.warning(
+                            '%s: Already connected to peer %s',
+                            self, conn.peer_peer_id)
+                        assert conn1 is not conn
+                        return
+            self.active_peers.add(conn)
+            conn.send_bitfield(self.bitfield_bytes())
+            if conn.extended_protocol_enabled:
+                import socket
+                conn.send_extended(0, {
+                    'm': {'ut_pex': 1},
+                    'reqq': 5,
+                    'v': 'erutor-0.1',
+                    'yourip': socket.inet_aton(conn.socket.getpeername()[0]),
+                    'p': self.port})
+            from .pwp import BITFIELD, HAVE
+            while not conn.closed:
+                type, *fields = conn.recv_message()
+                if type == BITFIELD:
+                    if self.wanted_pieces() & conn.peer_bitfield and not conn.am_interested:
+                        conn.send_interested()
+                elif type == HAVE:
+                    index, = fields
+                    if self.want_piece(index) and not conn.am_interested:
+                        conn.send_interested()
 
     def peer_disconnected(self, conn):
         self.logger.debug('%s: Removing connection: %s', self, conn)
                 self.cond.notify_all()
 
     def activate_peer(self, sock, addr):
-        import threading
-        thread = threading.Thread(target=self.peer_routine, args=[sock, addr])
-        thread.daemon = True
-        thread.start()
+        self.spawn(self.peer_routine, sock, addr)
 
     @routine_wrapper
     def connect_routine(self):
             self.logger.warning('Closing %s', self)
             self.closed = True
             self.closed_cond.notify_all()
+        with self.active_peers_lock:
+            for conn in self.active_peers:
+                conn.close()
         with self.swarm_addrs_not_empty:
             self.swarm_addrs_not_empty.notify_all()
         import socket
             self.logger.info(
                 'Missing pieces: %s',
                 pretty_bitfield(self.wanted_blocks.keys()))
-            import threading
             for tracker in self.trackers:
-                thread = threading.Thread(target=self.tracker_routine, args=[tracker])
-                thread.daemon = True
-                thread.start()
+                self.spawn(self.tracker_routine, tracker)
             for target in self.accept_routine, self.connect_routine, self.stat_routine:
-                thread = threading.Thread(target=target)
-                thread.start()
+                self.spawn(target)
             self.wait_closed()
         finally:
             self.close()