nx-web / nxweb / utils / ws / __init__.py

import json

from functools import wraps
from collections import namedtuple

from pyramid.request import Request

from pyramid.httpexceptions import (
    HTTPException,
    HTTPNotFound,
    HTTPMethodNotAllowed,
    HTTPNotAcceptable
)

from pyramid.exceptions import PredicateMismatch

from webob.multidict import MultiDict, NestedMultiDict

from .data import (
    decode_request_body,
    decode_body_field
)

from .response import error

from . import renderers


class WebService(object):
    def __init__(self, pyramid_configurator, default_representations=None, auth_decorator=None):
        self.config = pyramid_configurator

        if default_representations is None:
            self.reprs = [('application/json', 'json')]
        else:
            self.reprs = default_representations

        self.auth_decorator = auth_decorator

    def make_resource(self, name, path, default_representations=None, auth_decorator=None):
        return Resource(self, name, path, default_representations, auth_decorator)


class Resource(object):
    def __init__(self, webservice, name, path, default_representations=None, auth_decorator=None):
        self.ws = webservice
        self.name = name
        self.path = path
        self.reprs = default_representations
        self.auth_decorator = auth_decorator

        self.ws.config.add_route(name, path)

    def _check_request_init_errors(self, view):
        """this must be called before any other view wrappers"""
        @wraps(view)
        def wrapper(req):
            if req.has_errors():
                return error(req, '400 Request init validation failed')
            return view(req)
        return wrapper

    def _validate(self, view, validators):
        @wraps(view)
        def wrapper(req):
            for validator in validators:
                validator(req)

            if req.has_errors():
                return error(req, '400 Request validation failed')

            return view(req)
        return wrapper

    def _add_view(self, view, request_method, representations=None, validators=None, auth_decorator=None, **kw):
        for disallow_kw in ('accept', 'renderer'):
            if disallow_kw in kw:
                raise TypeError('`{0}` not allowed as a keyword argument. '
                    'Must be passed as part of `representations`.'.format(disallow_kw))

        view = self.ws.config.maybe_dotted(view)

        # configure the view so the validators will be called for each request that the view will process
        if validators is not None:
            if not hasattr(validators, '__iter__'):
                validators = (validators,)
            validators = map(self.ws.config.maybe_dotted, validators)
        else:
            validators = []

        view_validators_attr = getattr(view, 'validators', None)
        if view_validators_attr is not None:
            validators.extend(view_validators_attr)

        if len(validators) > 0:
            view = self._validate(view, validators)

        # prevent execution of view logic if there where errors in the request initialization
        view = self._check_request_init_errors(view)

        # choose and apply nearest specified authentication check decorator
        auth_deco = auth_decorator or self.auth_decorator or self.ws.auth_decorator

        if auth_deco is not None:
            # this has precedence over request initialization check
            view = auth_deco(view)

        # choose nearest specified representations
        reprs = representations or self.reprs or self.ws.reprs

        # add a view for each representation
        for rep in reprs:
            accept, renderer = rep
            self.ws.config.add_view(
                view,
                route_name=self.name,
                request_method=request_method,
                accept=accept,
                renderer=renderer,
                **kw
            )

    def get(self, view, **kw):
        self._add_view(view, 'GET', **kw)

    def post(self, view, **kw):
        self._add_view(view, 'POST', **kw)

    def put(self, view, **kw):
        self._add_view(view, 'PUT', **kw)

    def delete(self, view, **kw):
        self._add_view(view, 'DELETE', **kw)


class Errors(list):
    Error = namedtuple('Error', ['location', 'name', 'description'])

    def __init__(self, request, httpexception=None):
        super(Errors, self).__init__()

        self.request = request
        self.httpexception = httpexception

    def add(self, location, name=None, description=None):
        self.append(Errors.Error(location, name, description))

    def add_querystring(self, name, description=None):
        self.add('querystring', name, description)

    def add_headers(self, name, description=None):
        self.add('headers', name, description)

    def add_body(self, name, description=None):
        self.add('body', name, description)

    def add_path(self, name, description=None):
        self.add('path', name, description)


class WSRequest(Request):
    def __init__(self, environ, **kw):
        super(WSRequest, self).__init__(environ, **kw)

        self.decoded_body = None

        self.valid = dict()
        self.errors = Errors(self)

        self.DECODED = MultiDict()

        raw_content_types = (
            'application/octet-stream',
        )

        content_type = self.headers.get('content-type', None)

        if self.body and content_type not in raw_content_types:
            # decode whole body based on content-type; ignores raw content types
            try:
                if content_type:
                    form_content_types = (
                        'application/x-www-form-urlencoded',
                        'multipart/form-data'
                    )

                    if any(map(content_type.startswith, form_content_types)):
                        self.decoded_body = self.POST
                    else:
                        self.decoded_body = decode_request_body(self)
                        self.DECODED.update(self.decoded_body)
                else:
                    self.errors.add_body('**decoding**', 'Content-Type header not found')
            except Exception as e:
                if isinstance(e, HTTPException):
                    self.errors.httpexception = e
                self.errors.add_body('**decoding**', str(e))

        if len(self.DECODED) > 0:
            # do a per-field decoding based on a special http header
            try:
                body_field_encoding = self.headers.get('x-body-field-encoding', None)

                if body_field_encoding is not None:
                    body_field_encoding = ((kv.strip() for kv in field.split('=')) for field in (field.strip() for field in body_field_encoding.split(',')))

                    for field in body_field_encoding:
                        field_key, field_encoding = field

                        if field_key in self.DECODED:
                            decode_body_field(self, field_key, field_encoding)
                        else:
                            self.errors.add_body('**body field decoding**', 'Body field not found: `{0}`'.format(field_key))

            except Exception as e:
                if isinstance(e, HTTPException):
                    self.errors.httpexception = e
                self.errors.add_body('**body field decoding**', str(e))

    @property
    def params(self):
        return NestedMultiDict(self.GET, self.POST, self.DECODED)

    def has_errors(self):
        return len(self.errors) > 0


def notfound(req):
    if req.matched_route is None:
        return HTTPNotFound()
    print req.matched_route, req.matched_route.name # DEBUG
    print dir(req.matched_route)
    ##
    introspector = req.registry.introspector
    routes = introspector.get('routes', req.matched_route.name)
    print routes # DEBUG
    print dir(routes)
    for k,v in routes.iteritems():
        print k,v
    ##
    views_info = {}

    for view_intr in introspector.related(routes):
        rms = view_intr.get('request_methods')
        accept = view_intr.get('accept')
        for rm in rms:
            if views_info.get(rm) is None:
                views_info[rm] = set()
            if accept is not None:
                views_info[rm].add(accept)

    if req.method not in views_info:
        resp = HTTPMethodNotAllowed()
        resp.allow = views_info.keys()
        return resp

    if req.accept.best_match(views_info[req.method]) is None:
        resp = HTTPNotAcceptable()
        if req.method != 'HEAD':
            resp.body = json.dumps(error(req, acceptable=list(views_info[req.method])))
            resp.content_type = 'application/json'
        return resp

    raise PredicateMismatch(req.matched_route.name)


def includeme(config):
    config.set_request_factory(WSRequest)
    config.add_notfound_view(notfound)
    config.add_renderer('json', renderers.JSON())
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.