Commits

Anonymous committed 37d6029

Added support for schema imports

Comments (0)

Files changed (1)

                 detail = detail.text
             raise Fault(method.location, method.name, code, string, detail)
 
+    @property
+    def _resolver(self):
+        return self
+
+    def _find(self, type_):
+        return self.wsdl.resolve(type_, allow_ref=False)
+
 
 class Fault(Exception):
     """
         # Similar to AnyType()(element), but just finds the
         # class. AnyType() calls cls(), which if done here,
         # results in obj.__init__() getting called twice.
-        if cls._abstract and cls._client and element is not None:
+        if cls._abstract and cls._resolver and element is not None:
             valtype = xsi_type(element)
             if valtype:
-                cls = cls._client.wsdl.resolve(valtype, allow_ref=False)
+                # FIXME this doesn't work for generated clients
+                cls = cls._resolver._find(valtype)
         return object.__new__(cls)
 
     def __init__(self, element=None, **kw):
             for child in self._children:
                 if not hasattr(child.type, '_substitutions'):
                     continue
+                if not child.type._substitutions:
+                    continue
                 for sub_name, sub in child.type._substitutions.items():
                     subs[sub_name] = (sub, child.name)
             setattr(self.__class__, key, subs)
 
     def _items(self):
         items = []
-        for a in self._attributes + self._children:
+        for a in itertools.chain(self._attributes, self._children):
             # access the internal key not the real attribute to
             # avoid autovivification
             val = getattr(self, "_%s_" % a.name, None)
                 header_fmt.append((name, cls(val)))
             # complex headers
             hkids = getattr(cls, '_children', ())
-            hkids = hkids + getattr(cls, '_attributes', ())
+            hkids = itertools.chain(hkids, getattr(cls, '_attributes', ()))
             hkw = {}
             for kcls in hkids:
                 val = kw.pop(kcls.name, None)
         'positiveInteger': IntType,
         'short': IntType,
         'unsignedInt': IntType,
+        'unsignedByte': IntType,
+        'unsignedLong': LongType,
+        'unsignedShort': IntType,
+        'char': StringType,
         'long': LongType,
         'byte': StringType,
         'double': FloatType,
         'hexBinary': StringType,
         # FIXME: probably timedelta, but needs parsing.
         # It looks like P29DT23H54M58S
-        'duration': StringType
+        'duration': StringType,
+        'QName': StringType,
+        'ID': StringType,
+        'IDREF': StringType,
         })
 
     _simple_tag = '{%s}simpleType' % NS_XSD
     _list_tag = '{%s}list' % NS_XSD
     _types_tag = '{%s}types' % NS_WSDL
     _schema_tag = '{%s}schema' % NS_XSD
+    _import_tag = '{%s}import' % NS_XSD
     _cplx_type_tag = '{%s}complexType' % NS_XSD
 
     def __init__(self, wsdl_file):
         self.wsdl = etree.parse(wsdl_file).getroot()
         self.nsmap = NSStack(self.wsdl)
+        self._imports = {}
         self._lock = RLock()
         self._lock.acquire()
         try:
         for child in schema.getchildren():
             if self._is_type(child):
                 self._process_type(client, child)
+            elif self._is_import(child):
+                self._process_import(client, child)
         self.nsmap.pop_schema()
 
     def _process_type(self, client, child):
                 typel = child.get('type', None)
             if typel is None:
                 raise ValueError("Could not find type for element %s" % child)
-            self._make_class(client, typel, name=name, force_name=True)
+            self._make_class(client, typel, name=name, force_name=True,
+                             name_from=child)
         else:
             name = child.get('name', None)
             if name is None:
                 return
             self._make_class(client, child, name=name, force_name=False)
 
+    def _process_import(self, client, child):
+        # IMPORTS CAN BE CIRCULAR! BEWARE!
+        url = child.get('schemaLocation')
+        if not url:
+            return
+        if url in self._imports:
+            log.debug("Already imported %s", url)
+            return
+        schema = etree.parse(urlopen(url)).getroot()
+        self._imports[url] = schema
+        self._process_schema(client, schema)
+
     def _resolve_refs(self):
         for client_type, name in self._refs:
             ref = getattr(client_type, name)
             header_parts.append((h_name, self._make_class(client, h_type)))
         return OutputMessage(name, namespace, parts, header_parts)
 
-    def _make_class(self, client, type_, name=None, force_name=False):
+    def _make_class(self, client, type_, name=None, force_name=False,
+                    name_from=None):
         client_type = client.type
         ref = None
 
             # catch known types
             key = local_attr(type_)
             if key in self._typemap:
-                # FIXME this fails for instances of types that
+                # FIXME (both branches) this fails for instances of types that
                 # have their own namespaces
+                if force_name and name != key:
+                    # handle simple rename-only "subclasses" of existing types
+                    cls = self._make_subclass(client, name, name_from,
+                                              self._typemap[key])
+                    self._typemap[name] = cls
+                    setattr(client_type, name, cls)
+                    return cls
                 return self._typemap[key]
             if key == 'anyType':
                 # anyType must be bound to client, since
             if etype is None:
                 ref, etype = self._find_element_ref(type_)
             type_ = etype
+
+        if name is None:
+            name = type_.get('name')
+
+        # can't build a class without a name
+        if name is None:
+            return
+
         # this case handles weirdness like an element with
         # one namespace having a type attrib that points
         # to an element with another namespace.
             parent = ref.getparent()
         else:
             parent = type_.getparent()
-        if self._is_schema(parent):
-            schema = Schema(parent)
-            namespace = schema.targetNamespace
-        else:
-            schema = self.nsmap.top()
-            namespace = schema.targetNamespace
-
-        if name is None:
-            name = type_.get('name')
-
-        # can't build a class without a name
-        if name is None:
-            return
+        schema, namespace = self._find_schema(parent)
 
         # in cache?
         cls = self._typemap.get(name, None)
             if not hasattr(client_type, name):
                 setattr(client_type, name, cls)
             if ref is not None:
+                # FIXME this is similar to _make_subclass: consolidate?
                 # need to subclass, set ref's name, namespace and schema
                 # because the referring element wraps the type
                 rname = ref.get('name')
-                cls = type(rname, (cls,), {'_tag': ref.get('name'),
+                cls = type(rname, (cls,), {'_tag': rname,
                                            '_namespace': namespace,
                                            '_schema': schema})
                 self._typemap[rname] = cls
                 '_namespace': namespace,
                 '_schema': schema,
                 '_client': client,
+                '_resolver': client,
                 'xsd_type': (namespace, name)}
         # this handles cases where the xml element name is always
         # forced to the name in the wsdl
     def _is_schema(self, element):
         return element.tag == self._schema_tag
 
+    def _is_import(self, element):
+        return element.tag == self._import_tag
+
     def _is_type(self, element):
         return element.tag in (
             self._element_tag, self._cplx_type_tag, self._simple_tag)
                 return True
         return False
 
+    def _find_schema(self, element):
+        if self._is_schema(element):
+            schema = Schema(element)
+            namespace = schema.targetNamespace
+        else:
+            schema = self.nsmap.top()
+            namespace = schema.targetNamespace
+        return schema, namespace
+
     def _make_enum(self, element, client, namespace, name):
         vals = []
         data = {}
         cls = type(name, (Pickleable, base_cls,), data)
         return cls
 
+    def _make_subclass(self, client, name, element, base):
+        schema, namespace = self._find_schema(element)
+        data = {'xsd_type': (namespace, name),
+                '_tag': name,
+                '_client': client,
+                '_namespace': namespace,
+                '_schema': schema}
+        if Pickleable in base.__mro__:
+            bases = (base,)
+        else:
+            bases = (Pickleable, base)
+        cls = type(name, bases, data)
+        return cls
+
     def _children(self, element):
         # FIXME for element.tag == _all_tag,
         # force minOccurs -> 0, maxOccurs -> 1
                     yield se
 
     def _find_type(self, name):
-        name = local_attr(name)
-        # print "Find definition for class %s" % name
-        type_nodes = itertools.chain(
-                self.wsdl.findall(
-                    ".//%s[@name='%s']" % (self._cplx_type_tag, name)),
-                self.wsdl.findall(
-                    ".//%s[@name='%s']" % (self._simple_tag, name)))
-        for node in type_nodes:
-            return node # return first node found, if any
+        lname = local_attr(name)
+        sources = [self.wsdl] + self._imports.values()
+        for source in sources:
+            type_nodes = itertools.chain(
+                    source.findall(
+                        ".//%s[@name='%s']" % (self._cplx_type_tag, lname)),
+                    source.findall(
+                        ".//%s[@name='%s']" % (self._simple_tag, lname)))
+            for node in type_nodes:
+                return node # return first node found, if any
 
     def _find_element_ref(self, name):
-        name = local_attr(name)
-        elem_refs = self.wsdl.findall(
-            ".//%s[@name='%s']" % (self._element_tag, name))
-        if elem_refs:
-            # FIXME elem_refs[0] contains information about
-            # the eventual wanted type -- its name and namespace
-            # will actually wrap the wanted type. Ack.
-            type_name = elem_refs[0].get('type')
-            return elem_refs[0], self._find_type(type_name)
-        raise Exception("Could not find definition of class %s" % name)
+        sources = [self.wsdl] + self._imports.values()
+        for source in sources:
+            elem_refs = source.findall(
+                ".//%s[@name='%s']" % (self._element_tag, local_attr(name)))
+            if elem_refs:
+                # elem_refs[0] contains information about
+                # the eventual wanted type -- its name and namespace
+                # will actually wrap the wanted type.
+                type_name = elem_refs[0].get('type')
+                if type_name:
+                    return elem_refs[0], self._find_type(type_name)
+                else:
+                    # self-contained type declaration
+                    return elem_refs[0], elem_refs[0][0]
+        raise UnknownType("Could not find definition of class %s" % name)
 
     def _find_enumerations(self, element):
         restr = self._find_restriction(element)
 
 def backmap(dct):
     return dict(zip(dct.values(), dct.keys()))
+
+
+class UnknownType(TypeError):
+    pass