Commits

Anonymous committed dcd6545

use custom queryset

Comments (0)

Files changed (2)

flaskext/mongoengine.py

 import mongoengine
 
 from mongoengine.queryset import MultipleObjectsReturned, DoesNotExist
+from mongoengine.queryset import QuerySet as BaseQuerySet
 from mongoengine import ValidationError
 
 from flask import abort
 
         _include_mongoengine(self)
 
-        self.Model = Model
-
+        self.QuerySet = QuerySet
+        self.BaseQuerySet = BaseQuerySet
+        
         if app is not None:
             self.init_app(app)
 
         self.connection = mongoengine.connect(
             db=db, username=username, password=password)
 
-    # once v0.4 lands these will be deprecated and moved 
-    # into a QuerySet subclass, as the next version allows
-    # for custom QuerySets.
 
-    def get_or_404(self, qs, *args, **kwargs):
+class QuerySet(BaseQuerySet):
+
+    def get_or_404(self, *args, **kwargs):
         try:
-            return qs.get(*args, **kwargs)
+            return self.get(*args, **kwargs)
         except (MultipleObjectsReturned, DoesNotExist, ValidationError):
             abort(404)
 
-    def first_or_404(self, qs):
+    def first_or_404(self):
 
-        obj = qs.first()
+        obj = self.first()
         if obj is None:
             abort(404)
 
         return obj
 
-    def paginate(self, qs, page, per_page, error_out=True):
+    def paginate(self, page, per_page, error_out=True):
         
         if error_out and page < 1:
             abort(404)
         
         offset = (page - 1) * per_page
-        items = qs[offset:per_page]
+        items = self[offset:per_page]
 
         if not items and page != 1 and error_out:
             abort(404)
 
-        return Pagination(self, qs, page, per_page, qs.count(), items)
+        return Pagination(self, page, per_page, qs.count(), items)
 
 
 class Pagination(object):
                 last = num
 
 
-class Model(mongoengine.Document):
 
-    pass
 

test_mongoengine.py

 
 
 def make_todo_model(db):
-    class Todo(db.Model):
+    class Todo(db.Document):
         title = db.StringField(max_length=60)
         text = db.StringField()
         done = db.BooleanField(default=False)
         pub_date = db.DateTimeField(default=datetime.utcnow)
+        meta = dict(queryset_class=db.QuerySet)
     return Todo
 
 
 
         @app.route('/show/<id>/')
         def show(id):
-            todo = db.get_or_404(self.Todo.objects, id=id)
+            todo = self.Todo.objects.get_or_404(id=id)
 
         self.app = app
         self.db = db
         c.post('/add', data=dict(title='First Item', text='The text'))
         c.post('/add', data=dict(title='2nd Item', text='The text'))
         rv = c.get('/')
-        assert rv.data == 'First Item\n2nd Item'
+        print rv.data
+        assert rv.data == '2nd Item\nFirst Item'
 
     def test_request_context(self):
         with self.app.test_request_context():