Commits

Stephen Skory committed aac45ed

A few changes for inline Rockstar.

  • Participants
  • Parent commits 88f17e5

Comments (0)

Files changed (1)

File yt/analysis_modules/halo_finding/rockstar/rockstar.py

             stride = int(ceil(float(pool.comm.size) / self.num_writers))
             while len(self.writers) < self.num_writers:
                 self.writers.extend(avail[::stride])
-                for r in readers:
+                for r in avail:
                     avail.pop(avail.index(r))
 
     def run(self, handler, pool):
         # Start writers.
         writer_pid = 0
         if pool.comm.rank in self.writers:
-            time.sleep(0.1 + pool.comm.rank/10.0)
+            time.sleep(0.05 + pool.comm.rank/10.0)
             writer_pid = os.fork()
             if writer_pid == 0:
                 handler.start_writer()
                 os._exit(0)
         # Start readers, not forked.
         if pool.comm.rank in self.readers:
-            time.sleep(0.1 + pool.comm.rank/10.0)
+            time.sleep(0.05 + pool.comm.rank/10.0)
             handler.start_reader()
         # Make sure the forks are done, which they should be.
         if writer_pid != 0:
         if server_pid != 0:
             os.waitpid(server_pid, 0)
 
+    def setup_pool(self):
+        pool = ProcessorPool()
+        # Everyone is a reader, and when we're inline, that's all that matters.
+        readers = np.arange(ytcfg.getint("yt", "__global_parallel_size"))
+        pool.add_workgroup(ranks=readers, name="readers")
+        return pool, pool.workgroups[0]
+
 class StandardRunner(ParallelAnalysisInterface):
     def __init__(self, num_readers, num_writers):
         self.num_readers = num_readers
                     self.num_readers, self.num_writers, psize)
             raise RuntimeError
     
-    def split_work(self):
+    def split_work(self, pool):
         self.readers = np.arange(self.num_readers) + 1
         self.writers = np.arange(self.num_writers) + 1 + self.num_readers
     
         if wg.name == "server":
             handler.start_server()
         if wg.name == "readers":
-            time.sleep(0.1)
+            time.sleep(0.05)
             handler.start_reader()
         if wg.name == "writers":
-            time.sleep(0.2)
+            time.sleep(0.1)
             handler.start_writer()
+    
+    def setup_pool(self):
+        pool = ProcessorPool()
+        pool, workgroup = ProcessorPool.from_sizes(
+           [ (1, "server"),
+             (self.num_readers, "readers"),
+             (self.num_writers, "writers") ]
+        )
+        return pool, workgrup
 
 class RockstarHaloFinder(ParallelAnalysisInterface):
     def __init__(self, ts, num_readers = 1, num_writers = None,
             del tpf
         else:
             self.force_res = force_res
-        # We set up the workgroups *before* initializing
-        # ParallelAnalysisInterface. Everyone is their own workgroup!
-        self.pool = ProcessorPool()
-        self.pool, self.workgroup = ProcessorPool.from_sizes(
-           [ (1, "server"),
-             (self.num_readers, "readers"),
-             (self.num_writers, "writers") ]
-        )
+        # Setup pool and workgroups.
+        self.pool, self.workgroup = self.runner.setup_pool()
         p = self._setup_parameters(ts)
         params = self.comm.mpi_bcast(p, root = self.pool['readers'].ranks[0])
         self.__dict__.update(params)
 
 
     def __del__(self):
-        self.pool.free_all()
+        try:
+            self.pool.free_all()
+        except AttributeError:
+            # This really only acts to cut down on the misleading
+            # error messages when/if this class is called incorrectly
+            # or some other error happens and self.pool hasn't been created
+            # already.
+            pass
 
     def _get_hosts(self):
         if self.comm.rank == 0 or self.comm.size == 1:
             self.handler.call_rockstar()
         else:
             # Split up the work.
-            self.runner.split_work()
+            self.runner.split_work(self.pool)
             # And run it!
             self.runner.run(self.handler, self.workgroup)
         self.comm.barrier()