Commits

Drew Smathers committed e3d99b0

range-based parallel downloads

Comments (0)

Files changed (1)

 import sys
 import os
 import optparse
+import time
 from Queue import Queue
 
 import boto
 
 class CompletionCounter:
 
-    def __init__(self, total):
+    def __init__(self, total, sending=True):
         self.completed = 0
         self.total = total
+        self.sending = sending
 
     def count(self, fd, piece):
         self.completed += 1
-        sys.stderr.write('%-40s \r' % ('transferred chunks %d/%d' % (self.completed, self.total)))
+        verb = ('received', 'transferred')[self.sending]
+        sys.stderr.write('%-40s \r' % ('%s chunks %d/%d' % (verb, self.completed, self.total)))
 
 
 def generate_chunk_files(path, feedback=False):
     return coiterate(chunker).addCallback(
         lambda ign: sys.stderr.write('\nfinished generating chunks for threads\n'))
 
+
+class DownloadChunkMultiplexer(object):
+
+    def __init__(self, path):
+        self.file = open(path, 'w+b')
+
+    def update(self, piece, chunk):
+        start = piece * CHUNK_SIZE
+        self.file.seek(start)
+        self.file.write(chunk)
+
+    def close(self):
+        self.file.close()
+
+
+class DownloadTask(object):
+
+    def __init__(self, bucket, key, queue, download_manager, callback):
+        self.bucket = bucket
+        self.key = key
+        self.queue = queue
+        self.download_manager = download_manager
+        self.callback = callback
+        self.buf = StringIO()
+
+    def _get_key(self, key, bucket=None):
+        if not bucket:
+            conn = boto.connect_s3()
+            bucket = conn.lookup(self.bucket)
+        return bucket, bucket.get_key(key)
+
+    def get(self):
+        bucket, key = self._get_key(self.key)
+        while 1:
+            message = self.queue.get()
+            if message == DONE:
+                return
+            piece, start, end = message
+            bucket, key = self._get(piece, start, end, bucket, key)
+
+    def _get(self, piece, start, end, bucket, key):
+        while 1:
+            try:
+                self.buf.seek(0)
+                self.buf.truncate()
+                key.get_contents_to_file(self.buf, headers={'Range':'bytes=%d-%d' % (start, end)})
+                reactor.callFromThread(self.download_manager.update, piece, self.buf.getvalue())
+            except Exception, e:
+                reactor.callFromThread(sys.stderr.write,
+                    'Failure getting piece %d. (%s) Trying again.\n' % (piece, e))
+                bucket, key = self._get_key(self.key)
+                time.sleep(0.1)
+            else:
+                break
+        bucket, key = self._get_key(self.key, bucket)
+        reactor.callFromThread(self.callback, None, piece)
+        return bucket, key
+
+def download_ranges(bucket, s3path, localpath, threads):
+    sys.stderr.write('thread count: %d\n' %  threads)
+    reactor.suggestThreadPoolSize(threads)
+    reactor.callWhenRunning(_download_ranges, bucket, s3path, localpath, threads)
+    reactor.run()
+
+def _download_ranges(bucket, s3path, localpath, threads):
+    conn = boto.connect_s3()
+    bucket2 = conn.lookup(bucket)
+    key = bucket2.get_key(s3path)
+    size = key.size
+    chunks = (size / CHUNK_SIZE) + 1
+
+    if not localpath:
+        localpath = os.path.basename(s3path)
+
+    counter = CompletionCounter(chunks, sending=False)
+    q = Queue(threads * 4)
+    download_manager = DownloadChunkMultiplexer(localpath)
+    deferreds = []
+    for i in range(threads):
+        task = DownloadTask(bucket, s3path, q, download_manager, counter.count)
+        deferreds.append(deferToThread(task.get))
+
+    def end(ignore):
+        sys.stderr.write('\ndone\n')
+        download_manager.close()
+        reactor.stop()
+    gatherResults(deferreds).addErrback(log.err).addCallback(end)
+
+    def ranger():
+        for index in range(chunks):
+            start = CHUNK_SIZE * index
+            if start == size:
+                break
+            end = min(start + CHUNK_SIZE - 1, size)
+            q.put((index, start, end))
+            yield
+        for i in range(threads):
+            q.put(DONE)
+            yield
+    ranger = ranger()
+    return coiterate(ranger).addCallback(
+        lambda ign: sys.stderr.write('\nfinished generating ranges for threads\n'))
+
+
 def main():
     parser = optparse.OptionParser()
     parser.add_option('-b', '--bucket', dest='bucket', help='The bucket name')
                       help='Number of threads to use')
     parser.add_option('-p', '--public', dest='public', action='store_true',
                       help='Use this flag to set acl to public-read')
+    parser.add_option('-d', '--download', dest='download', action='store_true',
+                      help='Download a resource grabbing chunks in parallel')
+    parser.add_option('-o', '--output-file', dest='outputfile',
+                      help='For download, the output file to save to. '
+                           'If not given the basename of the key will be used.')
     opts, args = parser.parse_args()
     path = args[0]
-    upload_multipart(opts.bucket, path, opts.threads, opts.public)
+    if opts.download:
+        sys.stderr.write('downloading %s\n' % path)
+        download_ranges(opts.bucket, path, opts.outputfile, opts.threads or 10)
+    else:
+        upload_multipart(opts.bucket, path, opts.threads, opts.public)
 
 if __name__ == '__main__':
     main()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.