Commits

Matthew Turk committed 16a2374

Have ParticleOctree collect domains on a per-oct basis, not per-particle.
Mandate only a single domain per Oct.

Comments (0)

Files changed (2)

yt/geometry/oct_container.pxd

     Oct *oct
     ParticleArrays *next
     np.float64_t **pos
-    np.int64_t *domain_id
     np.int64_t np

yt/geometry/oct_container.pyx

                 count[cur.my_octs[i - cur.offset].domain - 1] += 1
         return count
 
-    def check(self, int curdom):
-        cdef int dind, pi
-        cdef Oct oct
-        cdef OctAllocationContainer *cont = self.domains[curdom - 1]
-        cdef int nbad = 0
-        for pi in range(cont.n_assigned):
-            oct = cont.my_octs[pi]
-            for i in range(2):
-                for j in range(2):
-                    for k in range(2):
-                        if oct.children[i][j][k] != NULL and \
-                           oct.children[i][j][k].level != oct.level + 1:
-                            if curdom == 61:
-                                print pi, oct.children[i][j][k].level,
-                                print oct.level
-                            nbad += 1
-        print "DOMAIN % 3i HAS % 9i BAD OCTS (%s / %s / %s)" % (curdom, nbad, 
-            cont.n - cont.n_assigned, cont.n_assigned, cont.n)
-
     @cython.boundscheck(False)
     @cython.wraparound(False)
     @cython.cdivision(True)
                 for i in range(3):
                     free(o.sd.pos[i])
                 free(o.sd.pos)
-            free(o.sd.domain_id)
         free(o)
 
     @cython.boundscheck(False)
             malloc(sizeof(ParticleArrays))
         cdef int i, j, k
         my_oct.ind = my_oct.domain = -1
+        my_oct.domain = -1
         my_oct.local_ind = self.nocts - 1
         my_oct.pos[0] = my_oct.pos[1] = my_oct.pos[2] = -1
         my_oct.level = -1
         self.last_sd = sd
         sd.oct = my_oct
         sd.next = NULL
-        sd.domain_id = <np.int64_t *> malloc(sizeof(np.int64_t) * 32)
         sd.pos = <np.float64_t **> malloc(sizeof(np.float64_t*) * 3)
         for i in range(3):
             sd.pos[i] = <np.float64_t *> malloc(sizeof(np.float64_t) * 32)
         for i in range(32):
             sd.pos[0][i] = sd.pos[1][i] = sd.pos[2][i] = 0.0
-            sd.domain_id[i] = -1
         sd.np = 0
         return my_oct
 
             cur = self.root_mesh[ind[0]][ind[1]][ind[2]]
             if cur == NULL:
                 raise RuntimeError
-            if self._check_refine(cur, cp) == 1:
+            if self._check_refine(cur, cp, domain_id) == 1:
                 self.refine_oct(cur, cp)
             while cur.sd.np < 0:
                 for i in range(3):
                         cp[i] += dds[i]/2.0
                 cur = cur.children[ind[0]][ind[1]][ind[2]]
                 level += 1
-                if self._check_refine(cur, cp) == 1:
+                if self._check_refine(cur, cp, domain_id) == 1:
                     self.refine_oct(cur, cp)
             # Now we copy in our particle 
             pi = cur.sd.np
             cur.level = level
             for i in range(3):
                 cur.sd.pos[i][pi] = pp[i]
-            cur.sd.domain_id[pi] = domain_id
+            cur.domain = domain_id
             cur.sd.np += 1
 
-    cdef int _check_refine(self, Oct *cur, np.float64_t cp[3]):
-        cdef int mid = 16384
-        cdef int mad = -16384
-        for i in range(imax(cur.sd.np, 0)):
-            mid = imin(cur.sd.domain_id[i], mid)
-            mad = imax(cur.sd.domain_id[i], mad)
-        if cur.sd.np == 32 or mid < mad:
+    cdef int _check_refine(self, Oct *cur, np.float64_t cp[3], int domain_id):
+        if cur.children[0][0][0] != NULL:
+            return 0
+        elif cur.sd.np == 32:
+            return 1
+        elif cur.domain >= 0 and cur.domain != domain_id:
             return 1
         return 0
 
                     noct.pos[2] = (o.pos[2] << 1) + k
                     noct.parent = o
                     o.children[i][j][k] = noct
-        for m in range(32):
+        for m in range(o.sd.np):
             for i in range(3):
                 if o.sd.pos[i][m] < pos[i]:
                     ind[i] = 0
             k = noct.sd.np
             for i in range(3):
                 noct.sd.pos[i][k] = o.sd.pos[i][m]
-            noct.sd.domain_id[k] = o.sd.domain_id[k]
+            noct.domain = o.domain
             noct.sd.np += 1
         o.sd.np = -1
         for i in range(3):
             free(o.sd.pos[i])
-        free(o.sd.domain_id)
         free(o.sd.pos)
 
     def recursively_count(self):
                     m = 1
                     break
             if m == 0: continue
-            for i in range(o.sd.np):
-                dmask[o.sd.domain_id[i]] = 1
+            dmask[o.domain] = 1
         return dmask.astype("bool")
 
     @cython.boundscheck(False)