Commits

Anonymous committed 9f066d6

Added AdvFormValidationError which is serializable with a status code of 4000

  • Participants
  • Parent commits 2cc125f

Comments (0)

Files changed (3)

File piston/emitters.py

         self.handler = handler
         self.fields = fields
         self.anonymous = anonymous
-        
+
         if isinstance(self.data, Exception):
             raise
-    
+
     def method_fields(self, data, fields):
         if not data:
             return { }
 
         has = dir(data)
         ret = dict()
-            
+
         for field in fields:
             if field in has and callable(field):
                 ret[field] = getattr(data, field)
-        
+
         return ret
-    
+
     def construct(self):
         """
         Recursively serialize a lot of types, and
             Dispatch, all types are routed through here.
             """
             ret = None
-            
+
             if isinstance(thing, QuerySet):
                 ret = _qs(thing, fields=fields)
             elif isinstance(thing, (tuple, list)):
             Foreign keys.
             """
             return _any(getattr(data, field.name))
-        
+
         def _related(data, fields=()):
             """
             Foreign keys.
             """
             return [ _model(m, fields) for m in data.iterator() ]
-        
+
         def _m2m(data, field, fields=()):
             """
             Many to many (re-route to `_model`.)
             """
             return [ _model(m, fields) for m in getattr(data, field.name).iterator() ]
-        
+
         def _model(data, fields=()):
             """
             Models. Will respect the `fields` and/or
             ret = { }
             handler = self.in_typemapper(type(data), self.anonymous)
             get_absolute_uri = False
-            
+
             if handler or fields:
                 v = lambda f: getattr(data, f.attname)
 
 
                     if 'absolute_uri' in get_fields:
                         get_absolute_uri = True
-                
+
                     if not get_fields:
                         get_fields = set([ f.attname.replace("_id", "", 1)
                             for f in data._meta.fields ])
-                
+
                     # sets can be negated.
                     for exclude in exclude_fields:
                         if isinstance(exclude, basestring):
                             get_fields.discard(exclude)
-                            
+
                         elif isinstance(exclude, re._pattern_type):
                             for field in get_fields.copy():
                                 if exclude.match(field):
                                     get_fields.discard(field)
-                                    
+
                 else:
                     get_fields = set(fields)
 
                 met_fields = self.method_fields(handler, get_fields)
-                
+
                 for f in data._meta.local_fields:
                     if f.serialize and not any([ p in met_fields for p in [ f.attname, f.name ]]):
                         if not f.rel:
                             if f.attname[:-3] in get_fields:
                                 ret[f.name] = _fk(data, f)
                                 get_fields.remove(f.name)
-                
+
                 for mf in data._meta.many_to_many:
                     if mf.serialize and mf.attname not in met_fields:
                         if mf.attname in get_fields:
                             ret[mf.name] = _m2m(data, mf)
                             get_fields.remove(mf.name)
-                
+
                 # try to get the remainder of fields
                 for maybe_field in get_fields:
                     if isinstance(maybe_field, (list, tuple)):
                         # using different names.
                         ret[maybe_field] = _any(met_fields[maybe_field](data))
 
-                    else:                    
+                    else:
                         maybe = getattr(data, maybe_field, None)
                         if maybe:
                             if callable(maybe):
             else:
                 for f in data._meta.fields:
                     ret[f.attname] = _any(getattr(data, f.attname))
-                
+
                 fields = dir(data.__class__) + ret.keys()
                 add_ons = [k for k in dir(data) if k not in fields]
-                
+
                 for k in add_ons:
                     ret[k] = _any(getattr(data, k))
-            
+
             # resouce uri
             if self.in_typemapper(type(data), self.anonymous):
                 handler = self.in_typemapper(type(data), self.anonymous)
                 if hasattr(handler, 'resource_uri'):
                     url_id, fields = handler.resource_uri()
-                    ret['resource_uri'] = permalink( lambda: (url_id, 
-                        (getattr(data, f) for f in fields) ) )()
-            
+                    ret['resource_uri'] = permalink(lambda: (url_id,
+                        (getattr(data, f) for f in fields)))()
+
             if hasattr(data, 'get_api_url') and 'resource_uri' not in ret:
                 try: ret['resource_uri'] = data.get_api_url()
                 except: pass
-            
+
             # absolute uri
             if hasattr(data, 'get_absolute_url') and get_absolute_uri:
                 try: ret['absolute_uri'] = data.get_absolute_url()
                 except: pass
-            
+
             return ret
-        
+
         def _qs(data, fields=()):
             """
             Querysets.
             """
             return [ _any(v, fields) for v in data ]
-                
+
         def _list(data):
             """
             Lists.
             """
             return [ _any(v) for v in data ]
-            
+
         def _dict(data):
             """
             Dictionaries.
             """
             return dict([ (k, _any(v)) for k, v in data.iteritems() ])
-            
+
         # Kickstart the seralizin'.
         return _any(self.data, self.fields)
-    
+
     def in_typemapper(self, model, anonymous):
         for klass, (km, is_anon) in self.typemapper.iteritems():
             if model is km and is_anon is anonymous:
                 return klass
-        
+
     def render(self):
         """
         This super emitter does not implement `render`,
         this is a job for the specific emitter below.
         """
         raise NotImplementedError("Please implement render.")
-        
+
     def stream_render(self, request, stream=True):
         """
         Tells our patched middleware not to look
         more memory friendly for large datasets.
         """
         yield self.render(request)
-        
+
     @classmethod
     def get(cls, format):
         """
             return cls.EMITTERS.get(format)
 
         raise ValueError("No emitters found for type %s" % format)
-    
+
     @classmethod
     def register(cls, name, klass, content_type='text/plain'):
         """
          - `content_type`: The content type to serve response as.
         """
         cls.EMITTERS[name] = (klass, content_type)
-        
+
     @classmethod
     def unregister(cls, name):
         """
         want to provide output in one of the built-in emitters.
         """
         return cls.EMITTERS.pop(name, None)
-    
+
 class XMLEmitter(Emitter):
     def _to_xml(self, xml, data):
         if isinstance(data, (list, tuple)):
 
     def render(self, request):
         stream = StringIO.StringIO()
-        
+
         xml = SimplerXMLGenerator(stream, "utf-8")
         xml.startDocument()
         xml.startElement("response", {})
-        
+
         self._to_xml(xml, self.construct())
-        
+
         xml.endElement("response")
         xml.endDocument()
-        
+
         return stream.getvalue()
 
 Emitter.register('xml', XMLEmitter, 'text/xml; charset=utf-8')
-Mimer.register(lambda *a: None, ('text/xml',))
+Mimer.register(lambda * a: None, ('text/xml',))
 
 class JSONEmitter(Emitter):
     """
             return '%s(%s)' % (cb, seria)
 
         return seria
-    
+
 Emitter.register('json', JSONEmitter, 'application/json; charset=utf-8')
 Mimer.register(simplejson.loads, ('application/json',))
-    
+
 class YAMLEmitter(Emitter):
     """
     YAML emitter, uses `safe_dump` to omit the
     """
     def render(self, request):
         return pickle.dumps(self.construct())
-        
+
 Emitter.register('pickle', PickleEmitter, 'application/python-pickle')
 Mimer.register(pickle.loads, ('application/python-pickle',))
 
             response = serializers.serialize(format, self.data, indent=True)
 
         return response
-        
+
 Emitter.register('django', DjangoEmitter, 'text/xml; charset=utf-8')

File piston/resource.py

 from handler import typemapper
 from doc import HandlerMethod
 from authentication import NoAuthentication
-from utils import coerce_put_post, FormValidationError, HttpStatusCode
+from utils import (coerce_put_post, FormValidationError, AdvFormValidationError,
+    HttpStatusCode)
 from utils import rc, format_error, translate_mime, MimerDataException
 
 class Resource(object):
     is an authentication handler. If not specified,
     `NoAuthentication` will be used by default.
     """
-    callmap = { 'GET': 'read', 'POST': 'create', 
+    callmap = { 'GET': 'read', 'POST': 'create',
                 'PUT': 'update', 'DELETE': 'delete' }
-    
+
     def __init__(self, handler, authentication=None):
         if not callable(handler):
             raise AttributeError, "Handler not callable."
-        
+
         self.handler = handler()
-        
+
         if not authentication:
             self.authentication = NoAuthentication()
         else:
             self.authentication = authentication
-            
+
         # Erroring
         self.email_errors = getattr(settings, 'PISTON_EMAIL_ERRORS', True)
         self.display_errors = getattr(settings, 'PISTON_DISPLAY_ERRORS', True)
         that as well.
         """
         em = kwargs.pop('emitter_format', None)
-        
+
         if not em:
             em = request.GET.get('format', 'json')
 
         return em
-    
+
     @vary_on_headers('Authorization')
     def __call__(self, request, *args, **kwargs):
         """
         else:
             handler = self.handler
             anonymous = handler.is_anonymous
-        
+
         # Translate nested datastructs into `request.data` here.
         if rm in ('POST', 'PUT'):
             try:
                 translate_mime(request)
             except MimerDataException:
                 return rc.BAD_REQUEST
-        
+
         if not rm in handler.allowed_methods:
             return HttpResponseNotAllowed(handler.allowed_methods)
-        
+
         meth = getattr(handler, self.callmap.get(rm), None)
-        
+
         if not meth:
             raise Http404
 
         em_format = self.determine_emitter(request, *args, **kwargs)
 
         kwargs.pop('emitter_format', None)
-        
+
         # Clean up the request object a bit, since we might
         # very well have `oauth_`-headers in there, and we
         # don't want to pass these along to the handler.
         request = self.cleanup_request(request)
-        
+
         try:
             result = meth(request, *args, **kwargs)
+
+        except AdvFormValidationError, e:
+            result = e
         except FormValidationError, e:
             # TODO: Use rc.BAD_REQUEST here
             return HttpResponse("Bad Request: %s" % e.form.errors, status=400)
             sig = hm.get_signature()
 
             msg = 'Method signature does not match.\n\n'
-            
+
             if sig:
                 msg += 'Signature should be: %s' % sig
             else:
                 msg += 'Resource does not expect any parameters.'
 
-            if self.display_errors:                
+            if self.display_errors:
                 msg += '\n\nException was: %s' % str(e)
-                
+
             result.content = format_error(msg)
         except HttpStatusCode, e:
             #result = e ## why is this being passed on and not just dealt with now?
                 isinstance(result, list) or isinstance(result, QuerySet)):
             fields = handler.list_fields
 
-        srl = emitter(result, typemapper, handler, fields, anonymous)
+        if isinstance(result, AdvFormValidationError):
+            errors = dict()
+            for key, value in result.form.errors.iteritems():
+                errors.update({key:[unicode(e) for e in value]})
+
+            if result.form.non_field_errors():
+                errors['non_field_errors'] = result.form.non_field_errors()
+            srl = emitter(errors, typemapper, handler, fields, anonymous)
+
+        else:
+            srl = emitter(result, typemapper, handler, fields, anonymous)
 
         try:
             """
 
             resp = HttpResponse(stream, mimetype=ct)
 
+            if isinstance(result, AdvFormValidationError):
+                resp.status_code = 400
+
             resp.streaming = self.stream
 
             return resp
 
             if True in [ k.startswith("oauth_") for k in block.keys() ]:
                 sanitized = block.copy()
-                
+
                 for k in sanitized.keys():
                     if k.startswith("oauth_"):
                         sanitized.pop(k)
-                        
+
                 setattr(request, method_type, sanitized)
 
         return request
-        
+
     # -- 
-    
+
     def email_exception(self, reporter):
         subject = "Piston crash report"
         html = reporter.get_traceback_html()
 
-        message = EmailMessage(settings.EMAIL_SUBJECT_PREFIX+subject,
+        message = EmailMessage(settings.EMAIL_SUBJECT_PREFIX + subject,
                                 html, settings.SERVER_EMAIL,
                                 [ admin[1] for admin in settings.ADMINS ])
-        
+
         message.content_subtype = 'html'
         message.send(fail_silently=True)

File piston/utils.py

     """
     Status codes.
     """
-    CODES = dict(ALL_OK = ('OK', 200),
-                 CREATED = ('Created', 201),
-                 DELETED = ('', 204), # 204 says "Don't send a body!"
-                 BAD_REQUEST = ('Bad Request', 400),
-                 FORBIDDEN = ('Forbidden', 401),
-                 NOT_FOUND = ('Not Found', 404),
-                 DUPLICATE_ENTRY = ('Conflict/Duplicate', 409),
-                 NOT_HERE = ('Gone', 410),
-                 NOT_IMPLEMENTED = ('Not Implemented', 501),
-                 THROTTLED = ('Throttled', 503))
+    CODES = dict(ALL_OK=('OK', 200),
+                 CREATED=('Created', 201),
+                 DELETED=('', 204), # 204 says "Don't send a body!"
+                 BAD_REQUEST=('Bad Request', 400),
+                 FORBIDDEN=('Forbidden', 401),
+                 NOT_FOUND=('Not Found', 404),
+                 DUPLICATE_ENTRY=('Conflict/Duplicate', 409),
+                 NOT_HERE=('Gone', 410),
+                 NOT_IMPLEMENTED=('Not Implemented', 501),
+                 THROTTLED=('Throttled', 503))
 
     def __getattr__(self, attr):
         """
             raise AttributeError(attr)
 
         return HttpResponse(r, content_type='text/plain', status=c)
-    
+
 rc = rc_factory()
-    
+
 class FormValidationError(Exception):
     def __init__(self, form):
         self.form = form
 
+class AdvFormValidationError(FormValidationError):
+    pass
+
 class HttpStatusCode(Exception):
     def __init__(self, response):
         self.response = response
     @decorator
     def wrap(f, self, request, *a, **kwa):
         form = v_form(getattr(request, operation))
-    
+
         if form.is_valid():
             return f(self, request, *a, **kwa)
         else:
             raise FormValidationError(form)
     return wrap
 
-def throttle(max_requests, timeout=60*60, extra=''):
+def adv_validate(v_form, operation='POST'):
+    """
+    Advanced validation decorator to return 
+    serialized errors on invalid forms.
+    """
+    @decorator
+    def wrap(func, self, request, *args, **kwargs):
+        form = v_form(getattr(request, operation))
+
+        if form.is_valid():
+            return func(self, request, *args, **kwargs)
+        else:
+            raise AdvFormValidationError(form)
+
+    return wrap
+
+def throttle(max_requests, timeout=60 * 60, extra=''):
     """
     Simple throttling decorator, caches
     the amount of requests made in cache.
             ident = request.user.username
         else:
             ident = request.META.get('REMOTE_ADDR', None)
-    
+
         if hasattr(request, 'throttle_extra'):
             """
             Since we want to be able to throttle on a per-
             object. If so, append the identifier name with it.
             """
             ident += ':%s' % str(request.throttle_extra)
-        
+
         if ident:
             """
             Preferrably we'd use incr/decr here, since they're
             stable, you can change it here.
             """
             ident += ':%s' % extra
-    
+
             now = datetime.now()
             ts_key = 'throttle:ts:%s' % ident
             timestamp = cache.get(ts_key)
             offset = now + timedelta(seconds=timeout)
-    
+
             if timestamp and timestamp < offset:
                 t = rc.THROTTLED
-                wait = timeout - (offset-timestamp).seconds
+                wait = timeout - (offset - timestamp).seconds
                 t.content = 'Throttled, wait %d seconds.' % wait
-                
+
                 return t
-                
+
             count = cache.get(ident, 1)
-            cache.set(ident, count+1)
-            
+            cache.set(ident, count + 1)
+
             if count >= max_requests:
                 cache.set(ts_key, offset, timeout)
                 cache.set(ident, 1)
-    
+
         return f(self, request, *args, **kwargs)
     return wrap
 
             request.META['REQUEST_METHOD'] = 'POST'
             request._load_post_and_files()
             request.META['REQUEST_METHOD'] = 'PUT'
-            
+
         request.PUT = request.POST
 
 
 
 class Mimer(object):
     TYPES = dict()
-    
+
     def __init__(self, request):
         self.request = request
-        
+
     def is_multipart(self):
         content_type = self.content_type()
 
         type_formencoded = "application/x-www-form-urlencoded"
 
         ctype = self.request.META.get('CONTENT_TYPE', type_formencoded)
-        
+
         if ctype.startswith(type_formencoded):
             return None
-        
+
         return ctype
-        
+
 
     def translate(self):
         """
         It will also set `request.content_type` so the handler has an easy
         way to tell what's going on. `request.content_type` will always be
         None for form-encoded and/or multipart form data (what your browser sends.)
-        """    
+        """
         ctype = self.content_type()
         self.request.content_type = ctype
-        
+
         if not self.is_multipart() and ctype:
             loadee = self.loader_for_type(ctype)
-            
+
             try:
                 self.request.data = loadee(self.request.raw_post_data)
-                
+
                 # Reset both POST and PUT from request, as its
                 # misleading having their presence around.
                 self.request.POST = self.request.PUT = dict()
                 raise MimerDataException
 
         return self.request
-                
+
     @classmethod
     def register(cls, loadee, types):
         cls.TYPES[loadee] = types
-        
+
     @classmethod
     def unregister(cls, loadee):
         return cls.TYPES.pop(loadee)
 
 def translate_mime(request):
     request = Mimer(request).translate()
-    
+
 def require_mime(*mimes):
     """
     Decorator requiring a certain mimetype. There's a nifty
     return wrap
 
 require_extended = require_mime('json', 'yaml', 'xml', 'pickle')
-    
+