1. Christophe de Vienne
  2. wsmorph

Commits

Christophe de Vienne  committed 68ebf09

Add result transformation rules. Rename pre/postcall to before/after_call.

  • Participants
  • Parent commits 94d8769
  • Branches default

Comments (0)

Files changed (2)

File wsmorph/__init__.py

View file
  • Ignore whitespace
             request.body = body
         self.patch_request_body(context, request)
 
-    def patch_response(self, request, response):
-        pass
+    def patch_response(self, context, response):
+        self.patch_response_body(context, response)
 
     def change_fname(self, call, newname):
         call.context.fname = newname
     def getjsondata(self, call):
         jsondata = getattr(call.context, 'jsondata', None)
         if jsondata is None:
-            print call.request.__class__.__name__
             body = call.request.body
             if not body:
                 return Rest.EmptyBody
     def setjsondata(self, call, jsondata):
         call.context.jsondata = jsondata
 
+    def getresultdata(self, call):
+        resultdata = getattr(call.context, 'resultdata', Unset)
+        if resultdata is Unset:
+            body = call.response.body
+            if not body:
+                return Rest.EmptyBody
+            if six.PY3:
+                body = body.decode('utf8')
+            resultdata = json.loads(body)
+            call.context.resultdata = resultdata
+        return resultdata
+
+    def setresultdata(self, call, value):
+        call.context.resultdata = value
+
     def patch_request_body(self, context, request):
         if hasattr(context, 'jsondata'):
             s = json.dumps(context.jsondata)
                 s = s.encode('utf8')
             request.body = s
 
+    def patch_response_body(self, context, response):
+        if hasattr(context, 'resultdata'):
+            s = json.dumps(context.resultdata)
+            if six.PY3:
+                s = s.encode('utf8')
+            response.body = s
+
     def rename_body_argument(self, call, oldname, newname):
         self.setjsondata(call, dict((
             (newname if name == oldname else name, value)
         self.setjsondata(call, data)
 
     def get_return_value(self, call):
-        pass
+        data = self.getresultdata(call)
+        return data
 
     def set_return_value(self, call, value):
-        pass
+        self.setresultdata(call, value)
 
 
 class RestXml(Rest):
         if hasattr(context, 'xml'):
             request.body = ET.tostring(context.xml)
 
+    def patch_response_body(self, context, response):
+        if hasattr(context, 'resultdata'):
+            e = ET.Element('result')
+            dumpxml(e, context.resultdata)
+            response.body = ET.tostring(e)
+
     def get_body_arg_value(self, call, argname):
         xml = self.getxml(call)
         return loadxml(xml.find(argname))
         self.setxml(call, xml)
 
     def get_return_value(self, call):
-        pass
+        resultdata = getattr(call.context, 'resultdata', Unset)
+        if resultdata is Unset:
+            resultdata = loadxmlstring(call.response.body)
+            call.context.resultdata = resultdata
+        return resultdata
 
     def set_return_value(self, call, value):
-        pass
+        call.context.resultdata = value
 
 
 uppercase_re = '([A-Z])'
 
 
 class Rule(object):
-    def precall(self, match_values, call, protocol):
+    def before_call(self, match_values, call, protocol):
         pass
 
-    def postcall(self, match_values, call, protocol):
+    def after_call(self, match_values, call, protocol):
         pass
 
 
         elif six.callable(self.trans):
             return self.trans(match_values)
 
-    def precall(self, match_values, call, protocol):
+    def before_call(self, match_values, call, protocol):
         newname = self.getnewname(match_values)
         log.debug("Renaming %s to %s", call.fname, newname)
         protocol.change_fname(call, newname)
         self.oldname = oldname
         self.newname = newname
 
-    def precall(self, match_values, call, protocol):
+    def before_call(self, match_values, call, protocol):
         protocol.rename_argument(call, self.oldname, self.newname)
 
 
         self.argname = argname
         self.trans = trans
 
-    def precall(self, match_values, call, protocol):
+    def before_call(self, match_values, call, protocol):
         value = protocol.get_arg_value(call, self.argname)
         value = self.trans(value)
         protocol.set_arg_value(call, self.argname, value)
 
 
+class TransResultRule(Rule):
+    def __init__(self, trans):
+        self.trans = trans
+
+    def after_call(self, match_values, call, protocol):
+        value = protocol.get_return_value(call)
+        value = self.trans(value)
+        protocol.set_return_value(call, value)
+
+
 class RuleSet(object):
     def __init__(self, match_fname):
+        self._orig_match_fname = match_fname
         if isinstance(match_fname, str):
             if match_fname.startswith('re:'):
                 match_fname = re.compile(match_fname[3:])
         self.match_fname = match_fname
         self.rules = []
 
+    def append(self, rule):
+        self.rules.append(rule)
+
     def match(self, fname):
         if six.callable(self.match_fname):
             return self.match_fname(fname)
             return False, None
         return False, None
 
-    def precall(self, match_values, call, protocol):
+    def before_call(self, match_values, call, protocol):
         for rule in self.rules:
-            rule.precall(match_values, call, protocol)
+            rule.before_call(match_values, call, protocol)
 
-    def postcall(self, match_values, call, protocol):
+    def after_call(self, match_values, call, protocol):
         for rule in self.rules:
-            rule.postcall(match_values, call, protocol)
+            rule.after_call(match_values, call, protocol)
 
 
 class RequestSettings(object):
         self.protocols = []
         self.rulesets = []
 
+    def _add_ruleset(self, fname_match):
+        if not self.rulesets or \
+                self.rulesets[-1]._orig_match_fname != fname_match:
+            self.rulesets.append(RuleSet(fname_match))
+        return self.rulesets[-1]
+
     def add_rename_function(self, fname_match, trans):
-        self.rulesets.append(RuleSet(fname_match))
-        self.rulesets[-1].rules.append(RenameFunctionRule(trans))
+        self._add_ruleset(fname_match).append(RenameFunctionRule(trans))
 
     def add_rename_argument(self, fname_match, oldname, newname):
-        self.rulesets.append(RuleSet(fname_match))
-        self.rulesets[-1].rules.append(RenameArgumentRule(oldname, newname))
+        self._add_ruleset(fname_match).append(
+            RenameArgumentRule(oldname, newname)
+        )
 
     def add_trans_argument(self, fname_match, argname, trans):
-        self.rulesets.append(RuleSet(fname_match))
-        self.rulesets[-1].rules.append(TransArgumentRule(argname, trans))
+        self._add_ruleset(fname_match).append(
+            TransArgumentRule(argname, trans)
+        )
 
-    def get_rulesets(self, call):
+    def add_trans_result(self, fname_match, trans):
+        self._add_ruleset(fname_match).append(TransResultRule(trans))
+
+    def get_rulesets(self, fname):
         rulesets = []
         for ruleset in self.rulesets:
-            match, match_values = ruleset.match(call.fname)
+            match, match_values = ruleset.match(fname)
             if match:
                 rulesets.append((ruleset, match_values))
         return rulesets
 
         for call in context.protocol.iter_calls(context, request):
             call.context = weakref.proxy(context)
-            call.rulesets = self.get_rulesets(call)
+            call.rulesets = self.get_rulesets(call.fname)
             context.calls.append(call)
 
         for call in context.calls:
             for ruleset, match_values in call.rulesets:
-                ruleset.precall(match_values, call, context.protocol)
+                ruleset.before_call(match_values, call, context.protocol)
 
         context.protocol.patch_request(context, request)
 
         return context
 
     def morph_response(self, context, response):
+        context.response = response
         for call in context.calls:
             call.response = response
 
         for call in context.calls:
             for ruleset, match_values in call.rulesets:
-                ruleset.postcall(match_values, call, context.protocol)
+                ruleset.after_call(match_values, call, context.protocol)
+
+        context.protocol.patch_response(context, response)
 
     def _handle_request(self, request, app):
         context = self.morph_request(request)

File wsmorph/tests/test_all.py

View file
  • Ignore whitespace
         )
         value = self.post_xml('/v1/misc/multiply', a=2, b=4)
         assert value == '10'
+
+    def test_trans_result_json(self):
+        def inc(value):
+            return int(value) - 1
+        self.morph.add_trans_result(
+            'misc/multiply', inc
+        )
+        value = self.get_json('/v1/misc/multiply', a=2, b=4)
+        assert value == 7, value
+
+    def test_trans_result_xml(self):
+        def inc(value):
+            return int(value) - 1
+        self.morph.add_trans_result(
+            'misc/multiply', inc
+        )
+        value = self.get_xml('/v1/misc/multiply', a=2, b=4)
+        assert value == '7', value