1. Su Lab
  2. BioThings
  3. mygene.info

Commits

Cyrus Afrasiabi  committed d0b634c

First commit to add fetch_all to mygene.info

Comments (0)

Files changed (3)

File src/api_v2/handlers.py Modified

View file
  • Ignore whitespace
  • Hide word diff
             size
             sort
             species
+            fetch_all
 
             explain
         '''
         kwargs = self.get_query_params()
         q = kwargs.pop('q', None)
+        scroll_id = kwargs.pop('scroll_id', None)
         _has_error = False
-        if q:
+        if scroll_id:
+            res = self.esq.scroll(scroll_id, fields=None, **kwargs)
+        elif q:
             explain = self.get_argument('explain', None)
             if explain and explain.lower() == 'true':
                 kwargs['explain'] = True
                         _has_error = True
             if not _has_error:
                 res = self.esq.query(q, **kwargs)
+                if kwargs.get('fetch_all', False):
+                    self.ga_track(event={'category': 'v2_api',
+                                         'action': 'fetch_all',
+                                         'label': 'total',
+                                         'value': res.get('total', None)})
         else:
             res = {'success': False, 'error': "Missing required parameters."}
 

File src/tests.py Modified

View file
  • Ignore whitespace
  • Hide word diff
 def test_static():
     get_ok(host + '/favicon.ico')
     get_ok(host + '/robots.txt')
+
+
+def test_fetch_all():
+    res = json_ok(get_ok(api + '/query?q=cdk2&fetch_all=true'))
+    assert '_scroll_id' in res
+
+    res2 = json_ok(get_ok(api + '/query?scroll_id=' + res['_scroll_id']))
+    assert 'hits' in res2
+    ok_(len(res2['hits']) >= 2)

File src/utils/es.py Modified

View file
  • Ignore whitespace
  • Hide word diff
     pass
 
 
+class MGScrollSetupError(Exception):
+    pass
+
+
 class ESQuery:
     def __init__(self):
         # self.conn0 = es0
         # self._doc_type = 'gene'
         self._index = ES_INDEX_NAME_ALL
         self._doc_type = ES_INDEX_TYPE
+        
+        # Scroll setup
+        self._scroll_time = '1m'
+        self._total_scroll_size = 1000
+        if self._total_scroll_size % self.get_number_of_shards() == 0:
+            self._scroll_size = int(self._total_scroll_size / self.get_number_of_shards())
+        else:
+            raise MGScrollSetupError("_total_scroll_size of {} can't be ".format(self._total_scroll_size) +
+                                     "divided evenly among {} shards.".format(self.get_number_of_shards()))
 
         # self._doc_type = 'gene_sample'
         self._default_fields = ['name', 'symbol', 'taxid', 'entrezgene']
         self._default_species = [9606, 10090, 10116]               # human, mouse, rat
         self._tier_1_species = set(taxid_d.values())
 
-    def _search(self, q, species='all'):
+    def _search(self, q, species='all', scroll_options={}):
         self._set_index(species)
         # body = '{"query" : {"term" : { "_all" : ' + q + ' }}}'
         res = self.conn.search(index=self._index, doc_type=self._doc_type,
-                               body=q)
+                               body=q, **scroll_options)
         self._index = ES_INDEX_NAME_ALL     # reset self._index
         return res
 
         else:
             return [self._get_genedoc(hit, dotfield=dotfield) for hit in hits['hits']]
 
+    def _clean_res2(self, res):
+        ''' res is the dictionary returned from a query.
+            do some reformating of raw ES results before returning.
+
+            This method is used for self.query method.
+        '''
+        _res = res['hits']
+        for attr in ['took', 'facets', 'aggregations', '_scroll_id']:
+            if attr in res:
+                _res[attr] = res[attr]
+        _res['hits'] = [self._get_genedoc(hit) for hit in _res['hits']]
+        return _res
+
+
     def _cleaned_res_2(self, res, empty=[], error={'error': True},
                        single_hit=False, dotfield=True, fields=None):
         if 'error' in res:
         options.rawquery = kwargs.pop('rawquery', False)
         #if dofield is false, returned fields contains dot notation will be restored as an object.
         options.dotfield = kwargs.pop('dotfield', True) not in [False, 'false']
+        options.fetch_all = kwargs.pop('fetch_all', False)
         scopes = kwargs.pop('scopes', None)
         if scopes:
             options.scopes = self._cleaned_scopes(scopes)
         options.kwargs = kwargs
         return options
 
+    def get_number_of_shards(self):
+        r = self.conn.indices.get_settings(self._index)
+        n_shards = r[list(r.keys())[0]]['settings']['index']['number_of_shards']
+        n_shards = int(n_shards)
+        return n_shards
+
     def get_gene(self, geneid, fields='all', **kwargs):
         kwargs['fields'] = self._cleaned_fields(fields)
         raw = kwargs.pop('raw', False)
         q = re.sub(u'[\t\n\x0b\x0c\r\x00]+', ' ', q)
         q = q.strip()
         _q = None
+        scroll_options = {}
+        if options.fetch_all:
+            scroll_options.update({'search_type': 'scan', 'size': self._scroll_size, 'scroll': self._scroll_time})
         # Check if special interval query pattern exists
         interval_query = self._parse_interval_query(q)
         try:
                 return _q
 
             try:
-                res = self._search(_q, species=kwargs['species'])
+                res = self._search(_q, species=kwargs['species'], scroll_options=scroll_options)
             except Exception as e:
                 if PY3:
                     msg = str(e)
                 return {'success': False, 'error': msg}
 
             if not options.raw:
-                _res = res['hits']
-                _res['took'] = res['took']
-                if "facets" in res:
-                    _res['facets'] = res['facets']
-                for v in _res['hits']:
-                    del v['_type']
-                    del v['_index']
-                    for attr in ['fields', '_source']:
-                        if attr in v:
-                            v.update(v[attr])
-                            del v[attr]
-                            break
-                    if not options.dotfield:
-                        parse_dot_fields(v)
-                res = _res
+                res = self._clean_res2(res)
+                #_res = res['hits']
+                #_res['took'] = res['took']
+                #if "facets" in res:
+                #    _res['facets'] = res['facets']
+                #for v in _res['hits']:
+                #    del v['_type']
+                #    del v['_index']
+                #    for attr in ['fields', '_source']:
+                #        if attr in v:
+                #            v.update(v[attr])
+                #            del v[attr]
+                #            break
+                #    if not options.dotfield:
+                #        parse_dot_fields(v)
+                #res = _res
         else:
             res = {'success': False,
                    'error': "Invalid query. Please check parameters."}
 
         return res
 
+    def scroll(self, scroll_id, fields=None, **kwargs):
+        options = self._get_cleaned_query_options(fields, kwargs)
+        r = self.conn.scroll(scroll_id, scroll=self._scroll_time)
+        scroll_id = r.get('_scroll_id')
+        if scroll_id is None or not r['hits']['hits']:
+            return {'success': False, 'error': 'No results to return.'}
+        else:
+            if not options.raw:
+                res = self._clean_res2(r)
+        #res.update({'_scroll_id': scroll_id})
+        if r['_shards']['failed']:
+            res.update({'_warning': 'Scroll request has failed on {} shards out of {}.'.format(r['_shards']['failed'], r['_shards']['total'])})
+        return res
+
     def query_interval(self, taxid, chr, gstart, gend, **kwargs):
         '''deprecated! Use query method with interval query string.'''
         kwargs.setdefault('fields', ['symbol', 'name', 'taxid'])