Anonymous avatar Anonymous committed 2a7825e

got sets working

Comments (0)

Files changed (5)

                   sql.Column('record_id', sql.Unicode, primary_key=True),
                   sql.Column('modified', sql.DateTime, index=True),
                   sql.Column('deleted', sql.Boolean),
-                  sql.Column('data', sql.String))
+                  sql.Column('metadata', sql.String))
         
         sql.Table('sets', db,
                   sql.Column('set_id', sql.Unicode, primary_key=True),
         self._cache = {'records': {}, 'sets': {}, 'setrefs': {}}
         
             
-    def update_record(self, oai_id, modified, deleted, sets, data):
+    def update_record(self, oai_id, modified, deleted, sets, metadata):
         # adds a record, call flush to actually store in db
 
         check_type(oai_id,
                    suffix='for parameter "deleted"')
         check_type(sets,
                    dict,
-                   unicode_keys=True,
                    unicode_values=True,
                    recursive=True,
                    prefix="record %s" % oai_id,
                    suffix='for parameter "sets"')
-        check_type(data,
+        check_type(metadata,
                    dict,
                    prefix="record %s" % oai_id,
-                   suffix='for parameter "dict"')
+                   suffix='for parameter "metadata"')
         
-        data = json.dumps(data)
+        metadata = json.dumps(metadata)
         self._cache['records'][oai_id] = (dict(modified=modified,
                                                deleted=deleted,
-                                               data=data))
+                                               metadata=metadata))
         self._cache['setrefs'][oai_id] = []
         for set_id in sets:
             self._cache['sets'][set_id] = dict(
         record = {'id': row.record_id,
                   'deleted': row.deleted,
                   'modified': row.modified,
-                  'data': json.loads(row.data),
+                  'metadata': json.loads(row.metadata),
                   'sets': self.get_setrefs(oai_id)}
         return record
 
                   offset=0,
                   batch_size=20,
                   sets=[],
-                  not_sets=[],
-                  filter_sets=[],
+                  disallowed_sets=[],
+                  allowed_sets=[],
                   from_date=None,
                   until_date=None,
                   identifier=None):
 
         setclauses = []
         for set_id in sets:
+            alias = self._setrefs.alias()
             setclauses.append(
                 sql.and_(
-                self._setrefs.c.set_id == set_id,
-                self._setrefs.c.record_id == self._records.c.record_id))
+                alias.c.set_id == set_id,
+                alias.c.record_id == self._records.c.record_id))
             
         if setclauses:
-            query.append_whereclause(sql.or_(*setclauses))
+            query.append_whereclause((sql.and_(*setclauses)))
             
-        # extra filter sets
-        
-        filter_setclauses = []
-        for set_id in filter_sets:
-            filter_setclauses.append(
+        allowed_setclauses = []
+        for set_id in allowed_sets:
+            alias = self._setrefs.alias()
+            allowed_setclauses.append(
                 sql.and_(
-                self._setrefs.c.set_id == set_id,
-                self._setrefs.c.record_id == self._records.c.record_id))
+                alias.c.set_id == set_id,
+                alias.c.record_id == self._records.c.record_id))
             
-        if filter_setclauses:
-            query.append_whereclause(sql.or_(*filter_setclauses))
+        if allowed_setclauses:
+            query.append_whereclause(sql.or_(*allowed_setclauses))
 
-        # filter not_sets
+        disallowed_setclauses = []
+        for set_id in disallowed_sets:
+            alias = self._setrefs.alias()
+            disallowed_setclauses.append(
+                sql.exists([self._records.c.record_id],
+                           sql.and_(
+                alias.c.set_id == set_id,
+                alias.c.record_id == self._records.c.record_id)))
+            
+        if disallowed_setclauses:
+            query.append_whereclause(sql.not_(sql.or_(*disallowed_setclauses)))
+            
+        for row in query.distinct().offset(offset).limit(batch_size).execute():
+            yield {'id': row.record_id,
+                   'deleted': row.deleted,
+                   'modified': row.modified,
+                   'metadata': json.loads(row.metadata),
+                   'sets': self.get_setrefs(row.record_id)
+                   }
 
-        not_setclauses = []
-        for set_id in not_sets:
-            not_setclauses.append(
-                sql.and_(
-                self._setrefs.c.set_id == set_id,
-                self._setrefs.c.record_id == self._records.c.record_id))
-            
-        if not_setclauses:
-            query.append_whereclause(sql.not_(sql.or_(*not_setclauses)))
-
-        for row in query.distinct().offset(offset).limit(batch_size).execute():
-            record = {'id': row.record_id,
-                      'deleted': row.deleted,
-                      'modified': row.modified,
-                      'data': json.loads(row.data),
-                      'sets': self.get_setrefs(row.record_id)
-                      }
-            yield {'record': record,
-                   'sets': record['sets'],
-                   'metadata': record['data'],
-                   'assets':{}}
-       
-    def empty_database(self):
-        self._records.delete().execute()
-        self._sets.delete().execute()
-        self._setrefs.delete().execute()
-
                                 'firstname': [first],
                                 'role': [u'aut']})
 
-        self.data = {'identifier': [u'http://example.org/data/%s' % id],
-                     'title': [xpath.string('//x:title')],
-                     'subject': xpath.strings('//x:subject'),
-                     'description': [xpath.string('//x:abstract')],
-                     'creator': [d['name'][0] for d in author_data],
-                     'author_data': author_data,
-                     'language': [u'en'],
-                     'date': [xpath.string('//x:issued')]}
-
+        self.metadata = {'identifier': [u'http://example.org/data/%s' % id],
+                         'title': [xpath.string('//x:title')],
+                         'subject': xpath.strings('//x:subject'),
+                         'description': [xpath.string('//x:abstract')],
+                         'creator': [d['name'][0] for d in author_data],
+                         'author_data': author_data,
+                         'language': [u'en'],
+                         'date': [xpath.string('//x:issued')]}
+        
         self.sets = {u'example': {u'name':u'example',
                                   u'description':u'An Example Set'}}
 
         if access == 'public':
             self.sets[u'public'] = {u'name':u'public',
                                     u'description':u'Public access'}
-            self.data['rights'] = [u'open access']
+            self.metadata['rights'] = [u'open access']
         elif access == 'private':
             self.sets[u'private'] = {u'name':u'private',
                                      u'description':u'Private access'}
-            self.data['rights'] = [u'restricted access']
+            self.metadata['rights'] = [u'restricted access']
 
             raise oaipmh.error.CannotDisseminateFormatError
 
     def _createHeader(self, record):
-        oai_id = record['record']['id']
-        datestamp = record['record']['modified']
-        sets = record['sets']
-        deleted = record['record']['deleted']
+        deleted = record['deleted']
         for deleted_set in self.config.sets_deleted:
             if deleted_set in record['sets']:
                 deleted = True
                 break
-        return oaipmh.common.Header(oai_id, datestamp, sets, deleted)
+        return oaipmh.common.Header(record['id'],
+                                    record['modified'],
+                                    record['sets'],
+                                    deleted)
 
     def _createHeaderAndMetadata(self, record):
         header = self._createHeader(record)
-        metadata = oaipmh.common.Metadata(record['metadata'])
+        metadata = oaipmh.common.Metadata(record)
         metadata.record = record
         return header, metadata
     
+# coding=utf8
+from unittest import TestCase, TestSuite, makeSuite
+import doctest
+import datetime
+
+from lxml import etree
+
+from moai.utils import XPath
+from moai.database import Database
+
+FLAGS = doctest.NORMALIZE_WHITESPACE + doctest.ELLIPSIS
+GLOBS = {}
+
+class XPathUtilTest(TestCase):
+
+    def test_string(self):
+        doc = etree.fromstring(
+            '''<doc>
+                 <string>test</string>
+                 <string/>
+                 <string>   test     </string>
+                 <string>test<foo/>more test</string>
+                 <string>tëst</string>
+               </doc>''')
+        xpath = XPath(doc)
+        self.assertEquals(xpath.strings('/doc/string'),
+                          [u'test', u'test', u'test', u'tëst'])
+    def test_boolean(self):
+        doc = etree.fromstring(
+            '''<doc>
+                 <bool>yes</bool>
+                 <bool>true</bool>
+                 <bool>False</bool>
+                 <bool>NO</bool>
+               </doc>''')
+        xpath = XPath(doc)
+        self.assertEquals(xpath.booleans('/doc/bool'),
+                          [True, True, False, False])
+
+    def test_number(self):
+        doc = etree.fromstring(
+            '''<doc>
+                 <number>1</number>
+                 <number>3.33333333333</number>
+                 <number>-75</number>
+                 <number>-0.75</number>
+               </doc>''')
+        xpath = XPath(doc)
+        numbers = xpath.numbers('/doc/number')
+        self.assertEquals(numbers,
+                          [1, 3.33333333333, -75, -0.75])
+        self.assertEquals([type(i) for i in numbers],
+                          [int, float, int, float])
+
+        
+    def test_date(self):
+        doc = etree.fromstring(
+            '''<doc>
+                 <date>2010-01-04</date>
+                 <date>2010/01/04</date>
+                 <date>2010-01-04T12:43:33Z</date>
+                 <date>2010-01-04T12:43:33</date>
+               </doc>''')
+        xpath = XPath(doc)
+        self.assertEquals(xpath.dates('/doc/date'),
+                          [datetime.datetime(2010, 1, 4, 0, 0),
+                           datetime.datetime(2010, 1, 4, 0, 0),
+                           datetime.datetime(2010, 1, 4, 12, 43, 33),
+                           datetime.datetime(2010, 1, 4, 12, 43, 33)])
+
+    def test_tags(self):
+        doc = etree.fromstring('<doc><a/><b/><s:c xmlns:s="urn:spam"/></doc>')
+        xpath = XPath(doc)
+        self.assertEquals(xpath.tags('/doc/*'), [u'a', u'b', u'c'])
+
+    def test_namespaces(self):
+        doc = etree.fromstring(
+            '<doc xmlns="urn:spam"><string>Spam!</string></doc>')
+        xpath = XPath(doc, nsmap={'spam': 'urn:spam'})
+        self.assertEquals(xpath.string('//spam:string'), u'Spam!')
+
+class DatabaseTest(TestCase):
+    def setUp(self):
+        self.db = Database()
+        
+    def tearDown(self):
+        del self.db
+        
+    def test_update(self):
+        # db is empty
+        self.assertEquals(self.db.record_count(), 0)
+        # let's add a record
+        self.db.update_record(u'oai:spam',
+                              datetime.datetime(2010, 10, 13, 12, 30, 00),
+                              False,
+                              {},
+                              {u'title': u'Spam!'})
+        self.db.flush()
+        self.assertEquals(self.db.record_count(), 1)
+        # check if all values are there
+        record = self.db.get_record(u'oai:spam')
+        self.assertEquals(record['id'], u'oai:spam')
+        self.assertEquals(record['deleted'], False)
+        self.assertEquals(record['modified'],
+                          datetime.datetime(2010, 10, 13, 12, 30, 00))
+        self.assertEquals(record['sets'], [])
+        self.assertEquals(record['metadata'], {u'title': u'Spam!'})
+        # change a metadata value
+        self.db.update_record(u'oai:spam',
+                              datetime.datetime(2010, 10, 13, 12, 30, 01),
+                              False,
+                              {},
+                              {u'title': u'Ham!'})
+        self.db.flush()
+        self.assertEquals(self.db.record_count(), 1)
+        # check if metadata was changed
+        record = self.db.get_record(u'oai:spam')
+        self.assertEquals(record['metadata'], {u'title': u'Ham!'})
+        # remove the record
+        self.db.remove_record(u'oai:spam')
+        self.assertEquals(self.db.record_count(), 0)
+        
+    def test_setrefs(self):
+        # add a record that references a set
+        self.assertEquals(self.db.set_count(), 0)
+        self.db.update_record(u'oai:spam',
+                              datetime.datetime(2010, 10, 13, 12, 30, 00),
+                              False,
+                              {u'spamset': {u'name': u'Spam Set',
+                                            u'description': u'spam spam spam',
+                                            u'hidden': False}},
+                              {u'title': u'Spam!'})
+        self.db.flush()
+        self.assertEquals(self.db.record_count(), 1)
+        self.assertEquals(self.db.set_count(), 1)
+        # check if all values are there
+        record = self.db.get_record(u'oai:spam')
+        self.assertEquals(record['sets'], [u'spamset'])
+        set = self.db.get_set(u'spamset')
+        self.assertEquals(set['id'], u'spamset')
+        self.assertEquals(set['name'], u'Spam Set')
+        self.assertEquals(set['description'], u'spam spam spam')
+        self.assertEquals(set['hidden'], False)
+        # now, we'll change the record to use the hamset
+        self.db.update_record(u'oai:spam',
+                              datetime.datetime(2010, 10, 13, 12, 30, 00),
+                              False,
+                              {u'hamset': {u'name': u'Ham Set',
+                                            u'description': u'ham ham ham',
+                                            u'hidden': False}},
+                              {u'title': u'Ham!'})
+        self.db.flush()
+        self.assertEquals(self.db.record_count(), 1)
+        # note that we now have 2 sets, the spam set is not removed
+        self.assertEquals(self.db.set_count(), 2)
+        # however, the spam record only has one reference 
+        record = self.db.get_record(u'oai:spam')
+        self.assertEquals(record['sets'], [u'hamset'])
+        # if the set is removed then all references to that set are
+        # also removed
+        self.db.remove_set(u'hamset')
+        record = self.db.get_record(u'oai:spam')
+        self.assertEquals(record['sets'], [])
+        self.assertEquals(self.db.set_count(), 1)
+
+    def test_hidden_sets(self):
+        # hidden sets are not added to the record setrefs list,
+        # they are there though, for filtering purposes
+        self.db.update_record(u'oai:spam',
+                              datetime.datetime(2010, 10, 13, 12, 30, 00),
+                              False,
+                              {u'spamset': {u'name': u'Spam Set',
+                                            u'description': u'spam spam spam',
+                                            u'hidden': False},
+                               u'hamset': {u'name': u'Ham Set',
+                                           u'description': u'ham ham ham',
+                                           u'hidden': True}},
+                              {u'title': u'Spam!'})
+        self.db.flush()
+        self.assertEquals(self.db.get_setrefs(u'oai:spam'), [u'spamset'])
+        self.assertEquals(self.db.get_setrefs(u'oai:spam',
+                                              include_hidden_sets=True),
+                          [u'hamset', u'spamset'])
+        # hidden sets are also never shown in the oai sets listing
+        self.assertEquals(list(self.db.oai_sets()),
+                          [{'description': u'spam spam spam',
+                            'id': u'spamset',
+                            'name': u'Spam Set'}] )
+
+    def test_earliest_datestamp(self):
+        self.assertEquals(self.db.oai_earliest_datestamp(),
+                          datetime.datetime(1970, 1, 1, 0, 0))
+        self.db.update_record(u'oai:spam',
+                              datetime.datetime(2009, 10, 13, 12, 30, 00),
+                              False, {}, {})
+        self.db.update_record(u'oai:ham',
+                              datetime.datetime(2010, 10, 13, 12, 30, 00),
+                              False, {}, {})
+        self.db.flush()
+        self.assertEquals(self.db.oai_earliest_datestamp(),
+                          datetime.datetime(2009, 10, 13, 12, 30))
+
+    def test_oai_query_dates(self):
+        self.db.update_record(u'oai:spam',
+                              datetime.datetime(2010, 01, 01, 00, 00, 00),
+                              False, {u'spamset':{u'name':u'spam'}},
+                              {})
+        self.db.update_record(u'oai:ham',
+                              datetime.datetime(2009, 01, 01, 00, 00, 00),
+                              False, {u'hamset':{u'name':u'ham'}},
+                              {})
+        self.db.flush()
+        self.assertEquals(
+            list(self.db.oai_query()),
+            [{'deleted': False,
+              'sets': [u'spamset'],
+              'metadata': {},
+              'id': u'oai:spam',
+              'modified': datetime.datetime(2010, 1, 1, 0, 0)},
+             {'deleted': False,
+              'sets': [u'hamset'],
+              'metadata': {},
+              'id': u'oai:ham',
+              'modified': datetime.datetime(2009, 1, 1, 0, 0)}])
+        # date slices
+        self.assertEquals(
+            [r['id'] for r in self.db.oai_query(
+            from_date=datetime.datetime(2009, 6, 1, 0, 0))],
+            [u'oai:spam'])
+        self.assertEquals(
+            [r['id'] for r in self.db.oai_query(
+            until_date=datetime.datetime(2009, 6, 1, 0, 0))],
+            [u'oai:ham'])
+        self.assertEquals(
+            [r['id'] for r in self.db.oai_query(
+            from_date=datetime.datetime(2008, 6, 1, 0, 0),
+            until_date=datetime.datetime(2010, 6, 1, 0, 0))],
+            [u'oai:spam', u'oai:ham'])
+        # no matches
+        self.assertEquals(
+            [r['id'] for r in self.db.oai_query(
+            from_date=datetime.datetime(2011, 1, 1, 0, 0))],
+            [])
+        self.assertEquals(
+            [r['id'] for r in self.db.oai_query(
+            until_date=datetime.datetime(2008, 1, 1, 0, 0))],
+            [])
+        # test inclusiveness
+        self.assertEquals(
+            [r['id'] for r in self.db.oai_query(
+            from_date=datetime.datetime(2009, 1, 1, 0, 0),
+            )],
+            [u'oai:spam', u'oai:ham'])
+
+    def test_oai_query_identifier(self):
+        self.db.update_record(u'oai:spam',
+                              datetime.datetime(2010, 01, 01, 00, 00, 00),
+                              False, {u'spamset':{u'name':u'spam'}},
+                              {})
+        self.db.flush()
+        self.assertEquals(
+            [r['id'] for r in self.db.oai_query(identifier=u'oai:spam')],
+            [u'oai:spam'])
+        
+    def test_oai_query_future_dates(self):
+        # records with a timestamp in the future should never
+        # be returned, this feature can be used to create embargo dates
+        self.db.update_record(u'oai:spam',
+                              datetime.datetime(2020, 01, 01, 00, 00, 00),
+                              False, {u'spamset':{u'name':u'spam'}},
+                              {})
+        self.db.flush()
+        self.assertEquals(list(self.db.oai_query()), [])
+        self.assertEquals(list(self.db.oai_query(
+            until_date=datetime.datetime(2030, 01, 01, 00, 00, 00))), [])
+        self.assertEquals(list(self.db.oai_query(identifier=u'oai:spam')), [])
+        
+    def test_oai_sets(self):
+        self.db.update_record(u'oai:spam',
+                              datetime.datetime(2009, 10, 13, 12, 30, 00),
+                              False, {u'spam': dict(name=u'spamset'),
+                                      u'test': dict(name=u'testset')}, {})
+        self.db.update_record(u'oai:spamspamspam',
+                              datetime.datetime(2009, 06, 13, 12, 30, 00),
+                              False, {u'spam': dict(name=u'spamset')}, {})
+        self.db.update_record(u'oai:ham',
+                              datetime.datetime(2010, 10, 13, 12, 30, 00),
+                              False, {u'ham': dict(name=u'hamset'),
+                                      u'test': dict(name=u'testset')}, {})
+        self.db.flush()
+        # all records
+        self.assertEquals([r['id'] for r in self.db.oai_query()],
+                          [u'oai:ham', u'oai:spam', u'oai:spamspamspam'])
+        # only set ham
+        self.assertEquals([r['id'] for r in self.db.oai_query(sets=[u'ham'])],
+                          [u'oai:ham'])
+        # only set spam
+        self.assertEquals([r['id'] for r in self.db.oai_query(sets=[u'spam'])],
+                          [u'oai:spam', u'oai:spamspamspam'])
+        # records in spam set and test set
+        self.assertEquals([r['id'] for r in self.db.oai_query(
+            sets=[u'test', u'spam'])], [u'oai:spam'])
+        # only allow records from certain sets
+        self.assertEquals([r['id'] for r in self.db.oai_query(
+            allowed_sets=[u'test'])], [u'oai:ham', u'oai:spam'])
+        self.assertEquals([r['id'] for r in self.db.oai_query(
+            allowed_sets=[u'spam', u'ham'])],
+                          [u'oai:ham', u'oai:spam', u'oai:spamspamspam'])
+        # only allow records from certain sets, combined with set
+        self.assertEquals([r['id'] for r in self.db.oai_query(
+            allowed_sets=[u'test'], sets=['spam'])],
+                          [u'oai:spam'])
+        self.assertEquals([r['id'] for r in self.db.oai_query(
+            allowed_sets=[u'spam'], sets=['test'])],
+                          [u'oai:spam'])
+        # certain records should always be disallowed
+        self.assertEquals([r['id'] for r in self.db.oai_query(
+            disallowed_sets=[u'spam'])],
+                           [u'oai:ham'])
+        # disallowed sets has precedence over allowed sets
+        self.assertEquals([r['id'] for r in self.db.oai_query(
+            disallowed_sets=[u'test'], allowed_sets=[u'spam'])],
+                           [u'oai:spamspamspam'])
+    def test_oai_batching(self):
+        self.db.update_record(u'oai:spam',
+                              datetime.datetime(2009, 10, 13, 12, 30, 00),
+                              False, {u'spam': dict(name=u'spamset'),
+                                      u'test': dict(name=u'testset')}, {})
+        self.db.update_record(u'oai:spamspamspam',
+                              datetime.datetime(2009, 06, 13, 12, 30, 00),
+                              False, {u'spam': dict(name=u'spamset')}, {})
+        self.db.update_record(u'oai:ham',
+                              datetime.datetime(2010, 10, 13, 12, 30, 00),
+                              False, {u'ham': dict(name=u'hamset'),
+                                      u'test': dict(name=u'testset')}, {})
+        self.db.flush()
+        self.assertEquals(len(list(self.db.oai_query())), 3)
+        self.assertEquals(len(list(self.db.oai_query(batch_size=1))), 1)
+        self.assertEquals([r['id'] for r in self.db.oai_query(
+            batch_size=1)], [u'oai:ham'])
+        self.assertEquals([r['id'] for r in self.db.oai_query(
+            batch_size=1, offset=1)], [u'oai:spam'])
+        self.assertEquals([r['id'] for r in self.db.oai_query(
+            batch_size=1, offset=2)], [u'oai:spamspamspam'])
+        
+
+        
+def suite():
+    test_suite = TestSuite()
+    test_suite.addTest(makeSuite(XPathUtilTest))
+    test_suite.addTest(makeSuite(DatabaseTest))
+    return test_suite
                                    content.modified,
                                    content.deleted,
                                    content.sets,
-                                   content.data)
+                                   content.metadata)
         except Exception, err:
             if options.debug:
                 raise
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.