import copy import os from collections import OrderedDict import yaml from torchgen.code_template import CodeTemplate from yaml.constructor import ConstructorError from yaml.nodes import MappingNode try: from yaml import CLoader as Loader except ImportError: from yaml import Loader # type: ignore[misc] # https://gist.github.com/pypt/94d747fe5180851196eb class UniqueKeyLoader(Loader): def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] if not isinstance(node, MappingNode): raise ConstructorError( None, None, "expected a mapping node, but found %s" % node.id, node.start_mark, ) mapping = {} for key_node, value_node in node.value: key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call] try: hash(key) except TypeError as e: raise ConstructorError( "while constructing a mapping", node.start_mark, "found unacceptable key ", key_node.start_mark, ) from e # check for duplicate keys if key in mapping: raise ConstructorError( "while constructing a mapping", node.start_mark, "found duplicate key", key_node.start_mark, ) value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call] mapping[key] = value return mapping class GLSLGenerator(object): standard_header = """ #version 450 core #define PRECISION $precision #define FORMAT $format """ def __init__(self): # type: ignore[no-untyped-def] self.ops_template_params = {} def add_params_yaml(self, parameters_yaml_file): # type: ignore[no-untyped-def] all_template_params = OrderedDict() with open(parameters_yaml_file, "r") as f: contents = yaml.load(f, Loader=UniqueKeyLoader) for key in contents: all_template_params[key] = contents[key] self.validate_and_construct_op_params(all_template_params) # type: ignore[no-untyped-call] def validate_and_construct_op_params(self, all_template_params): # type: ignore[no-untyped-def] for op in all_template_params: if op in self.ops_template_params: raise KeyError(f"{op} params file has already been parsed") op_params_default_vals = all_template_params[op][ "parameter_names_with_default_values" ] template_params_set = set(op_params_default_vals.keys()) self.ops_template_params[op] = [] self.ops_template_params[op].append(op_params_default_vals) op_template_params_values = all_template_params[op]["parameter_values"] for param_vals in op_template_params_values: param_vals_set = set(param_vals.keys()) invalid_keys = param_vals_set - template_params_set if (len(invalid_keys)) > 0: raise KeyError(f"Invalid keys {invalid_keys} are found") param_vals_copy = copy.deepcopy(op_params_default_vals) for key in param_vals: param_vals_copy[key] = param_vals[key] self.ops_template_params[op].append(param_vals_copy) def generate(self, glsl_template_in, out_dir): # type: ignore[no-untyped-def] glsl_template_name = os.path.basename(glsl_template_in) op_name, extension_name = glsl_template_name.split(".") if extension_name != "glslt": raise TypeError(f"invalid file type for glsl template {extension_name}") if op_name not in self.ops_template_params: raise KeyError(f"{op_name} params have not been populated") code_template = CodeTemplate.from_file(glsl_template_in) for template_params in self.ops_template_params[op_name]: content = GLSLGenerator.standard_header param_vals_string = "x".join([str(i) for i in template_params.values()]) output_file_name = op_name + "_" + param_vals_string + ".glsl" content += code_template.substitute(template_params) output_file = os.path.join(out_dir, output_file_name) with open(output_file, "w") as f: f.write(content) # Remove this if __name__ == "__main__": pass