Commits

Matt Joiner  committed 1970db9

Save progress so far, might drop this branch.

  • Participants
  • Parent commits 56bb4f8

Comments (0)

Files changed (5)

File download-torrent

 
     from lib import Torrent, Metainfo, bencode
     import lib
+    from lib import throttle
 
+    upload_limiter = lib.throttle.Limiter(rate=5<<10, interval=0.1)
+    upload_limiter.start()
     metainfo = Metainfo(next(bencode.decode(open(namespace.torrent, 'rb'))))
     data = lib.open_data(metainfo, namespace.destination)
     logging.info('Info hash: %s', binascii.b2a_hex(metainfo.info_hash).decode())
         metainfo=metainfo,
         data=data,
         socket=sock,
-        trackers=trackers)
+        trackers=trackers,
+        upload_limiter=upload_limiter)
     print(str(torrent.completion() * 100) + '%')
     torrent.run()
 

File lib/metainfo.py

     @property
     def tracker_urls(self):
         yield self.raw['announce'].decode()
-        for url in self.raw.get('announce-list', ()):
-            yield url.decode()
+        for tier in self.raw.get('announce-list', ()):
+            for url in tier:
+                yield url.decode()
 import random
 import struct
 
-from gthread import queue
-import gthread.socket as socket
-import gthread
+from . import bencode
+from .util import pretty_bitfield, extract_packed_peer_addrs
 
-from bencoding import buncode, bencode
-from util import pretty_bitfield
-
-
-class ConnectionClosed(Exception): pass
 
 class Connection:
 
+    KEEP_ALIVE = None
     CHOKE = 0
     UNCHOKE = 1
     INTERESTED = 2
     CANCEL = 8
     EXTENDED = 20
 
-    __slots__ = (
-        'logger', 'socket', 'peer_bitfield',
-        'last_recv_time', 'last_send_time',
-        'peer_requests', 'our_requests',
-        'peer_choked', 'am_choked',
-        'peer_interested', 'am_interested',
-        'bytes_sent', 'bytes_received',
-        'out_queue', 'in_buffer', 'send_routine_exc',
-        'peer_extensions', 'peer_info_hash', 'peer_peer_id',
-        'peer_extended_protocols', 'peer_reqq')
-
     protocol = b'\x13BitTorrent protocol'
     our_extensions = b'\0\0\0\0\0\x10\0\0'
 
-    def __init__(self, sock):
-        self.logger = logging.getLogger('peer.{:x}'.format(id(self)))
-        self.logger.addHandler(logging.FileHandler('log/peer.{:x}'.format(id(self))))
-        self.logger.setLevel(1)
+    logger = logging.getLogger('peer')
+
+    class Closed(Exception): pass
+    class ProtocolError(Exception): pass
+
+    def __init__(self, sock, torrent, upload_limiter):
         self.socket = sock
+        self.torrent = torrent
+        self.upload_limiter = upload_limiter
         self.peer_bitfield = set()
         self.last_recv_time = None
         self.last_send_time = None
         self.am_interested = False
         self.bytes_sent = 0
         self.bytes_received = 0
+        import queue
         self.out_queue = queue.Queue()
         self.send(self.protocol + self.our_extensions)
         self.in_buffer = b''
         self.send_routine_exc = None
         import collections
         self.peer_extended_protocols = collections.defaultdict(int)
-        self.peer_reqq = 20
+        self.peer_reqq = 20 # how many requests the peer lets us queue
         try:
             localaddr = sock.getsockname()
             peeraddr = sock.getpeername()
     def close(self):
         self.socket.close()
         self.out_queue.put(None)
-        self.on_close()
-
-    def __del__(self):
-        self.logger.debug('Deleting %r', self)
-        handlers = self.logger.handlers
-        assert len(handlers) == 1, handlers
-        for h in handlers:
-            self.logger.removeHandler(h)
 
     def __enter__(self):
         return self
 
     def send_routine(self):
         with self:
+            import socket, errno
             while True:
                 data = self.out_queue.get()
                 if data is None:
                     return
+                self.upload_limiter.reserve(len(data))
                 try:
                     self.socket.sendall(data)
                 except socket.error as exc:
                     if exc.errno in {errno.EPIPE, errno.ECONNRESET}:
                         return
+                    elif exc.errno == errno.EBADF:
+                        assert self.closed
+                        return
+                    raise
                 self.logger.debug('Sent %r bytes', len(data))
 
-    def recv(self, count):
+    def _recv(self, count):
+        import socket
         data = b''
         while len(data) < count:
             try:
-                data1 = self.socket.recv(count - len(data))
+                data1 = self.socket.recv(count - len(data), socket.MSG_WAITALL)
             except socket.error as exc:
                 # TODO ECONNRESET, EAGAIN and EWOULDBLOCK might want special treatment here
-                raise ConnectionClosed(exc)
+                raise self.Closed from exc
             if not data1:
-                raise ConnectionClosed()
+                raise self.Closed
             data += data1
         self.logger.debug('Received %d bytes', len(data))
         return data
 
-    def recv_routine(self):
+    def _recv_routine(self):
         import functools
-        self.recv_protocol()
-        self.recv_extensions()
-        self.recv_info_hash()
-        self.recv_peer_id()
+        self._recv_protocol()
+        self._recv_extensions()
+        self._recv_info_hash()
+        self._recv_peer_id()
         while True:
             length, = struct.unpack('>I', self.recv(4))
             if length == 0:
             self.logger.debug('Next message length: %s', length)
             type = self.recv(1)[0]
             handler = {
-                self.CHOKE: self.recv_choke,
-                self.UNCHOKE: self.recv_unchoke,
-                self.INTERESTED: self.recv_interested,
-                self.NOT_INTERESTED: self.recv_not_interested,
-                self.HAVE: self.recv_have,
-                self.BITFIELD: self.recv_bitfield,
-                self.REQUEST: self.recv_request,
-                self.PIECE: self.recv_piece,
-                self.CANCEL: self.recv_cancel,
-                self.EXTENDED: self.recv_extended,
-            }.get(type, functools.partial(self.recv_unknown, type))
+                self.CHOKE: self._recv_choke,
+                self.UNCHOKE: self._recv_unchoke,
+                self.INTERESTED: self._recv_interested,
+                self.NOT_INTERESTED: self._recv_not_interested,
+                self.HAVE: self._recv_have,
+                self.BITFIELD: self._recv_bitfield,
+                self.REQUEST: self._recv_request,
+                self.PIECE: self._recv_piece,
+                self.CANCEL: self._recv_cancel,
+                self.EXTENDED: self._recv_extended,
+            }.get(type, functools.partial(self._recv_unknown, type))
             handler(length - 1)
 
-    def run(self):
-        try:
-            gthread.spawn(self.send_routine)
-            self.recv_routine()
-        except ConnectionClosed as exc:
-            if exc.args:
-                self.logger.error('%s: Connection terminated: %s', self, exc)
-        finally:
-            self.close()
-            if self.send_routine_exc:
-                raise self.send_routine_exc
-
-
-    def recv_protocol(self):
+    def _recv_protocol(self):
         self.got_protocol(self.recv(20))
 
     def got_protocol(self, protocol):
         if protocol != self.protocol:
-            raise ConnectionClosed('Unknown protocol %s' % protocol)
+            raise self.ProtocolError('Unknown protocol %s' % protocol)
 
-    def recv_extensions(self):
+    # EXTENSIONS
+
+    # send extensions done during __init__
+
+    def _recv_extensions(self):
         self.got_extensions(self.recv(8))
 
     def got_extensions(self, extensions):
         assert len(info_hash) == 20, len(info_hash)
         self.send(info_hash)
 
-    def recv_info_hash(self):
+    def _recv_info_hash(self):
         self.got_info_hash(self.recv(20))
 
     def got_info_hash(self, info_hash):
         self.peer_info_hash = info_hash
+        self.torrent.peer_sent_info_hash(self)
 
     # PEER ID
 
         assert len(peer_id) == 20, peer_id
         self.send(peer_id)
 
-    def recv_peer_id(self):
+    def _recv_peer_id(self):
         self.got_peer_id(self.recv(20))
 
     def got_peer_id(self, peer_id):
         self.peer_peer_id = peer_id
+        self.torrent.peer_sent_peer_id(self)
 
     # keep alive
 
     def send_keep_alive(self):
         self.send(struct.pack('>I', 0))
 
-    def recv_unknown(self, type, length):
+    def _recv_unknown(self, type, length):
         self.logger.warning('Received message with unknown type %r, length %r', type, length)
         max_length = 0x100
         data = self.recv(min(0x100, length))
         self.logger.debug('Unknown message begins: %s', data)
         if length > max_length:
-            raise ConnectionClosed('Unknown message is too long')
+            raise self.ProtocolError('Unknown message is too long')
         else:
             self.got_unknown(type, data)
 
             self.send_message(self.CHOKE)
         self.peer_requests.clear()
 
-    def recv_choke(self, length):
+    def _recv_choke(self, length):
         if length != 0:
-            raise ConnectionClosed('Received choke with length=%d' % length)
+            raise self.ProtocolError('Received choke with length=%d' % length)
         self.got_choke()
 
     def got_choke(self):
             self.peer_choked = False
             self.send_message(self.UNCHOKE)
 
-    def recv_unchoke(self, length):
+    def _recv_unchoke(self, length):
         if length != 0:
-            raise ConnectionClosed('Received unchoked with length %s' % length)
+            raise self.ProtocolError('Received unchoked with length %s' % length)
         self.got_unchoke()
 
     def got_unchoke(self):
             self.am_interested = True
             self.send_message(self.INTERESTED)
 
-    def recv_interested(self, length):
+    def _recv_interested(self, length):
         if length != 0:
-            raise ConnectionClosed('Received interested with length=%d' % length)
+            raise self.ProtocolError('Received interested with length=%d' % length)
         self.got_interested()
 
     def got_interested(self):
             self.am_interested = False
             self.send_message(self.NOT_INTERESTED)
 
-    def recv_not_interested(self, length):
+    def _recv_not_interested(self, length):
         assert length == 0, length
         self.got_not_interested()
 
         self.logger.debug('%s: Sending HAVE(%d)', self, index)
         self.send_message(self.HAVE, struct.pack('>I', index))
 
-    def recv_have(self, length):
+    def _recv_have(self, length):
         index = struct.unpack('>I', self.recv(length))[0]
         self.got_have(index)
 
     def got_have(self, index):
         self.logger.debug('Peer has piece %r', index)
         self.peer_bitfield.add(index)
+        self.torrent.top_up_requests(self)
 
     # bitfield
 
             self.logger.debug('%s: Sending bitfield %s', self, bytes)
             self.send_message(self.BITFIELD, bytes)
 
-    def recv_bitfield(self, length):
+    def _recv_bitfield(self, length):
+        if length != (self.torrent.piece_count + 7) // 8:
+            raise self.ProtocolError('Bitfield packet has wrong length: {}'.format(length))
         import itertools
         index_iter = itertools.count()
         bitfield = set()
         for byte in self.recv(length):
             for bit_shift in range(7, -1, -1):
-                index = index_iter.__next__()
+                index = next(index_iter)
                 if byte >> bit_shift & 1:
+                    if index not in range(self.torrent.piece_count):
+                        raise self.ProtocolError('Bitfield packet incorrectly sets padding bits')
                     bitfield.add(index)
         self.got_bitfield(bitfield)
 
     def got_bitfield(self, bitfield):
-        self.logger.debug('%s: Got bitfield: %s', self, pretty_bitfield(bitfield))
+        if self.logger.isEnabledFor(logging.DEBUG):
+            self.logger.debug('%s: Got bitfield: %s', self, pretty_bitfield(bitfield))
+        assert bitfield <= set(range(self.torrent.piece_count))
         self.peer_bitfield = bitfield
+        self.torrent.top_up_requests(self)
 
-    def recv_request(self, length):
+    def _recv_request(self, length):
         request = struct.unpack('>III', self.recv(length))
         self.peer_requests.add(request)
         self.got_request(request)
         self.logger.debug('Got request %s', request)
         self.peer_requests.add(request)
 
-    def recv_piece(self, length):
+    def _recv_piece(self, length):
         index, offset = struct.unpack_from('>II', self.recv(8))
         self.got_piece(index, offset, self.recv(length - 8))
 
 
     # cancel
 
-    def recv_cancel(self, length):
+    def _recv_cancel(self, length):
         request = struct.unpack('>III', self.recv(length))
         self.got_cancel(request)
 
 
     # extended
 
-    def recv_extended(self, length):
+    def _recv_extended(self, length):
         import io
         type = self.recv(1)[0]
         self.logger.debug('Receiving extended message type: %s', type)
         message = self.recv(length - 1)
         self.logger.debug('Received extended message payload: %s', message)
-        self.got_extended(type, buncode(io.BytesIO(message)).__next__())
+        self.got_extended(type, next(bencode.decode(io.BytesIO(message))))
 
     def got_extended(self, type, msg):
-        import pprint
-        self.logger.debug('%s: Got extended message (type=%s):\n%s', self, type, pprint.pformat(msg))
+        if self.logger.isEnabledFor(logging.DEBUG):
+            import pprint
+            self.logger.debug(
+                '%s: Got extended message (type=%s):\n%s',
+                self, type, pprint.pformat(msg))
         if type == 0:
             self.peer_extended_protocols.update(msg['m'])
             if 'reqq' in msg:
                 self.peer_reqq = msg['reqq']
+        elif type == 1:
+            addrs = list(extract_packed_peer_addrs(msg.get('added', b'')))
+            self.logger.info('%s: Sent us %d new peer addrs', self, len(addrs))
+            self.torrent.add_swarm_peers(addrs)
+        else:
+            raise ProtocolError('Extended message unknown type: {}'.format(type))
 
     def send_message(self, type, payload=b''):
         self.send(struct.pack('>IB', len(payload) + 1, type) + payload)
     def send_request(self, request):
         self.logger.info('%s: Sending request: %s', self, request)
         assert not self.am_choked
-        if request in self.our_requests:
-            self.logger.warning('%s: Tried to send duplicate request: %s', self, request)
-        else:
-            self.our_requests.add(request)
-            self.send_message(self.REQUEST, struct.pack('>III', *request))
+        assert self.am_interested
+        assert request not in self.our_requests
+        self.our_requests.add(request)
+        self.send_message(self.REQUEST, struct.pack('>III', *request))
 
     def send_piece(self, index, begin, data):
         request = index, begin, len(data)
             self.send_message(self.CANCEL, struct.pack('>III', *request))
 
     def send_extended(self, type, msg):
-        assert self.peer_extensions[5] & 0x10, self.peer_extensions
-        assert self.our_extensions[5] & 0x10, self.our_extensions
+        assert self.extended_protocol_enabled
         assert 'm' in msg, msg
-        data = bytes([type]) + bencode(msg)
+        data = bytes([type]) + bencode.encode(msg)
         import pprint
         self.logger.debug(
             'Sending extended message (type=%s):\n%s',
             pprint.pformat(msg))
         self.send_message(self.EXTENDED, data)
 
+    @property
+    def extended_protocol_enabled(self):
+        return self.our_extensions[5] & 0x10 and self.peer_extensions[5] & 0x10
+

File lib/throttle.py

+import threading
+
+
+class Limiter(threading.Thread):
+
+    def __init__(self, rate, interval):
+        super().__init__()
+        self.level = 0
+        self._lock = threading.Lock()
+        self._below_zero = threading.Condition(self._lock)
+        self._not_full = threading.Condition(self._lock)
+        self._above_zero = threading.Condition(self._lock)
+        self.rate = rate
+        self.interval = interval
+
+    def run(self):
+        import time
+        now = time.time()
+        next = now + self.interval
+        while True:
+            while now < next:
+                time.sleep(next - now)
+                now = time.time()
+            with self._lock:
+                self.level = min(self.level + self.rate, self.rate)
+                if self.level > 0:
+                    self._above_zero.notify_all()
+            next += self.interval
+
+    def report(self, amount):
+        with self._lock:
+            self.level -= amount
+
+    def reserve(self, amount):
+        with self._lock:
+            while self.level < 0:
+                self._above_zero.wait()
+            self.level -= amount
+
+
+
+
+
+
+
+

File lib/torrent.py

 from .util import pretty_bitfield
 from .trackers import AnnounceError
+from . import peer
 
 
 class Torrent:
     def left(self):
         return sum(block[1] for blocks in self.wanted_blocks.values() for block in blocks)
 
-    def __init__(self, metainfo, data, socket, limiter=None, trackers=()):
+    def __init__(self, metainfo, data, socket, upload_limiter=None, trackers=()):
         import threading, os
         self.closed = False
         self.closed_cond = threading.Condition()
                 self.wanted_blocks[index] = set(self.piece_blocks(index))
         self.swarm_addrs = set()
         self.swarm_addrs_not_empty = threading.Condition()
-        self.active_peers = {} # peer_id: connection
+        self.active_peers = set()
+        self.active_peers_lock = threading.RLock()
         self.trackers = trackers
         self.name = metainfo.raw['info']['name'].decode()
+        self.upload_limiter = upload_limiter
 
     @property
     def port(self):
             if self.swarm_addrs:
                 self.swarm_addrs_not_empty.notify_all()
 
+    def wanted_pieces(self):
+        return {piece for piece, blocks in self.wanted_blocks.items() if blocks}
+
     def want_piece(self, index):
         return index in self.wanted_blocks
 
-    def all_requests(self):
-        requests = multiset()
-        with self.lock:
-            for conn in self.active_peers.values():
-                requests.update(conn.our_requests)
-        return requests
-
     def top_up_requests(self, conn):
         request_generator = self.generate_requests(conn)
         try:
             request_generator.close()
 
     def generate_requests(self, conn):
-        with self.lock:
+        with self.active_peers_lock:
             all_requests = set()
-            for conn in self.active_peers.values():
+            for conn in self.active_peers:
                 all_requests.update(conn.our_requests)
             for index in self.wanted_blocks.keys() & conn.peer_bitfield:
                 choices = {(index,) + block for block in self.wanted_blocks[index]}
 
     def peer_sent_info_hash(self, conn):
         if conn.peer_info_hash != self.info_hash:
-            self.logger.warning("%s: Peer %r's info hash does not match: %s", self, conn, conn.peer_info_hash)
+            self.logger.warning(
+                "%s: Peer %r's info hash does not match: %s",
+                self, conn, conn.peer_info_hash)
             conn.close()
 
     def routine_wrapper(routine):
                 return routine(self, *args, **kwds)
             except:
                 self.logger.exception('%s', self)
+                self.close()
             finally:
                 self.logger.debug('Routine %s returned', routine)
-                self.close()
         return wrapper
 
     @routine_wrapper
     def peer_routine(self, sock, addr):
-        # create and connect the socket if a socket wasn't given
+        # create and connect a socket if a socket wasn't given
         if sock is None:
             with self.half_open:
+                import socket
                 sock = socket.socket()
                 sock.settimeout(30)
                 try:
                     return
                 else:
                     self.logger.info('Connected to %s', addr)
+                    sock.setblocking(True)
 
         with sock:
-            sock.setblocking(True)
-            with ConnectionWrapper(sock, self) as conn:
-                #~ threading.current_thread().name = conn
+            with peer.Connection(
+                sock,
+                self,
+                upload_limiter=self.upload_limiter
+            ) as conn:
+                import threading
+                threading.current_thread().name = conn
                 conn.send_info_hash(self.info_hash)
                 conn.send_peer_id(self.peer_id)
                 conn.run()
 
     def peer_sent_peer_id(self, conn):
         # add the connection to the active peers
-        peer_id = conn.peer_peer_id
-        with self.lock:
-            if peer_id in self.active_peers:
-                self.logger.warning(
-                    '%s: Already connected to peer %s',
-                    self,
-                    peer_id)
-                assert conn is not self.active_peers[peer_id]
-                conn.close()
+        new_peer_id = conn.peer_peer_id
+        with self.active_peers_lock:
+            for conn1 in self.active_peers:
+                if conn1.peer_peer_id == new_peer_id:
+                    self.logger.warning(
+                        '%s: Already connected to peer %s', self, new_peer_id)
+                    assert conn1 is not conn
+                    conn.close()
+                    break
             else:
-                self.active_peers[peer_id] = conn
+                self.active_peers.add(conn)
                 conn.send_bitfield(self.bitfield_bytes())
-                if conn.peer_extensions[5] & 0x10:
+                if conn.extended_protocol_enabled:
+                    import socket
                     conn.send_extended(0, {
                         'm': {'ut_pex': 1},
                         'reqq': 5,
                 self.cond.notify_all()
 
     def activate_peer(self, sock, addr):
-        gthread.spawn(self.peer_routine, sock, addr)
-        return True
+        import threading
+        thread = threading.Thread(target=self.peer_routine, args=[sock, addr])
+        thread.daemon = True
+        thread.start()
 
     @routine_wrapper
     def connect_routine(self):
                     self.swarm_addrs_not_empty.wait()
                     if self.closed:
                         return
+                if self.closed:
+                    return
                 addr = self.swarm_addrs.pop()
             with self.half_open:
                 self.activate_peer(None, addr)
         return conn.peer_interested, conn.am_interested, conn.bytes_received, not conn.am_choked
 
     def do_stats(self):
-        num = sum(len(conn.peer_bitfield) for conn in self.active_peers.values())
-        denom = len(self.active_peers) * self.piece_count
+        with self.active_peers_lock:
+            num = sum(len(conn.peer_bitfield) for conn in self.active_peers)
+            denom = len(self.active_peers) * self.piece_count
         try:
             av_completion = 100 * num / denom
         except ZeroDivisionError:
                 peer.send_choke()
             for peer in peers[-5:]:
                 peer.send_unchoke()
-            self.wait(10)
+            self.wait_closed(10)
 
     def final_announce(self, tracker):
         try:
         with self.closed_cond:
             if self.closed:
                 return
+            self.logger.warning('Closing %s', self)
             self.closed = True
             self.closed_cond.notify_all()
         with self.swarm_addrs_not_empty:
             / self.piece_count)
 
     def wait_for_deadline(self, deadline):
-        return self.wait(self.timeout_from_deadline(deadline))
+        return not self.wait_closed(self.timeout_from_deadline(deadline))
 
     def timeout_from_deadline(self, deadline):
         import time
                 try:
                     peers, interval = tracker.announce(self, self.timeout_from_deadline(deadline))
                 except AnnounceError as exc:
-                    self.logger.warning('%s: Error announcing %s: %r', self, tracker, exc)
+                    self.logger.warning('%s: %s', tracker, exc)
                 else:
                     break
-                self.wait_for_deadline(deadline)
+                if not self.wait_for_deadline(deadline):
+                    return
                 timeout = min(timeout * 2, 3600)
             peers = set(peers)
             self.logger.info('Got %d peers from tracker %s', len(peers), tracker)
                 self.downloaded += request[2]
                 return True
 
+
 def main():
     import util
     util.configure_logging()