Commits

Christophe de Vienne committed 72e7ac4

Add a rule type that transforms arguments

Comments (0)

Files changed (2)

wsmorph/__init__.py

 import re
+import os.path
 import six
 
 import logging
     import json  # noqa
 
 from webob.dec import wsgify
+from webob.multidict import MultiDict
 
 log = logging.getLogger(__name__)
 
 
     def iter_calls(self, context, request):
         fname = request.path[len(self.webpath):]
-        context.fname = fname
-        yield Call(request, fname)
+        context.fname, context.fname_ext = os.path.splitext(fname)
+        yield Call(request, context.fname)
 
     def patch_request(self, context, request):
-        newpath = self.target_webpath + context.fname
+        newpath = self.target_webpath + context.fname + context.fname_ext
         if hasattr(request, 'upath_info'):
             request.upath_info = newpath
         else:
             request.path = newpath
+        if hasattr(request, '_wsmorph_GET'):
+            request.query_string = urlencode([
+                (k, v) for k, v in request._wsmorph_GET.items()
+            ])
+        if getattr(request, '_wsmorph_POST', None):
+            body = urlencode([
+                (k, v) for k, v in request._wsmorph_POST.items()
+            ])
+            if six.PY3:
+                body = body.encode('ascii')
+            request.body = body
         self.patch_request_body(context, request)
 
     def patch_response(self, request, response):
     def change_fname(self, call, newname):
         call.context.fname = newname
 
+    def GET(self, request):
+        if not hasattr(request, '_wsmorph_GET'):
+            request._wsmorph_GET = MultiDict(parse_qsl(request.query_string))
+        return request._wsmorph_GET
+
+    def POST(self, request):
+        if not hasattr(request, '_wsmorph_POST'):
+            if request.content_type == 'application/x-www-form-urlencoded':
+                body = request.body
+                if six.PY3:
+                    body = body.decode('ascii')
+                request._wsmorph_POST = MultiDict(parse_qsl(body))
+            else:
+                request._wsmorph_POST = None
+        return request._wsmorph_POST
+
     def rename_argument(self, call, oldname, newname):
-        if oldname in (
-                key for key, value in parse_qsl(call.request.query_string)):
-            new_args = [
-                (newname if name == oldname else name, value)
-                for name, value in parse_qsl(call.request.query_string)
-            ]
-            call.request.query_string = urlencode(new_args)
-        elif call.request.content_type == 'application/x-www-form-urlencoded' \
-                and oldname in (
-                    key for key, value in
-                    parse_qsl(call.request.query_string)):
-            new_args = [
-                (newname if name == oldname else name, value)
-                for name, value in parse_qsl(call.request.body)
-            ]
-            call.request.body = urlencode(new_args)
+        GET = self.GET(call.request)
+        POST = self.POST(call.request)
+        if oldname in GET:
+            for i in range(len(GET._items)):
+                key, value = GET._items[i]
+                if key == oldname:
+                    GET._items[i] = (newname, value)
+        elif POST and oldname in POST:
+            for i in range(len(POST._items)):
+                key, value = POST._items[i]
+                if key == oldname:
+                    POST._items[i] = (newname, value)
         else:
             self.rename_body_argument(call, oldname, newname)
 
+    def get_arg_value(self, call, argname):
+        ARGS = self.GET(call.request)
+        if argname not in ARGS:
+            ARGS = self.POST(call.request)
+        if ARGS and argname in ARGS:
+            value = ARGS.getall(argname)
+            if len(value) == 1:
+                value = value[0]
+            return value
+        return self.get_body_arg_value(call, argname)
+
+    def set_arg_value(self, call, argname, value):
+        ARGS = self.GET(call.request)
+        if argname not in ARGS:
+            ARGS = self.POST(call.request)
+        if ARGS and argname in ARGS:
+            if isinstance(value, list):
+                del ARGS[argname]
+                for item in value:
+                    ARGS.add(argname, str(item))
+            else:
+                ARGS[argname] = str(value)
+        else:
+            self.set_body_arg_value(call, argname, value)
+
 
 class RestJson(Rest):
     def accept(self, request):
             body = call.request.body
             if six.PY3:
                 body = body.decode('utf8')
-            call.context.jsondata = jsondata = json.loads(body)
+            jsondata = json.loads(body)
+            call.context.jsondata = jsondata
         return jsondata
 
     def setjsondata(self, call, jsondata):
             for name, value in self.getjsondata(call).items()
         )))
 
+    def get_body_arg_value(self, call, argname):
+        return self.getjsondata(call)[argname]
+
+    def set_body_arg_value(self, call, argname, value):
+        data = self.getjsondata(call)
+        data[argname] = value
+        self.setjsondata(call, data)
+
+    def get_return_value(self, call):
+        pass
+
+    def set_return_value(self, call, value):
+        pass
+
 
 class RestXml(Rest):
     def accept(self, request):
         if hasattr(context, 'xml'):
             request.body = ET.tostring(context.xml)
 
+    def get_body_arg_value(self, call, argname):
+        xml = self.getxml(call)
+        return loadxml(xml.find(argname))
+
+    def set_body_arg_value(self, call, argname, value):
+        xml = self.getxml(call)
+        node = xml.find(argname)
+        node.clear()
+        dumpxml(node, value)
+        self.setxml(call, xml)
+
+    def get_return_value(self, call):
+        pass
+
+    def set_return_value(self, call, value):
+        pass
+
 
 uppercase_re = '([A-Z])'
 
         protocol.rename_argument(call, self.oldname, self.newname)
 
 
+class TransArgumentRule(Rule):
+    def __init__(self, argname, trans):
+        self.argname = argname
+        self.trans = trans
+
+    def precall(self, match_values, call, protocol):
+        value = protocol.get_arg_value(call, self.argname)
+        value = self.trans(value)
+        protocol.set_arg_value(call, self.argname, value)
+
+
 class RuleSet(object):
     def __init__(self, match_fname):
         if isinstance(match_fname, str):
         self.rulesets.append(RuleSet(fname_match))
         self.rulesets[-1].rules.append(RenameArgumentRule(oldname, newname))
 
+    def add_trans_argument(self, fname_match, argname, trans):
+        self.rulesets.append(RuleSet(fname_match))
+        self.rulesets[-1].rules.append(TransArgumentRule(argname, trans))
+
     def get_rulesets(self, call):
         rulesets = []
         for ruleset in self.rulesets:

wsmorph/tests/test_all.py

         res = self.app.get(path, headers=headers)
         return json.loads(res.text)
 
-    def post_json(self, path, body=None, **args):
+    def post_json(self, path, body=None, form=False, **args):
         headers = {
             'Accept': 'application/json',
             'Content-Type': 'application/json',
         }
+        if form:
+            headers['Content-Type'] = 'application/x-www-form-urlencoded'
         if args and body is None:
-            body = json.dumps(args)
-        res = self.app.post(path, body, headers=headers)
+            if form:
+                body = urlencode(args)
+            else:
+                body = json.dumps(args)
+        res = self.app.post(path + '.json', body, headers=headers)
         return json.loads(res.text)
 
     def get_xml(self, path, **args):
     def test_rename_argument_post_xml(self):
         self.morph.add_rename_argument('misc/multiply', 'c', 'a')
         assert self.post_xml('/v1/misc/multiply', c=5, b=15) == '75'
+
+    def test_trans_arg(self):
+        def double_value(value):
+            return int(value) * 2
+
+        self.morph.add_trans_argument(
+            'misc/multiply', 'a', double_value
+        )
+        value = self.get_json('/v1/misc/multiply', a=2, b=4)
+        assert value == 16, value
+
+        value = self.post_json('/v1/misc/multiply', a=2, b=4, form=True)
+        assert value == 16, value
+
+    def test_trans_body_arg_json(self):
+        def inc(value):
+            return int(value) + 1
+
+        self.morph.add_trans_argument(
+            'misc/multiply', 'b', inc
+        )
+        value = self.post_json('/v1/misc/multiply', a=2, b=4)
+        assert value == 10
+
+    def test_trans_body_arg_xml(self):
+        def inc(value):
+            return int(value) + 1
+
+        self.morph.add_trans_argument(
+            'misc/multiply', 'b', inc
+        )
+        value = self.post_xml('/v1/misc/multiply', a=2, b=4)
+        assert value == '10'
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.