Snippets

Anteru Device proc address loader

Created by Anteru
#!/usr/bin/env python3
# SPDX-License-Identifier: BSD-2-Clause

from xml.etree import ElementTree
import argparse
import sys
from collections import OrderedDict

def FindDeviceDispatchableTypes (tree):
    '''Return a set of all types that are children of VkDevice (including VkDevice
    itself.)'''
    # We search for all types where the category = handle
    handleTypes = tree.findall ('./types/type[@category="handle"]')

    # Ordered dict for determinism
    typeParents = OrderedDict ()

    # for each handle type, we will store the type as the key, and the set of
    # the parents as the value
    for handleType in handleTypes:
        # if it's an alias, we just duplicate
        if 'alias' in handleType.attrib:
            name = handleType.get ('name')
            alias = handleType.get ('alias')

            # This assumes aliases come after the actual type,
            # which is true for vk.xml
            typeParents [name] = typeParents [alias]
        else:
            name = handleType.find ('name').text
            parent = handleType.get ('parent')

            # There can be more than one parent
            if parent:
                typeParents [name] = set (parent.split (','))
            else:
                typeParents [name] = set ()

    def IsVkDeviceOrDerivedFromVkDevice (handleType, typeParents):
        '''Check if VkDevice shows up in the parent list (or if the type itself
        is VkDevice).'''
        if handleType == 'VkDevice':
            return True
        else:
            parents = typeParents [handleType]
            if parents is None:
                return False
            else:
                # If we derive from VkDevice through any path, we're set
                return any ([IsVkDeviceOrDerivedFromVkDevice (parent, typeParents) for parent in parents])

    deviceTypes = {t for t in typeParents.keys () if IsVkDeviceOrDerivedFromVkDevice (t, typeParents)}

    return deviceTypes

def FindAllDeviceFunctions (tree, deviceTypes):
    '''Find all device functions where the first parameter is in the
    ``deviceTypes`` set. Returns a list of functions, each function consisting
    of a dictionary with three entries:

      - ``name`` is the function name
      - ``return_type`` is the return type
      - ``parameters`` is the list of parameters, i.e. ``param_type name``
    '''
    functions = []

    for command in tree.findall ('./commands/command'):
        parameters = command.findall ('param')
        if parameters:
            firstParameter = parameters [0]
            if firstParameter.find ('type').text in deviceTypes:
                function = {
                    'return_type' : command.find ('proto/type').text,
                    'name' : command.find ('proto/name').text,
                    'parameters' : []
                }

                for parameter in parameters:
                    # This flattens ``<param>const <type>T</type> <name>N</name></param>``
                    # to ``const T N``
                    function ['parameters'].append (''.join (parameter.itertext ()))

                functions.append (function)

    return functions

def GetFunctionProtection (tree):
    '''Return a dictionary of protected functions, and the definitions guarding them.'''
    # Find all extensions which have some protection set
    extensions = tree.findall (f'./extensions/extension[@protect]')

    result = {}

    for extension in extensions:
        protection = extension.get ('protect').split (',')
        for command in extension.findall ('./require/command[@name]'):
            result [command.get ('name')] = protection

    return result

def GenerateHeader (tree, functions, protection, outputStream):
    '''Generate a header based on the provided functions.'''
    import hashlib
    def Write (s=''):
        print (s, file=outputStream)

    # Same tree will always result in the same hash
    includeUuid = hashlib.sha256(ElementTree.tostring (tree)).hexdigest().upper ()

    Write (f'#ifndef VK_DIRECT_{includeUuid}')
    Write (f'#define VK_DIRECT_{includeUuid} 1')
    Write ()
    Write ('#include <vulkan/vulkan.h>')
    Write ()

    Write ('struct VkDirect')
    Write ('{')

    def UnpackFunction (function):
        return (function ['name'], function ['return_type'], function ['parameters'])

    for function in functions:
        name, return_type, parameters = UnpackFunction (function)

        if name == 'vkGetDeviceProcAddr':
            continue

        protect = protection.get (name, None)

        if protect:
            Write (f'#ifdef {" && ".join (protect)}')

        Write (f'\tusing FT_{name} = {return_type} ({", ".join (parameters)});')
        Write (f'\tFT_{name}* {name} = nullptr;')
        if protect:
            Write ('#endif')
        Write ()

    Write ('\tvoid Bind (VkDevice device)')
    Write ('\t{')
    for function in functions:
        name, return_type, parameters = UnpackFunction (function)

        if name == 'vkGetDeviceProcAddr':
            continue

        protect = protection.get (name, None)

        if protect:
            Write (f'#ifdef {" && ".join (protect)}')

        Write (f'\t\t{name} = (FT_{name}*)vkGetDeviceProcAddr (device, "{name}");')
        if protect:
            Write ('#endif')

    Write ('\t}')
    Write ('};')
    Write ()
    Write ('#endif')

if __name__ == '__main__':
    parser = argparse.ArgumentParser ()
    parser.add_argument ('specfile', type=argparse.FileType ('r'),
        default=sys.stdin, help='A Vulkan specification file (typically vk.xml)')
    parser.add_argument ('output', type=argparse.FileType ('w'),
        default=sys.stdout, help='Output file path')

    args = parser.parse_args ()

    document = ElementTree.parse (args.specfile)
    tree = document.getroot ()

    deviceTypes = FindDeviceDispatchableTypes (tree)
    functions = FindAllDeviceFunctions (tree, deviceTypes)
    protection = GetFunctionProtection (tree)

    GenerateHeader (tree, functions, protection, args.output)

Comments (0)