Commits

drewthepooh  committed 7addd4f

flagstatscsv with pool instead of process

  • Participants
  • Parent commits f9c22bc

Comments (0)

Files changed (2)

File fastq_subset.py

 
     def recordGenerator(fastq):
         while True:
-            try:
-                yield [next(fastq) for i in xrange(4)]
-            except:
-                raise StopIteration
+            yield [next(fastq) for i in xrange(4)]
 
     random_set = set(random.sample(xrange(numRecords), sample))
 

File flagstats2csv.py

 import csv
 import subprocess
 from os.path import basename
-from multiprocessing import Process, Queue
-import glob
+from multiprocessing import Pool
 import argparse
 import sys
 
 
-def flagstats_calculator(*file_paths):
+def get_flagstats(file_path):
+
+    print('Calculating stats for', basename(file_path), file=sys.stderr)
+    flagstat_bytes = subprocess.check_output(['samtools', 'flagstat', file_path])
+    flagstat_output = flagstat_bytes.decode('UTF-8')
+    flagstats_raw = flagstat_output.split('\n')[:-1]
 
-    def get_flagstats(file_path, q):
+    flagstats_nums = [stat[0] for stat in (flagstats_raw[i].split(' ')
+                      for i in range(len(flagstats_raw)))]
+    flagstats_nums.insert(0, basename(file_path))
+    return flagstats_nums
 
-        print('Calculating stats for', basename(file_path), file=sys.stderr)
-        flagstat_bytes = subprocess.check_output(['samtools', 'flagstat', file_path])
-        flagstat_output = flagstat_bytes.decode('UTF-8')
-        flagstats_raw = flagstat_output.split('\n')[:-1]
 
-        flagstats_nums = [stat[0] for stat in (flagstats_raw[i].split(' ')
-                          for i in range(len(flagstats_raw)))]
-        flagstats_nums.insert(0, basename(file_path))
-        q.put(flagstats_nums)
+def flagstats_calculator(*file_paths):
 
     writer = csv.writer(sys.stdout, delimiter='\t', lineterminator='\n')
     stats = ['sample',
     writer.writerow(stats)
 
     if __name__ == '__main__':
-        jobs = []
-        q = Queue()
-        for file_path in file_paths:
-            j = Process(target=get_flagstats, args=(file_path, q))
-            jobs.append(j)
-            j.start()
-        for j in jobs:
-            j.join()
-        for j in jobs:
-                if j.exitcode != 0:
-                    raise RuntimeError('multi process returned with non-0 '
-                                       'exit code')
-        while not q.empty():
-            writer.writerow(q.get())
+        pool = Pool()
+        rows = pool.map(func=get_flagstats, iterable=file_paths)
+        pool.close()
+        pool.join()
+        writer.writerows(rows)
 
 
 def main():
     helpText = ('Takes a number of bam files (can be input with file name globbing),'
                 'calculates flagstats in parallel, and outputs to stdout.  Capture the'
-                'output to generate a csv file')
+                'output to generate a csv file. Note that due to the nature of the '
+                'multiprocessing, you may have to hit ctrl-C twice if an error is raised.')
     parser = argparse.ArgumentParser(description=helpText)
     parser.add_argument('input_BAM',
-                        help='one or more BAM files for analysis',
+                        help='BAM file for analysis',
                         nargs='+')
     args = parser.parse_args()