Commits

Anonymous committed fa596f9

[econ/patterns/nber][m]: major speedups in extracting flows data.

Comments (0)

Files changed (2)

             self._get_nclass_stats()
         self.nclasss = sj.load(file(self.nclass_list))
 
+    def full_join(self):
+        full = db.patent.join(db.citation)
+        p1 = db.patent
+        p2 = db.patent.alias('cited')
+        c1 = db.citation
+        full = full.outerjoin(p2, p2.c.id==db.citation.c.cited_id)
+        sel = sql.select(
+                [p1.c.id, p2.c.id, p1.c.cmade,
+                    p1.c.nclass, p2.c.nclass,
+                    p1.c.subcat, p2.c.subcat,
+                    ],
+                from_obj=full)
+        sel = sel.apply_labels()
+        return sel
+
     def _get_subcat_stats(self):
         # hand-crafted sql is probably faster
         q = sql.select([db.patent.c.subcat, sql.func.count('*')])
 
     def all_flows(self):
         self.flows = {}
-        for year in range(1985, 1986):
+        for year in range(1975, 1995):
             self.flows[year] = self.get_flows_by_year(year)
         return self.flows
 
         fp2 = os.path.join(D.flowdatadir, fn2)
         if not os.path.exists(fp1):
             print '## Extracting flow information for year: ', year
-            patents = db.Patent.query.filter_by(gyear=year)
+            patents = self.full_join()
+            patents = patents.where(db.patent.c.gyear==year)
             if limit > 0:
                 patents = patents.limit(limit)
-            msubcat, mnclass = self.get_flows_slow(patents)
+            msubcat, mnclass = self.get_flows(patents)
             scipy.io.write_array(fp1, msubcat) 
             scipy.io.write_array(fp2, mnclass)
         else:
             mnclass = scipy.io.read_array(fp2)
         return (msubcat, mnclass)
 
-    def full_join(self):
-        full = db.patent.join(db.citation)
-        p1 = db.patent
-        p2 = db.patent.alias('p2')
-        c1 = db.citation
-        full = full.outerjoin(p2, p2.c.id==db.citation.c.cited_id)
-        print full
-        sel = sql.select(
-                [p1.c.id, p2.c.id, p1.c.cmade,
-                    p1.c.nclass, p2.c.nclass,
-                    p1.c.subcat, p2.c.subcat,
-                    ],
-                from_obj=full)
-        sel = sel.apply_labels()
-        return sel
+    def get_flows(self, query):
+        '''Citation flows from different areas (subcat/nclass).
+        
+        Sources for citations are patents in query so we are measuring all flow
+        into those areas from *all* other patents (i.e. cited patents can be
+        outside of our set of patents).
 
-    def get_flows(self, query):
-        '''We include cited patents outside of our set of patents.
+        Currently count each citation as 1 unit.
 
-        That is we are simply looking for all inflows into our set.
+        @return 2 matrixes (subcat and nclass respectively) with M(i,j) =
+        citation flow from area i to area j.
+
+        Flow i->j is measured by a citation from i->j (so movement of 'ideas'
+        from j -> i).
 
         Where patent is unknown assign this to an extra last category
 
         For efficiency do both nclass and subcat at once.
         '''
-        for row in query.execute():
-            src
         class CategoryProcessor(object):
             def __init__(self, catname, cat_summary):
                 self.catname = catname
+                self.src_cat_col = getattr(db.patent.c, self.catname)
+                cited = db.patent.alias('cited')
+                self.cited_cat_col = getattr(cited.c, self.catname)
+
                 self.size = len(cat_summary)
                 self.catlist = [ x[0] for x in cat_summary]
                 self.size = len(self.catlist)
                 # add 1 for case where patents unknown
                 self.matrix = N.zeros( (self.size+1, self.size+1) )
 
-            def process(self, srccat, destcat, flow):
-                i = self.index(srccat)
-                j = self.index(destcat)
-                # self.matrix[i,j] = self.matrix[i,j] + 
+            def process(self, row):
+                # weight in flow by number of citatons this patent has
+                # flow = 1.0/row[db.patent.c.cmade]
+                flow = 1.0
+                i = self.cat2idx(row[self.src_cat_col])
+                j = self.cat2idx(row[self.cited_cat_col])
+                self.matrix[i,j] = self.matrix[i,j] + flow
 
-            def index(self, cat):
+            def cat2idx(self, cat):
                 # index into matrix
                 if cat is None:
                     return self.size
                 else:
                     return self.catlist.index(cat)
-            proc_subcat = CategoryProcessor('subcat', self.subcats)
-            proc_nclass = CategoryProcessor('nclass', self.nclasss)
-
-    def get_flows_slow(self, patents):
-        '''This is just too slow ...
-        '''
-        class CategoryProcessor(object):
-            def __init__(self, catname, cat_summary):
-                self.catname = catname
-                self.size = len(cat_summary)
-                self.catlist = [ x[0] for x in cat_summary]
-                self.size = len(self.catlist)
-                # add 1 for case where patents unknown
-                self.matrix = N.zeros( (self.size+1, self.size+1) )
-
-            def index(self, patent):
-                # index into matrix
-                if patent is None:
-                    return self.size
-                else:
-                    cat = getattr(patent, self.catname)
-                    return self.catlist.index(cat)
 
         proc_subcat = CategoryProcessor('subcat', self.subcats)
         proc_nclass = CategoryProcessor('nclass', self.nclasss)
-        total = patents.count()
+        # very slow
+        # countq = query.alias().count()
+        # total = countq.execute().fetchall()[0][0]
+        total = 'Unknown'
         print 'Total to process:', total
         count = -1
-        for p in patents:
+        for row in query.execute():
             count += 1
-            # print every 1% point
-            if count % max(1,(total/100)) == 0: print count
-            print count
-            for cite in p.citations:
-                flow = 1.0/p.cmade
-                dest = db.Patent.query.get(cite.cited_id)
-                for proc in [ proc_subcat, proc_nclass ]:
-                    srcindex = proc.index(p)
-                    destindex = proc.index(dest)
-                    proc.matrix[srcindex, destindex] = proc.matrix[srcindex, destindex] + flow
+            # if count % max(1,(total/100)) == 0: # every 1%
+            if count % 10000 == 0:
+                print count
+            for proc in [ proc_subcat, proc_nclass ]:
+                    proc.process(row)
+
         return (proc_subcat.matrix, proc_nclass.matrix)
     
 
     pylab.savefig(fn)
     pylab.clf()
 
+def plot_flows(year):
+    a = Analyzer()
+    msubcat, mnclass = a.get_flows_by_year(1985)
+    # subcat 13
+    idx = 2
+    subcats = [ x[0] for x in a.subcats ]
+    subcat = subcats[idx]
+    # for undefined area
+    subcats.append(subcats[-1] + 1)
+    vals = msubcat[2]
+    pylab.bar(subcats, vals)
+    fn = os.path.join(outdir, 'flows_subcat_%s_%s.png' % (subcat, year))
+    pylab.savefig(fn)
+
 def main():
     plot_ddist(1975)
     plot_ddist(1985)
 if __name__ == '__main__':
     # main()
     a = Analyzer()
-    # a.all_flows()
-    a.get_flows_fast()
+    a.all_flows()
+    # plot_flows(1985)
 
 
         # these should actually be the same since all citation post 1975
         assert len(pat.citations_by) == pat.creceive
 
+    def test_full_join(self):
+        q = self.a.full_join()
+        # q = q.where(db.patent.c.gyear==1975)
+        # q = q.limit(10)
+        q = q.where(db.patent.c.id==3883377)
+        out = q.execute().fetchall()[0]
+        print out.keys()
+        print out
+        assert out[1] ==  3357959
+        assert out['patent_cmade'] == 1
+        assert out['cited_id'] == 3357959
+        # check other access method
+        citing = db.patent.alias('cited')
+        assert out[citing.c.id] == 3357959
+
     def test_subcat_stats(self):
         out = self.a.subcats
         print out
         y1975 = out[0][1]['cmade']
         assert y1975[0][0] == 0
 
-    def test_get_flows_slow_nclass(self):
-        pats = db.Patent.query.filter_by(gyear=1975).filter_by(nclass=2).limit(20)
-        msubcat, mnclass = self.a.get_flows_slow(pats)
+    def test_get_flows(self):
+        baseq = self.a.full_join()
+        baseq = baseq.where(db.patent.c.gyear==1975) 
+        exptotal = 20
+        pats = baseq.where(db.patent.c.nclass==2).limit(exptotal)
+        msubcat, mnclass = self.a.get_flows(pats)
         # total flow should sum to the number of patents
-        total_flow = mnclass.sum()
-        assert round(total_flow, 1) == 20.0, total_flow
+        print mnclass
+        total_flow = round(mnclass.sum(), 1)
+        assert total_flow == exptotal, total_flow
         # should have self refs
         assert mnclass[1,1] > 0.0, mnclass
 
-    def get_get_flows_slow_subcat(self):
-        pats = db.Patent.query.filter_by(gyear=1975).filter_by(subcat=13).limit(20)
-        msubcat, mnclass = self.a.get_flows_slow(pats)
-        # total flow should sum to the number of patents
-        total_flow = msubcat.sum()
-        assert round(total_flow, 1) == 20.0, total_flow
+        pats = baseq.where(db.patent.c.subcat==13).limit(exptotal)
+        msubcat, mnclass = self.a.get_flows(pats)
+        total_flow = round(msubcat.sum(),1)
+        assert total_flow == exptotal, total_flow
         # should have self refs and subcat 13 is col 3
         assert msubcat[2,2] > 0.0, msubcat
 
         total = round(matrix.sum(),0)
         assert total == exptotal, total
 
-    def test_full_join(self):
+    def _test_full_join_count(self):
+        # very slow
         q = self.a.full_join()
-        # q = q.where(db.patent.c.gyear==1975)
-        # q = q.limit(10)
-        q = q.where(db.patent.c.id==3883377)
-        out = q.execute().fetchall()[0]
-        print out.keys()
-        print out
-        assert out[1] ==  3357959
-        assert out['patent_cmade'] == 1
-        print out['p2_id'] == 3357959
+        q = q.where(db.patent.c.gyear==1975)
+        q = q.alias()
+        q = q.count()
+        print q
+        q.execute()
+        # out = q.execute().fetchall()[0]
+        assert out == 0, out
 
+