Snippets

Anteru Device proc address loader

Created by Anteru

File vkdirect.py Added

  • Ignore whitespace
  • Hide word diff
+#!/usr/bin/env python3
+# SPDX-License-Identifier: BSD-2-Clause
+
+from xml.etree import ElementTree
+import argparse
+import sys
+from collections import OrderedDict
+
+def FindDeviceDispatchableTypes (tree):
+    '''Return a set of all types that are children of VkDevice (including VkDevice
+    itself.)'''
+    # We search for all types where the category = handle
+    handleTypes = tree.findall ('./types/type[@category="handle"]')
+
+    # Ordered dict for determinism
+    typeParents = OrderedDict ()
+
+    # for each handle type, we will store the type as the key, and the set of
+    # the parents as the value
+    for handleType in handleTypes:
+        # if it's an alias, we just duplicate
+        if 'alias' in handleType.attrib:
+            name = handleType.get ('name')
+            alias = handleType.get ('alias')
+
+            # This assumes aliases come after the actual type,
+            # which is true for vk.xml
+            typeParents [name] = typeParents [alias]
+        else:
+            name = handleType.find ('name').text
+            parent = handleType.get ('parent')
+
+            # There can be more than one parent
+            if parent:
+                typeParents [name] = set (parent.split (','))
+            else:
+                typeParents [name] = set ()
+
+    def IsVkDeviceOrDerivedFromVkDevice (handleType, typeParents):
+        '''Check if VkDevice shows up in the parent list (or if the type itself
+        is VkDevice).'''
+        if handleType == 'VkDevice':
+            return True
+        else:
+            parents = typeParents [handleType]
+            if parents is None:
+                return False
+            else:
+                # If we derive from VkDevice through any path, we're set
+                return any ([IsVkDeviceOrDerivedFromVkDevice (parent, typeParents) for parent in parents])
+
+    deviceTypes = {t for t in typeParents.keys () if IsVkDeviceOrDerivedFromVkDevice (t, typeParents)}
+
+    return deviceTypes
+
+def FindAllDeviceFunctions (tree, deviceTypes):
+    '''Find all device functions where the first parameter is in the
+    ``deviceTypes`` set. Returns a list of functions, each function consisting
+    of a dictionary with three entries:
+
+      - ``name`` is the function name
+      - ``return_type`` is the return type
+      - ``parameters`` is the list of parameters, i.e. ``param_type name``
+    '''
+    functions = []
+
+    for command in tree.findall ('./commands/command'):
+        parameters = command.findall ('param')
+        if parameters:
+            firstParameter = parameters [0]
+            if firstParameter.find ('type').text in deviceTypes:
+                function = {
+                    'return_type' : command.find ('proto/type').text,
+                    'name' : command.find ('proto/name').text,
+                    'parameters' : []
+                }
+
+                for parameter in parameters:
+                    # This flattens ``<param>const <type>T</type> <name>N</name></param>``
+                    # to ``const T N``
+                    function ['parameters'].append (''.join (parameter.itertext ()))
+
+                functions.append (function)
+
+    return functions
+
+def GetFunctionProtection (tree):
+    '''Return a dictionary of protected functions, and the definitions guarding them.'''
+    # Find all extensions which have some protection set
+    extensions = tree.findall (f'./extensions/extension[@protect]')
+
+    result = {}
+
+    for extension in extensions:
+        protection = extension.get ('protect').split (',')
+        for command in extension.findall ('./require/command[@name]'):
+            result [command.get ('name')] = protection
+
+    return result
+
+def GenerateHeader (tree, functions, protection, outputStream):
+    '''Generate a header based on the provided functions.'''
+    import hashlib
+    def Write (s=''):
+        print (s, file=outputStream)
+
+    # Same tree will always result in the same hash
+    includeUuid = hashlib.sha256(ElementTree.tostring (tree)).hexdigest().upper ()
+
+    Write (f'#ifndef VK_DIRECT_{includeUuid}')
+    Write (f'#define VK_DIRECT_{includeUuid} 1')
+    Write ()
+    Write ('#include <vulkan/vulkan.h>')
+    Write ()
+
+    Write ('struct VkDirect')
+    Write ('{')
+
+    def UnpackFunction (function):
+        return (function ['name'], function ['return_type'], function ['parameters'])
+
+    for function in functions:
+        name, return_type, parameters = UnpackFunction (function)
+
+        if name == 'vkGetDeviceProcAddr':
+            continue
+
+        protect = protection.get (name, None)
+
+        if protect:
+            Write (f'#ifdef {" && ".join (protect)}')
+
+        Write (f'\tusing FT_{name} = {return_type} ({", ".join (parameters)});')
+        Write (f'\tFT_{name}* {name} = nullptr;')
+        if protect:
+            Write ('#endif')
+        Write ()
+
+    Write ('\tvoid Bind (VkDevice device)')
+    Write ('\t{')
+    for function in functions:
+        name, return_type, parameters = UnpackFunction (function)
+
+        if name == 'vkGetDeviceProcAddr':
+            continue
+
+        protect = protection.get (name, None)
+
+        if protect:
+            Write (f'#ifdef {" && ".join (protect)}')
+
+        Write (f'\t\t{name} = (FT_{name}*)vkGetDeviceProcAddr (device, "{name}");')
+        if protect:
+            Write ('#endif')
+
+    Write ('\t}')
+    Write ('};')
+    Write ()
+    Write ('#endif')
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser ()
+    parser.add_argument ('specfile', type=argparse.FileType ('r'),
+        default=sys.stdin, help='A Vulkan specification file (typically vk.xml)')
+    parser.add_argument ('output', type=argparse.FileType ('w'),
+        default=sys.stdout, help='Output file path')
+
+    args = parser.parse_args ()
+
+    document = ElementTree.parse (args.specfile)
+    tree = document.getroot ()
+
+    deviceTypes = FindDeviceDispatchableTypes (tree)
+    functions = FindAllDeviceFunctions (tree, deviceTypes)
+    protection = GetFunctionProtection (tree)
+
+    GenerateHeader (tree, functions, protection, args.output)