diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index 644069395e1..38ab7743d82 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -1,8 +1,6 @@ -from __future__ import annotations - import re from dataclasses import dataclass -from typing import cast, Sequence +from typing import cast, Dict, List, Match, Optional, Sequence, Set, Tuple from torchgen import local from torchgen.api import cpp @@ -50,16 +48,16 @@ class Derivative: original_formula: str # Names of the arguments for which this formula calculates derivatives. - var_names: tuple[str, ...] + var_names: Tuple[str, ...] # Saved inputs that are referenced by the formula. - saved_inputs: tuple[SavedAttribute, ...] + saved_inputs: Tuple[SavedAttribute, ...] # Saved outputs that are referenced by the formula. - saved_outputs: tuple[SavedAttribute, ...] + saved_outputs: Tuple[SavedAttribute, ...] # Gradients that are referenced by name in the formula. - named_gradients: set[str] + named_gradients: Set[str] # Represents a forward formula that calculates forward derivatives @@ -73,17 +71,17 @@ class ForwardDerivative: # Name of the output arguments for which this formula calculates forward # derivatives - var_names: tuple[str, ...] + var_names: Tuple[str, ...] # Type of the output arguments for which this formula calculates forward # derivatives - var_types: tuple[Type, ...] + var_types: Tuple[Type, ...] # Inputs for which the forward derivatives are required for this formula - required_inputs_fw_grad: tuple[str, ...] | None + required_inputs_fw_grad: Optional[Tuple[str, ...]] # Inputs for which the primal is required for this formula - required_inputs_primal: tuple[str, ...] | None + required_inputs_primal: Optional[Tuple[str, ...]] # Flag to specify if this formula requires the original value of self # This is only used by inplace operations @@ -118,7 +116,7 @@ class DifferentiabilityInfo: # The name of the generated autograd function. # It's set only if we will calculate a derivative, i.e. # 'args_with_derivatives' is not empty. - op: str | None + op: Optional[str] # The derivatives formulae for this function. # Note that the length of this sequence is the number of differentiable inputs @@ -140,7 +138,7 @@ class DifferentiabilityInfo: # The named gradients that are used in any of the derivatives. # Invariant: all(name in available_named_gradients for name in used_named_gradients) - used_named_gradients: set[str] + used_named_gradients: Set[str] # The function's input arguments for which it calculates derivatives. # It's the union of 'var_names' of all 'derivatives', sorted by the @@ -151,7 +149,7 @@ class DifferentiabilityInfo: non_differentiable_arg_names: Sequence[str] # Raw data read from derivatives.yaml. - output_differentiability: list[bool] | None + output_differentiability: Optional[List[bool]] # output_differentiability in derivatives.yaml can be a list of # conditions that express if the output is differentiable. In this case, @@ -159,7 +157,7 @@ class DifferentiabilityInfo: # (NB: we only support one condition right now). # output_differentiability gets populated with True for each condition, # while output_differentiability_conditions gets populated with the conditions - output_differentiability_conditions: list[str] | None + output_differentiability_conditions: Optional[List[str]] @property def has_derivatives(self) -> bool: @@ -172,7 +170,7 @@ class DifferentiabilityInfo: # See Note [Codegen'd {view}_copy Operators] def create_view_copy_from_view_derivative( self, g: NativeFunctionsViewGroup - ) -> DifferentiabilityInfo | None: + ) -> Optional["DifferentiabilityInfo"]: if g.view_copy is None: return None f = g.view_copy @@ -203,7 +201,7 @@ class DifferentiabilityInfo: ) -def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool: +def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool: if info is None: return False for derivative in info.derivatives: @@ -213,11 +211,11 @@ def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool: return False -def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool: +def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool: return uses_ident(info, "retain_variables") -def uses_single_grad(info: DifferentiabilityInfo | None) -> bool: +def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool: return uses_ident(info, "grad") @@ -255,8 +253,8 @@ class DifferentiableOutput: @dataclass(frozen=True) class NativeFunctionWithDifferentiabilityInfo: func: NativeFunction - info: dict[str, DifferentiabilityInfo] | None - fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None + info: Optional[Dict[str, DifferentiabilityInfo]] + fw_derivatives: Optional[Dict[str, Sequence[ForwardDerivative]]] # TODO: Update comment below since it is out of date. @@ -365,19 +363,19 @@ def is_reference_for_foreach( # TODO(crcrpar): Avoid hard coding "Default" ideally. def gen_foreach_derivativeinfo( foreach_function: NativeFunction, - functional_info_by_signature: dict[ - FunctionSchema, dict[str, DifferentiabilityInfo] + functional_info_by_signature: Dict[ + FunctionSchema, Dict[str, DifferentiabilityInfo] ], - non_functional_info_by_signature: dict[ - FunctionSchema, dict[str, DifferentiabilityInfo] + non_functional_info_by_signature: Dict[ + FunctionSchema, Dict[str, DifferentiabilityInfo] ], dispatch_key: str = "Default", -) -> tuple[DifferentiabilityInfo | None, bool]: +) -> Tuple[Optional[DifferentiabilityInfo], bool]: """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place. The second return value indicates whether the info is generated in this function. """ - ref_diff_info: DifferentiabilityInfo | None = None + ref_diff_info: Optional[DifferentiabilityInfo] = None for function_schema, diff_info in functional_info_by_signature.items(): if not is_reference_for_foreach(foreach_function, function_schema): @@ -487,13 +485,13 @@ def gen_foreach_derivativeinfo( if arg.name in all_var_names ] - forward_derivatives: list[ForwardDerivative] = [] + forward_derivatives: List[ForwardDerivative] = [] fw_derivative: ForwardDerivative for fw_derivative in ref_diff_info.forward_derivatives: - var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef] - var_types: list[Type] = list(fw_derivative.var_types) - required_inputs_fw_grad: list[str] = [] - required_inputs_primal: list[str] = [] + var_names: List[str] = list(fw_derivative.var_names) # type: ignore[no-redef] + var_types: List[Type] = list(fw_derivative.var_types) + required_inputs_fw_grad: List[str] = [] + required_inputs_primal: List[str] = [] if fw_derivative.required_inputs_fw_grad is not None: required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad) if fw_derivative.required_inputs_primal: @@ -580,9 +578,9 @@ def gen_foreach_derivativeinfo( def match_differentiability_info( - native_functions: list[NativeFunction], - differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], -) -> list[NativeFunctionWithDifferentiabilityInfo]: + native_functions: List[NativeFunction], + differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], +) -> List[NativeFunctionWithDifferentiabilityInfo]: """Sets the "derivative" key on declarations to matching autograd function In-place functions will use the out-of-place derivative definition if there is no in-place specific derivative. @@ -601,7 +599,7 @@ def match_differentiability_info( def find_info( f: NativeFunction, - ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]: + ) -> Tuple[Optional[Dict[str, DifferentiabilityInfo]], bool]: # Don't bother matching info to generated out= variants if "generated" in f.tags and f.func.kind() == SchemaKind.out: return None, False @@ -655,7 +653,7 @@ Attempted to convert a derivative formula for a mutable operator return None, False - result: list[NativeFunctionWithDifferentiabilityInfo] = [] + result: List[NativeFunctionWithDifferentiabilityInfo] = [] for f in native_functions: info_dict, is_exact_match = find_info(f) @@ -679,7 +677,7 @@ Attempted to convert a derivative formula for a mutable operator ) continue - fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {} + fw_derivative_dict: Dict[str, Sequence[ForwardDerivative]] = {} for key, info in info_dict.items(): if not info.forward_derivatives: fw_derivative_dict[key] = [] @@ -715,7 +713,7 @@ Attempted to convert a derivative formula for a mutable operator formula = fw_info.formula def replace_self_with_original_self(formula: str, postfix: str) -> str: - def repl(m: re.Match[str]) -> str: + def repl(m: Match[str]) -> str: return f"{m.group(1)}original_self{postfix}{m.group(2)}" return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula) @@ -736,7 +734,7 @@ Attempted to convert a derivative formula for a mutable operator formula = replace_self_with_original_self(formula, "_t") # replace "result" from the formula by "self_p" - def repl(m: re.Match[str]) -> str: + def repl(m: Match[str]) -> str: return f"{m.group(1)}self_p{m.group(2)}" formula = re.sub(IDENT_REGEX.format("result"), repl, formula) @@ -760,8 +758,8 @@ Attempted to convert a derivative formula for a mutable operator # If there is a need, we can relax (2) to allow any op that has an in-place variant is_single_method_on_self_t = False directly_do_inplace = False - op_name: str | None = None - between_parens: str | None = None + op_name: Optional[str] = None + between_parens: Optional[str] = None match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula) if match: op_name, between_parens = match.group(1), match.group(2) @@ -825,7 +823,7 @@ Attempted to convert a derivative formula for a mutable operator def is_differentiable( - name: str, type: Type, info: DifferentiabilityInfo | None + name: str, type: Type, info: Optional[DifferentiabilityInfo] ) -> bool: return type.is_tensor_like() and ( info is None or name not in info.non_differentiable_arg_names @@ -834,10 +832,10 @@ def is_differentiable( def gen_differentiable_outputs( fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" -) -> list[DifferentiableOutput]: +) -> List[DifferentiableOutput]: f = fn.func info = fn.info[key] if fn.info else None - outputs: list[DifferentiableOutput] = [ + outputs: List[DifferentiableOutput] = [ DifferentiableOutput( name=name, type=ret.type, @@ -852,7 +850,7 @@ def gen_differentiable_outputs( f"The length of output_differentiability ({len(output_differentiability)}), " f"does not match the number of outputs ({len(outputs)})." ) - differentiable_outputs: list[DifferentiableOutput] = [] + differentiable_outputs: List[DifferentiableOutput] = [] if False in output_differentiability and f.func.kind() == SchemaKind.inplace: raise RuntimeError( "output_differentiability=False for inplace operation (version_counter won't get updated)" diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index c709364e92b..0e9d67375c7 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -1,6 +1,4 @@ -from __future__ import annotations - -from typing import Sequence +from typing import List, Optional, Sequence, Set, Union from torchgen import local from torchgen.api.types import ( @@ -96,7 +94,7 @@ def valuetype_type( binds: ArgName, remove_non_owning_ref_types: bool = False, symint: bool = False, -) -> NamedCType | None: +) -> Optional[NamedCType]: if isinstance(t, BaseType): if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar: return None @@ -281,7 +279,7 @@ def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType: def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]: - returns: list[str] = [] + returns: List[str] = [] for i, r in enumerate(f.func.returns): # If we have an inplace function, the return argument is # implicitly named self. @@ -370,17 +368,17 @@ def default_expr(d: str, t: Type, *, symint: bool) -> str: def argument( - a: Argument | TensorOptionsArguments | SelfArgument, + a: Union[Argument, TensorOptionsArguments, SelfArgument], *, - cpp_no_default_args: set[str], + cpp_no_default_args: Set[str], method: bool, faithful: bool, symint: bool = False, has_tensor_options: bool, -) -> list[Binding]: +) -> List[Binding]: def sub_argument( - a: Argument | TensorOptionsArguments | SelfArgument, - ) -> list[Binding]: + a: Union[Argument, TensorOptionsArguments, SelfArgument] + ) -> List[Binding]: return argument( a, cpp_no_default_args=cpp_no_default_args, @@ -396,7 +394,7 @@ def argument( binds = SpecialArgName.possibly_redundant_memory_format else: binds = a.name - default: str | None = None + default: Optional[str] = None if a.name not in cpp_no_default_args and a.default is not None: default = default_expr(a.default, a.type, symint=symint) return [ @@ -447,9 +445,9 @@ def arguments( faithful: bool, symint: bool = False, method: bool, - cpp_no_default_args: set[str], -) -> list[Binding]: - args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + cpp_no_default_args: Set[str], +) -> List[Binding]: + args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] if faithful: args.extend(arguments.non_out) args.extend(arguments.out) diff --git a/torchgen/api/dispatcher.py b/torchgen/api/dispatcher.py index 103e6cf4299..aa3c97b2d34 100644 --- a/torchgen/api/dispatcher.py +++ b/torchgen/api/dispatcher.py @@ -1,7 +1,5 @@ -from __future__ import annotations - import itertools -from typing import Sequence +from typing import List, Sequence, Union from torchgen.api import cpp from torchgen.api.types import ArgName, Binding, CType, NamedCType @@ -78,10 +76,10 @@ def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType: return cpp.returns_type(rs, symint=symint) -def jit_arguments(func: FunctionSchema) -> list[Argument]: +def jit_arguments(func: FunctionSchema) -> List[Argument]: def to_argument( - a: Argument | TensorOptionsArguments | SelfArgument, - ) -> list[Argument]: + a: Union[Argument, TensorOptionsArguments, SelfArgument] + ) -> List[Argument]: if isinstance(a, Argument): return [a] elif isinstance(a, SelfArgument): @@ -116,5 +114,5 @@ def argument( ) -def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]: +def arguments(func: FunctionSchema, *, symint: bool = True) -> List[Binding]: return [argument(a, symint=symint) for a in jit_arguments(func)] diff --git a/torchgen/api/functionalization.py b/torchgen/api/functionalization.py index 93667e39b17..cc492588e60 100644 --- a/torchgen/api/functionalization.py +++ b/torchgen/api/functionalization.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import List, Optional from torchgen.api import dispatcher from torchgen.api.types import ( @@ -93,7 +93,7 @@ def name( *, is_reverse: bool, include_namespace: bool, - reapply_views: bool | None = None, + reapply_views: Optional[bool] = None, ) -> str: if reapply_views is None: # reapply_views is only important for the fwd lambda, @@ -124,7 +124,7 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str: return f"{api_name}_inverse" -def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]: +def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding]: # capture arguments include all arguments except `self`. # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), # So any reference types (IntArrayRef) need to be converted to value types (vector) @@ -152,14 +152,14 @@ def returns_type(func: FunctionSchema) -> CType: return BaseCType(tensorT) -def outer_arguments(*, is_reverse: bool) -> list[Binding]: +def outer_arguments(*, is_reverse: bool) -> List[Binding]: if is_reverse: return [base_binding, mutated_view_binding, mutated_view_idx_binding] else: return [base_binding, mutated_view_idx_binding] -def inner_call_index(func: FunctionSchema) -> Binding | None: +def inner_call_index(func: FunctionSchema) -> Optional[Binding]: # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. # When we replay a view op that returns multiple tensors, we need to index into the output appropriately if len(func.returns) > 1 or ( @@ -169,7 +169,7 @@ def inner_call_index(func: FunctionSchema) -> Binding | None: return None -def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: +def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]: args = func.arguments.flat_all assert args[0].type == BaseType(BaseTy.Tensor) non_self_args = args[1:] diff --git a/torchgen/api/lazy.py b/torchgen/api/lazy.py index cfffa516b65..166c2fc8b53 100644 --- a/torchgen/api/lazy.py +++ b/torchgen/api/lazy.py @@ -1,6 +1,4 @@ -from __future__ import annotations - -from typing import Any +from typing import Any, Dict, List, Optional, Tuple, Union from torchgen.api.types import ( BaseCppType, @@ -36,7 +34,7 @@ from torchgen.model import ( ) -_valueT: BaseCppType | None = None +_valueT: Optional[BaseCppType] = None # A ValueT is an IR type which represents the computation of a Tensor. In other @@ -68,8 +66,8 @@ tensorListValueT = BaseCppType("torch::lazy", "Value") def process_ir_type( - typ: Type, properties: LazyIrProperties, *, symint: bool -) -> BaseCType | VectorCType | OptionalCType | ListCType: + typ: Type, properties: "LazyIrProperties", *, symint: bool +) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]: """ This function takes a type from NativeFunctions and converts it for use with lazy tensor codegen. @@ -149,7 +147,7 @@ def process_ir_type( # # Invariant: passed typ should be an *owning* CType (e.g., we will report # that ArrayRef is NOT a value type) -def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool: +def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool: """ Given a type, determine if it is a Value-like type. This is equivalent to being Tensor-like, but assumes the type has already been transformed. @@ -204,7 +202,7 @@ def isGeneratorType(typ: Type) -> bool: class LazyArgument: name: str orig_type: Type - lazy_type_: CType | None + lazy_type_: Optional[CType] is_wrapped_scalar: bool is_generator: bool # TODO: this is lies, it is false for symint list @@ -216,9 +214,7 @@ class LazyArgument: # true if this argument is or contains a lazy IR value is_lazy_value: bool - def __init__( - self, arg: Argument, properties: LazyIrProperties, *, symint: bool - ) -> None: + def __init__(self, arg: Argument, properties: "LazyIrProperties", *, symint: bool): self.name = arg.name self.orig_type = arg.type self.symint = symint @@ -252,7 +248,7 @@ class LazyIrProperties: attributes. The mutual exclusivity is automatically handled. """ - Properties: tuple[tuple[str, ...], ...] = ( + Properties: Tuple[Tuple[str, ...], ...] = ( ( "ShapePrecompute", # Assume shape has been precomputed "ShapeCompute", # Need to compute the shape on construction @@ -275,8 +271,8 @@ class LazyIrProperties: ), ) - def __init__(self, *default_properties: str) -> None: - properties: dict[tuple[str, ...], str | None] = dict.fromkeys( + def __init__(self, *default_properties: str): + properties: Dict[Tuple[str, ...], Optional[str]] = dict.fromkeys( LazyIrProperties.Properties ) self.__dict__["properties"] = properties @@ -309,17 +305,17 @@ class LazyIrProperties: # TODO: This is not idiomatic with how other torchgen APIs transform on schema. class LazyIrSchema: # The name of the operator this function schema describes. - name: OperatorName + name: "OperatorName" - positional_args: tuple[LazyArgument, ...] - keyword_args: tuple[LazyArgument, ...] + positional_args: Tuple[LazyArgument, ...] + keyword_args: Tuple[LazyArgument, ...] # TODO: Need to handle collisions with argument names at some point - returns: tuple[Return, ...] + returns: Tuple["Return", ...] # if this schema has a Generator arg, list its orig ctype/name but don't # build a LazyArgument since lazy IR doesn't support it - generator_arg: NamedCType | None = None + generator_arg: Optional[NamedCType] = None # original function schema func: FunctionSchema @@ -333,21 +329,21 @@ class LazyIrSchema: "Lower", "CanBeReused", ) - opkind: str | None = None + opkind: Optional[str] = None def __init__( self, func: FunctionSchema, - properties: LazyIrProperties | None = None, + properties: Optional[LazyIrProperties] = None, *, symint: bool, - ) -> None: + ): if properties: self.properties = properties self.func = func self.symint = symint - positional_args: list[LazyArgument] = [] + positional_args: List[LazyArgument] = [] for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]: if arg_field == "self_arg" and func.arguments.self_arg is not None: arg = func.arguments.self_arg.argument @@ -361,7 +357,7 @@ class LazyIrSchema: ) self.positional_args = tuple(positional_args) - keyword_args: list[LazyArgument] = [] + keyword_args: List[LazyArgument] = [] for arg_field in [ "pre_tensor_options_kwarg_only", "tensor_options", @@ -415,13 +411,13 @@ class LazyIrSchema: values: bool = True, scalars: bool = True, generator: bool = True, - ) -> list[LazyArgument]: + ) -> List[LazyArgument]: # This function maintains the sorted order of arguments but provides different filtered views. # Some parts of the code care about kwargs vs args (TS lowerings), # other parts care about whether they need to wrap the arg in a lazy value or leave it alone. # Generators are special cased, as they are needed for fallback/shape-inference but not supported # in TS lowerings and therefore also omitted from lazy IR. - args: list[LazyArgument] = [] + args: List[LazyArgument] = [] if positional: args.extend(self.positional_args) if keyword: @@ -443,25 +439,25 @@ class LazyIrSchema: return [] @property - def positional_values(self) -> list[LazyArgument]: + def positional_values(self) -> List[LazyArgument]: return self.filtered_args( positional=True, keyword=False, values=True, scalars=False ) @property - def positional_scalars(self) -> list[LazyArgument]: + def positional_scalars(self) -> List[LazyArgument]: return self.filtered_args( positional=True, keyword=False, values=False, scalars=True ) @property - def keyword_values(self) -> list[LazyArgument]: + def keyword_values(self) -> List[LazyArgument]: return self.filtered_args( positional=False, keyword=True, values=True, scalars=False ) @property - def keyword_scalars(self) -> list[LazyArgument]: + def keyword_scalars(self) -> List[LazyArgument]: return self.filtered_args( positional=False, keyword=True, values=False, scalars=True ) diff --git a/torchgen/api/native.py b/torchgen/api/native.py index a00e8266b8d..df06b539d5e 100644 --- a/torchgen/api/native.py +++ b/torchgen/api/native.py @@ -1,6 +1,4 @@ -from __future__ import annotations - -from typing import Sequence +from typing import List, Optional, Sequence, Union from torchgen import local from torchgen.api import cpp @@ -83,11 +81,11 @@ def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType: def argument( - a: Argument | SelfArgument | TensorOptionsArguments, + a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool, symint: bool, -) -> list[Binding]: +) -> List[Binding]: # Ideally, we NEVER default native functions. However, there are a number # of functions that call native:: directly and rely on the defaulting # existing. So for BC, we generate defaults for non-out variants (but not @@ -95,7 +93,7 @@ def argument( # default) should_default = not is_out if isinstance(a, Argument): - default: str | None = None + default: Optional[str] = None if should_default and a.default is not None: default = cpp.default_expr(a.default, a.type, symint=symint) return [ @@ -146,8 +144,8 @@ def argument( assert_never(a) -def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]: - args: list[Argument | TensorOptionsArguments | SelfArgument] = [] +def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]: + args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] args.extend(func.arguments.non_out) args.extend(func.arguments.out) return [ diff --git a/torchgen/api/python.py b/torchgen/api/python.py index aa942572e98..2026c40f08b 100644 --- a/torchgen/api/python.py +++ b/torchgen/api/python.py @@ -1,7 +1,5 @@ -from __future__ import annotations - from dataclasses import dataclass -from typing import Sequence +from typing import Dict, List, Optional, Sequence, Set, Tuple, Union from torchgen.api import cpp from torchgen.api.types import Binding, CppSignature, CppSignatureGroup @@ -199,14 +197,14 @@ from torchgen.model import ( @dataclass(frozen=True) class PythonReturns: - returns: tuple[Return, ...] + returns: Tuple[Return, ...] @dataclass(frozen=True) class PythonArgument: name: str type: Type - default: str | None + default: Optional[str] # Used to generate the default init expr for some PythonArgParser outputs, e.g.: # @@ -214,7 +212,7 @@ class PythonArgument: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ^ # +--- default_init str - default_init: str | None + default_init: Optional[str] # Compute argument formal for python argument parsing. # Needs to be consistent with torch/csrc/utils/python_arg_parser.h. @@ -302,10 +300,12 @@ class PythonOutArgument(PythonArgument): # 'auto out = _r.tensorlist_n<2>(2);', # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc. # TODO: maybe don't need keep scattered out fields for python signature? - outputs: tuple[PythonArgument, ...] + outputs: Tuple[PythonArgument, ...] @staticmethod - def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None: + def from_outputs( + outputs: Tuple[PythonArgument, ...] + ) -> Optional["PythonOutArgument"]: if not outputs: return None @@ -339,13 +339,13 @@ class PythonSignature: # Positional arguments. # TODO: create a dedicated SelfArgument type for 'self'? - input_args: tuple[PythonArgument, ...] + input_args: Tuple[PythonArgument, ...] # Keyword arguments excluding the 'out' argument and scattered kwargs belonging # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc). - input_kwargs: tuple[PythonArgument, ...] + input_kwargs: Tuple[PythonArgument, ...] - output_args: PythonOutArgument | None + output_args: Optional[PythonOutArgument] # Return types, which are only used by pyi returns: PythonReturns @@ -356,7 +356,7 @@ class PythonSignature: # for out variant), in which case they will be used as scattered fields without # being packed into 'options'. # TODO: maybe create a PythonTensorOptionsArgument? - tensor_options_args: tuple[PythonArgument, ...] + tensor_options_args: Tuple[PythonArgument, ...] # method or function signature? method: bool @@ -367,8 +367,8 @@ class PythonSignature: def arguments( self, *, skip_outputs: bool = False, skip_tensor_options: bool = False - ) -> tuple[PythonArgument | PythonOutArgument, ...]: - result: list[PythonArgument | PythonOutArgument] = [] + ) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]: + result: List[Union[PythonArgument, PythonOutArgument]] = [] result.extend(self.input_args) result.extend(self.input_kwargs) if self.output_args is not None and not skip_outputs: @@ -394,7 +394,7 @@ class PythonSignature: # signature_str_pyi(). def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: args = self.arguments(skip_outputs=skip_outputs) - schema_formals: list[str] = [ + schema_formals: List[str] = [ a.argument_str(method=self.method, symint=symint) for a in args ] positional_argc = len(self.input_args) @@ -405,7 +405,7 @@ class PythonSignature: def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: args = self.arguments(skip_outputs=skip_outputs) - schema_formals: list[str] = [ + schema_formals: List[str] = [ a.argument_str_pyi(method=self.method) for a in args ] positional_argc = len(self.input_args) @@ -419,10 +419,10 @@ class PythonSignature: schema_formals.insert(0, "self") return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' - def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]: # only pyi uses vararg signatures args = self.arguments(skip_outputs=skip_outputs) - schema_formals: list[str] = [ + schema_formals: List[str] = [ a.argument_str_pyi(method=self.method) for a in args ] # vararg only applies to pyi signatures. vararg variants are not generated for all signatures @@ -470,7 +470,7 @@ class PythonSignatureDeprecated(PythonSignature): # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor # [func call]: self.addmm(mat1, mat2, beta, 1) # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case. - deprecated_args_exprs: tuple[str, ...] + deprecated_args_exprs: Tuple[str, ...] @property def deprecated(self) -> bool: @@ -486,7 +486,7 @@ class PythonSignatureDeprecated(PythonSignature): def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: args = self.arguments(skip_outputs=skip_outputs) - schema_formals: list[str] = [ + schema_formals: List[str] = [ a.argument_str_pyi(method=self.method, deprecated=True) for a in args ] positional_argc = len(self.input_args) @@ -496,7 +496,7 @@ class PythonSignatureDeprecated(PythonSignature): returns_str = returns_str_pyi(self) return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' - def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]: # the codegen doesn't include vararg variants for deprecated signatures return None @@ -530,14 +530,14 @@ class PythonSignatureGroup: base: NativeFunction # The out variant (e.g. conv2d_out) - outplace: NativeFunction | None + outplace: Optional[NativeFunction] @classmethod def from_pairs( cls, functional: PythonSignatureNativeFunctionPair, - out: PythonSignatureNativeFunctionPair | None, - ) -> PythonSignatureGroup: + out: Optional[PythonSignatureNativeFunctionPair], + ) -> "PythonSignatureGroup": if out is None: return PythonSignatureGroup( signature=functional.signature, @@ -716,7 +716,7 @@ def argument_type_str( raise RuntimeError(f"unrecognized type {repr(t)}") -def argument_type_size(t: Type) -> int | None: +def argument_type_size(t: Type) -> Optional[int]: l = t.is_list_like() if l is not None and str(l.elem) != "bool": return l.size @@ -750,11 +750,11 @@ def signature( def signature_from_schema( func: FunctionSchema, *, - category_override: str | None, + category_override: Optional[str], method: bool = False, pyi: bool = False, ) -> PythonSignature: - args: list[Argument] = [] + args: List[Argument] = [] args.extend(func.arguments.pre_self_positional) # Skip SelfArgument if this is method. if not method and func.arguments.self_arg is not None: @@ -807,10 +807,10 @@ def signature_from_schema( ) is_dummy_function = category_override == "dummy" - tensor_options_args: list[PythonArgument] = [] + tensor_options_args: List[PythonArgument] = [] if (is_factory_function or is_like_or_new_function) and not is_dummy_function: - def topt_default_init(name: str) -> str | None: + def topt_default_init(name: str) -> Optional[str]: topt_args = func.arguments.tensor_options if topt_args is None: return None @@ -891,7 +891,7 @@ def signature_from_schema( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]: +def structseq_fieldnames(returns: Tuple[Return, ...]) -> List[str]: if len(returns) <= 1 or all(r.name is None for r in returns): return [] else: @@ -1002,7 +1002,7 @@ def return_type_str_pyi(t: Type) -> str: return argument_type_str_pyi(t) -def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None: +def returns_structseq_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]: python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] structseq_name = signature.name field_names = structseq_fieldnames(signature.returns.returns) @@ -1104,7 +1104,7 @@ def returns_str_pyi(signature: PythonSignature) -> str: def dispatch_lambda_args( ps: PythonSignature, f: NativeFunction, symint: bool = True -) -> tuple[DispatchLambdaArgument, ...]: +) -> Tuple[DispatchLambdaArgument, ...]: if isinstance(ps, PythonSignatureDeprecated): schema = ps.deprecated_schema else: @@ -1118,7 +1118,7 @@ def dispatch_lambda_args( method=False, cpp_no_default_args=f.cpp_no_default_args, ) - out_args: set[str] = {a.name for a in schema.arguments.out} + out_args: Set[str] = {a.name for a in schema.arguments.out} # Convert from cpp argument to lambda argument def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument: @@ -1224,11 +1224,11 @@ def cpp_dispatch_target(f: NativeFunction) -> str: def cpp_dispatch_exprs( f: NativeFunction, *, - python_signature: PythonSignature | None = None, -) -> tuple[str, ...]: + python_signature: Optional[PythonSignature] = None, +) -> Tuple[str, ...]: cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments() - exprs: tuple[str, ...] = tuple() + exprs: Tuple[str, ...] = tuple() if not isinstance(python_signature, PythonSignatureDeprecated): # By default the exprs are consistent with the C++ signature. exprs = tuple(a.name for a in cpp_args) @@ -1262,7 +1262,7 @@ def cpp_dispatch_exprs( # For certain cases it is intentionally more restrictive than necessary, # e.g.: it doesn't accepts doublelist with definite size. def arg_parser_unpack_method( - t: Type, default: str | None, default_init: str | None, *, symint: bool = True + t: Type, default: Optional[str], default_init: Optional[str], *, symint: bool = True ) -> str: has_default_init = default_init is not None if has_default_init and str(t) not in ( @@ -1377,7 +1377,7 @@ def arg_parser_output_expr( # Returns a map with key = arg_name and value = PythonArgParserOutputExpr. def arg_parser_output_exprs( ps: PythonSignature, f: NativeFunction, *, symint: bool = True -) -> dict[str, PythonArgParserOutputExpr]: +) -> Dict[str, PythonArgParserOutputExpr]: return { e.name: e for i, a in enumerate(ps.arguments()) @@ -1404,8 +1404,8 @@ def dispatch_lambda_exprs( # outputs. arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) lambda_args = dispatch_lambda_args(ps, f, symint=symint) - inits: list[str] = [] - lambda_args_exprs: dict[str, str] = {} + inits: List[str] = [] + lambda_args_exprs: Dict[str, str] = {} has_toptions = has_tensor_options(f) diff --git a/torchgen/api/structured.py b/torchgen/api/structured.py index a93d666114d..e3be72189bb 100644 --- a/torchgen/api/structured.py +++ b/torchgen/api/structured.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import List, Union from torchgen.api import cpp from torchgen.api.types import ( @@ -97,7 +97,7 @@ def argument_type(a: Argument, *, binds: ArgName) -> NamedCType: # Structured kernels are never defaulted -def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]: +def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]: if isinstance(a, Argument): return [ Binding( @@ -115,15 +115,15 @@ def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Bindin assert_never(a) -def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]: - args: list[Argument | TensorOptionsArguments | SelfArgument] = [] +def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]: + args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] if g.out.precomputed: # A list of parameters for the impl function with # certain parameters replaced with precomputed counterparts # as specified in native_functions.yaml. - non_out_args_replaced: list[ - Argument | TensorOptionsArguments | SelfArgument + non_out_args_replaced: List[ + Union[Argument, TensorOptionsArguments, SelfArgument] ] = [] for a in g.out.func.arguments.non_out: if isinstance(a, Argument) and a.name in g.out.precomputed.replace: @@ -145,13 +145,13 @@ def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]: return [r for arg in args for r in argument(arg)] -def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]: - args: list[Argument | TensorOptionsArguments | SelfArgument] = [] +def meta_arguments(g: NativeFunctionsGroup) -> List[Binding]: + args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] args.extend(g.functional.func.arguments.non_out) return [r for arg in args for r in argument(arg)] -def out_arguments(g: NativeFunctionsGroup) -> list[Binding]: - args: list[Argument | TensorOptionsArguments | SelfArgument] = [] +def out_arguments(g: NativeFunctionsGroup) -> List[Binding]: + args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] args.extend(g.out.func.arguments.out) return [r for arg in args for r in argument(arg)] diff --git a/torchgen/api/translate.py b/torchgen/api/translate.py index 761fb3c7c2b..87fc3348b69 100644 --- a/torchgen/api/translate.py +++ b/torchgen/api/translate.py @@ -1,6 +1,4 @@ -from __future__ import annotations - -from typing import NoReturn, Sequence +from typing import Dict, List, NoReturn, Sequence, Union from torchgen.api.types import ( ArrayRefCType, @@ -97,13 +95,13 @@ class UnsatError(RuntimeError): # something more complicated, e.g., tracking the set of bindings in a context, # you may find using these smaller types more convenient. def translate( - bindings: Sequence[Expr | Binding], - goals: Sequence[NamedCType | Binding], + bindings: Sequence[Union[Expr, Binding]], + goals: Sequence[Union[NamedCType, Binding]], *, method: bool = False, allow_expensive_conversions: bool = False, -) -> list[Expr]: - binding_exprs: list[Expr] = [] +) -> List[Expr]: + binding_exprs: List[Expr] = [] for b in bindings: if isinstance(b, Binding): binding_exprs.append( @@ -115,7 +113,7 @@ def translate( else: binding_exprs.append(b) - goal_ctypes: list[NamedCType] = [] + goal_ctypes: List[NamedCType] = [] for g in goals: if isinstance(g, Binding): goal_ctypes.append(g.nctype) @@ -123,7 +121,7 @@ def translate( goal_ctypes.append(g) # Add all the bindings to the context - ctx: dict[NamedCType, str] = {} + ctx: Dict[NamedCType, str] = {} for b in binding_exprs: ctx[b.type] = b.expr diff --git a/torchgen/api/types/signatures.py b/torchgen/api/types/signatures.py index f7d85ca6e2f..0b7abe00012 100644 --- a/torchgen/api/types/signatures.py +++ b/torchgen/api/types/signatures.py @@ -1,19 +1,14 @@ -from __future__ import annotations - from dataclasses import dataclass -from typing import Iterator, Sequence, TYPE_CHECKING +from typing import Iterator, List, Optional, Sequence, Set, Tuple, Union from torchgen.api.types.types_base import Binding, CType, Expr - - -if TYPE_CHECKING: - from torchgen.model import ( - BackendIndex, - FunctionSchema, - NativeFunction, - NativeFunctionsGroup, - NativeFunctionsViewGroup, - ) +from torchgen.model import ( + BackendIndex, + FunctionSchema, + NativeFunction, + NativeFunctionsGroup, + NativeFunctionsViewGroup, +) @dataclass(frozen=True) @@ -43,7 +38,7 @@ class CppSignature: symint: bool # The set of C++ arguments which should not have defaults applied to them - cpp_no_default_args: set[str] + cpp_no_default_args: Set[str] # Is this a fallback C++ binding? Fallback bindings are enabled by # manual_cpp_binding: True and are alternate, non-public API that @@ -77,7 +72,7 @@ class CppSignature: def decl( self, *, - name: str | None = None, + name: Optional[str] = None, prefix: str = "", is_redispatching_fn: bool = False, suppress_symint_suffix: bool = False, @@ -98,7 +93,7 @@ class CppSignature: def defn( self, *, - name: str | None = None, + name: Optional[str] = None, prefix: str = "", is_redispatching_fn: bool = False, ) -> str: @@ -131,9 +126,9 @@ class CppSignature: class CppSignatureGroup: func: FunctionSchema signature: CppSignature - faithful_signature: CppSignature | None - symint_signature: CppSignature | None - symint_faithful_signature: CppSignature | None + faithful_signature: Optional[CppSignature] + symint_signature: Optional[CppSignature] + symint_faithful_signature: Optional[CppSignature] def most_faithful_signature(self) -> CppSignature: if self.faithful_signature: @@ -154,7 +149,7 @@ class CppSignatureGroup: @staticmethod def from_native_function( f: NativeFunction, *, method: bool, fallback_binding: bool = False - ) -> CppSignatureGroup: + ) -> "CppSignatureGroup": func = f.func def make_sig(*, faithful: bool, symint: bool) -> CppSignature: @@ -167,16 +162,16 @@ class CppSignatureGroup: cpp_no_default_args=f.cpp_no_default_args, ) - def make_sigs(*, symint: bool) -> tuple[CppSignature, CppSignature | None]: - faithful_signature: CppSignature | None = None + def make_sigs(*, symint: bool) -> Tuple[CppSignature, Optional[CppSignature]]: + faithful_signature: Optional[CppSignature] = None if func.arguments.tensor_options is not None or len(func.arguments.out) > 0: faithful_signature = make_sig(faithful=True, symint=symint) signature = make_sig(faithful=False, symint=symint) return signature, faithful_signature signature, faithful_signature = make_sigs(symint=False) - symint_signature: CppSignature | None = None - symint_faithful_signature: CppSignature | None = None + symint_signature: Optional[CppSignature] = None + symint_faithful_signature: Optional[CppSignature] = None if func.has_symint(): symint_signature, symint_faithful_signature = make_sigs(symint=True) @@ -201,20 +196,20 @@ class DispatcherSignature: symint: bool = True - def arguments(self) -> list[Binding]: + def arguments(self) -> List[Binding]: return dispatcher.arguments(self.func, symint=self.symint) def name(self) -> str: return self.prefix + dispatcher.name(self.func) - def decl(self, name: str | None = None) -> str: + def decl(self, name: Optional[str] = None) -> str: args_str = ", ".join(a.decl() for a in self.arguments()) if name is None: name = self.name() return f"{self.returns_type().cpp_type()} {name}({args_str})" def defn( - self, name: str | None = None, *, is_redispatching_fn: bool = False + self, name: Optional[str] = None, *, is_redispatching_fn: bool = False ) -> str: args = [a.defn() for a in self.arguments()] if is_redispatching_fn: @@ -224,7 +219,7 @@ class DispatcherSignature: name = self.name() return f"{self.returns_type().cpp_type()} {name}({args_str})" - def exprs(self) -> list[Expr]: + def exprs(self) -> List[Expr]: return [Expr(a.name, a.nctype) for a in self.arguments()] def returns_type(self) -> CType: @@ -242,7 +237,7 @@ class DispatcherSignature: @staticmethod def from_schema( func: FunctionSchema, *, prefix: str = "", symint: bool = True - ) -> DispatcherSignature: + ) -> "DispatcherSignature": return DispatcherSignature(func, prefix, symint) @@ -258,13 +253,13 @@ class NativeSignature: def name(self) -> str: return self.prefix + native.name(self.func) - def decl(self, name: str | None = None) -> str: + def decl(self, name: Optional[str] = None) -> str: args_str = ", ".join(a.decl() for a in self.arguments()) if name is None: name = self.name() return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})" - def defn(self, name: str | None = None) -> str: + def defn(self, name: Optional[str] = None) -> str: args_str = ", ".join(a.defn() for a in self.arguments()) if name is None: name = self.name() @@ -275,13 +270,13 @@ class NativeSignature: args_str = ", ".join(a.defn() for a in self.arguments()) return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})" - def arguments(self) -> list[Binding]: + def arguments(self) -> List[Binding]: return native.arguments(self.func, symint=self.symint) def returns_type(self) -> CType: return native.returns_type(self.func.returns, symint=self.symint) - def dispatcher_exprs(self) -> list[Expr]: + def dispatcher_exprs(self) -> List[Expr]: return translate.translate( self.arguments(), dispatcher.arguments(self.func), method=False ) @@ -312,7 +307,7 @@ class FunctionalizationLambda: # are we generating the forward lambda or the reverse lambda? is_reverse: bool - def captures(self) -> list[Expr]: + def captures(self) -> List[Expr]: # The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments # We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed, # and plumb it into the lambda. @@ -341,7 +336,7 @@ class FunctionalizationLambda: ] return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}" - def inner_call(self, *, reapply_views: bool | None = None) -> str: + def inner_call(self, *, reapply_views: Optional[bool] = None) -> str: inner_call_name = functionalization.name( self.g, is_reverse=self.is_reverse, @@ -371,7 +366,7 @@ class FunctionalizationLambda: @staticmethod def from_func( g: NativeFunctionsViewGroup, *, is_reverse: bool - ) -> FunctionalizationLambda: + ) -> "FunctionalizationLambda": return FunctionalizationLambda(g, is_reverse) @@ -380,11 +375,11 @@ class StructuredImplSignature: g: NativeFunctionsGroup name: str - def defn(self, name: str | None = None) -> str: + def defn(self, name: Optional[str] = None) -> str: args_str = ", ".join(a.defn() for a in self.arguments()) return f"TORCH_IMPL_FUNC({self.name})({args_str})" - def arguments(self) -> list[Binding]: + def arguments(self) -> List[Binding]: return structured.impl_arguments(self.g) @@ -393,7 +388,7 @@ class StructuredImplSignature: def kernel_signature( f: NativeFunction, backend_index: BackendIndex, *, prefix: str = "" -) -> NativeSignature | DispatcherSignature: +) -> Union["NativeSignature", "DispatcherSignature"]: # Note [External Backends Follow Dispatcher API] # Kernel signatures for in-tree backends follow the "native" API, # while kernels for out-of-tree backends follow the dispatcher API. diff --git a/torchgen/api/types/types.py b/torchgen/api/types/types.py index 30e027a6312..3f0a90c634f 100644 --- a/torchgen/api/types/types.py +++ b/torchgen/api/types/types.py @@ -12,10 +12,8 @@ if we want to generate code for another C++ library. Add new types to `types.py` if these types are ATen/c10 related. Add new types to `types_base.py` if they are basic and not attached to ATen/c10. """ - -from __future__ import annotations - from dataclasses import dataclass +from typing import Dict from torchgen.api.types.types_base import ( BaseCppType, @@ -85,7 +83,7 @@ symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef") scalar_t = BaseCppType("", "scalar_t") opmath_t = BaseCppType("", "opmath_t") -ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = { +ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = { ScalarType.Byte: byteT, ScalarType.Char: charT, ScalarType.Short: shortT, @@ -104,7 +102,7 @@ ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = { ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT, } -BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = { +BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = { BaseTy.int: longT, BaseTy.float: doubleT, BaseTy.bool: boolT, @@ -130,7 +128,7 @@ BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = { @dataclass(frozen=True) class OptionalCType(CType): - elem: CType + elem: "CType" def cpp_type(self, *, strip_ref: bool = False) -> str: # Do not pass `strip_ref` recursively. @@ -139,13 +137,13 @@ class OptionalCType(CType): def cpp_type_registration_declarations(self) -> str: return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>" - def remove_const_ref(self) -> CType: + def remove_const_ref(self) -> "CType": return OptionalCType(self.elem.remove_const_ref()) @dataclass(frozen=True) class ListCType(CType): - elem: CType + elem: "CType" def cpp_type(self, *, strip_ref: bool = False) -> str: # Do not pass `strip_ref` recursively. @@ -154,13 +152,13 @@ class ListCType(CType): def cpp_type_registration_declarations(self) -> str: return f"c10::List<{self.elem.cpp_type_registration_declarations()}>" - def remove_const_ref(self) -> CType: + def remove_const_ref(self) -> "CType": return ListCType(self.elem.remove_const_ref()) @dataclass(frozen=True) class ArrayRefCType(CType): - elem: CType + elem: "CType" def cpp_type(self, *, strip_ref: bool = False) -> str: # Do not pass `strip_ref` recursively. @@ -169,7 +167,7 @@ class ArrayRefCType(CType): def cpp_type_registration_declarations(self) -> str: return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>" - def remove_const_ref(self) -> CType: + def remove_const_ref(self) -> "CType": return ArrayRefCType(self.elem.remove_const_ref()) @@ -187,5 +185,5 @@ class VectorizedCType(CType): def cpp_type_registration_declarations(self) -> str: raise NotImplementedError - def remove_const_ref(self) -> CType: + def remove_const_ref(self) -> "CType": return self diff --git a/torchgen/api/types/types_base.py b/torchgen/api/types/types_base.py index e031b79485e..e59a4b3d820 100644 --- a/torchgen/api/types/types_base.py +++ b/torchgen/api/types/types_base.py @@ -12,17 +12,12 @@ if we want to generate code for another C++ library. Add new types to `types.py` if these types are ATen/c10 related. Add new types to `types_base.py` if they are basic and not attached to ATen/c10. """ - -from __future__ import annotations - from abc import ABC, abstractmethod from dataclasses import dataclass from enum import auto, Enum -from typing import TYPE_CHECKING, Union +from typing import List, Optional, Union - -if TYPE_CHECKING: - from torchgen.model import Argument, SelfArgument, TensorOptionsArguments +from torchgen.model import Argument, SelfArgument, TensorOptionsArguments # An ArgName is just the str name of the argument in schema; @@ -41,7 +36,7 @@ ArgName = Union[str, SpecialArgName] # This class shouldn't be created directly; instead, use/create one of the singletons below. @dataclass(frozen=True) class BaseCppType: - ns: str | None + ns: Optional[str] name: str def __str__(self) -> str: @@ -76,7 +71,7 @@ class CType(ABC): raise NotImplementedError @abstractmethod - def remove_const_ref(self) -> CType: + def remove_const_ref(self) -> "CType": return self @@ -92,13 +87,13 @@ class BaseCType(CType): def cpp_type_registration_declarations(self) -> str: return str(self.type).replace("at::", "") - def remove_const_ref(self) -> CType: + def remove_const_ref(self) -> "CType": return self @dataclass(frozen=True) class ConstRefCType(CType): - elem: CType + elem: "CType" def cpp_type(self, *, strip_ref: bool = False) -> str: if strip_ref: @@ -108,13 +103,13 @@ class ConstRefCType(CType): def cpp_type_registration_declarations(self) -> str: return f"const {self.elem.cpp_type_registration_declarations()} &" - def remove_const_ref(self) -> CType: + def remove_const_ref(self) -> "CType": return self.elem.remove_const_ref() @dataclass(frozen=True) class VectorCType(CType): - elem: CType + elem: "CType" def cpp_type(self, *, strip_ref: bool = False) -> str: # Do not pass `strip_ref` recursively. @@ -123,13 +118,13 @@ class VectorCType(CType): def cpp_type_registration_declarations(self) -> str: return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>" - def remove_const_ref(self) -> CType: + def remove_const_ref(self) -> "CType": return VectorCType(self.elem.remove_const_ref()) @dataclass(frozen=True) class ArrayCType(CType): - elem: CType + elem: "CType" size: int def cpp_type(self, *, strip_ref: bool = False) -> str: @@ -139,13 +134,13 @@ class ArrayCType(CType): def cpp_type_registration_declarations(self) -> str: return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>" - def remove_const_ref(self) -> CType: + def remove_const_ref(self) -> "CType": return ArrayCType(self.elem.remove_const_ref(), self.size) @dataclass(frozen=True) class TupleCType(CType): - elems: list[CType] + elems: List["CType"] def cpp_type(self, *, strip_ref: bool = False) -> str: # Do not pass `strip_ref` recursively. @@ -154,13 +149,13 @@ class TupleCType(CType): def cpp_type_registration_declarations(self) -> str: return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>' - def remove_const_ref(self) -> CType: + def remove_const_ref(self) -> "CType": return TupleCType([e.remove_const_ref() for e in self.elems]) @dataclass(frozen=True) class MutRefCType(CType): - elem: CType + elem: "CType" def cpp_type(self, *, strip_ref: bool = False) -> str: if strip_ref: @@ -170,7 +165,7 @@ class MutRefCType(CType): def cpp_type_registration_declarations(self) -> str: return f"{self.elem.cpp_type_registration_declarations()} &" - def remove_const_ref(self) -> CType: + def remove_const_ref(self) -> "CType": return self.elem.remove_const_ref() @@ -195,10 +190,10 @@ class NamedCType: def cpp_type_registration_declarations(self) -> str: return self.type.cpp_type_registration_declarations() - def remove_const_ref(self) -> NamedCType: + def remove_const_ref(self) -> "NamedCType": return NamedCType(self.name, self.type.remove_const_ref()) - def with_name(self, name: str) -> NamedCType: + def with_name(self, name: str) -> "NamedCType": return NamedCType(name, self.type) @@ -213,11 +208,11 @@ class NamedCType: class Binding: name: str nctype: NamedCType - argument: Argument | TensorOptionsArguments | SelfArgument + argument: Union[Argument, TensorOptionsArguments, SelfArgument] # TODO: maybe don't represent default here - default: str | None = None + default: Optional[str] = None - def rename(self, name: str) -> Binding: + def rename(self, name: str) -> "Binding": return Binding( name=name, nctype=self.nctype, @@ -229,7 +224,7 @@ class Binding: def type(self) -> str: return self.nctype.cpp_type() - def no_default(self) -> Binding: + def no_default(self) -> "Binding": return Binding( name=self.name, nctype=self.nctype, @@ -260,7 +255,7 @@ class Binding: def defn(self) -> str: return f"{self.type} {self.name}" - def with_name(self, name: str) -> Binding: + def with_name(self, name: str) -> "Binding": return Binding( name=name, nctype=self.nctype, argument=self.argument, default=self.default ) diff --git a/torchgen/api/ufunc.py b/torchgen/api/ufunc.py index 17adcccecab..7981c2b29d7 100644 --- a/torchgen/api/ufunc.py +++ b/torchgen/api/ufunc.py @@ -1,6 +1,5 @@ -from __future__ import annotations - from dataclasses import dataclass +from typing import List, Optional import torchgen.api.types as api_types from torchgen.api import cpp, structured @@ -39,7 +38,7 @@ def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str: # argument registers) # # NB: used for CPU only -def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None: +def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]: # Dispatch stubs are always plain ints r = cpp.valuetype_type(t, binds=binds, symint=False) if r is not None: @@ -135,8 +134,8 @@ def ufunc_argument(a: Argument, compute_t: CType) -> Binding: @dataclass(frozen=True) class UfunctorBindings: - ctor: list[Binding] - apply: list[Binding] + ctor: List[Binding] + apply: List[Binding] # ufunctors are a CUDA-only concept representing functors that take some of @@ -157,7 +156,7 @@ class UfunctorBindings: # The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers # to the operator() definition def ufunctor_arguments( - g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType + g: NativeFunctionsGroup, *, scalar_tensor_idx: Optional[int], scalar_t: BaseCppType ) -> UfunctorBindings: ctor = [] apply = [] @@ -186,7 +185,7 @@ def ufunctor_arguments( # } # # In this file, we refer to T as compute_t which is bound by caller -def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]: +def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Binding]: return [ ufunc_argument(a, compute_t=compute_t) for a in g.functional.func.arguments.flat_non_out @@ -198,7 +197,7 @@ def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Bindin # # using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha); # DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub); -def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]: +def stub_arguments(g: NativeFunctionsGroup) -> List[Binding]: # stubs drop all tensor arguments (they are implicit in the TensorIterator # argument and keep everything else) return [ diff --git a/torchgen/api/unboxing.py b/torchgen/api/unboxing.py index 1e649b75178..70128b1845b 100644 --- a/torchgen/api/unboxing.py +++ b/torchgen/api/unboxing.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import List, Tuple from torchgen.api import cpp from torchgen.api.types import Binding, CppSignatureGroup, CType @@ -103,7 +103,7 @@ def name(f: NativeFunction) -> str: # Convert all the arguments in a NativeFunction to C++ code -def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]: +def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]: # we need the 'self' argument so method needs to be False args = ( CppSignatureGroup.from_native_function(f, method=False) @@ -138,7 +138,7 @@ def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]: # (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType def argumenttype_ivalue_convert( t: Type, arg_name: str, *, mutable: bool = False -) -> tuple[str, CType, list[str], list[str]]: +) -> Tuple[str, CType, List[str], List[str]]: # Unboxing is for mobile, which doesn't care about SymInts ctype = cpp.argumenttype_type( t=t, mutable=mutable, binds=arg_name, symint=False @@ -172,7 +172,7 @@ def argumenttype_ivalue_convert( def _gen_code_base_type( arg_name: str, out_name: str, ctype: CType -) -> tuple[list[str], list[str]]: +) -> Tuple[List[str], List[str]]: return [ f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" ], [] @@ -180,7 +180,7 @@ def _gen_code_base_type( def _gen_code_optional_type( arg_name: str, out_name: str, t: OptionalType, ctype: CType -) -> tuple[list[str], list[str]]: +) -> Tuple[List[str], List[str]]: in_name = f"{arg_name}_opt_in" res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name) return ( @@ -203,7 +203,7 @@ if ({arg_name}_opt.has_value()) {{ def _gen_code_list_type( arg_name: str, out_name: str, t: ListType, ctype: CType -) -> tuple[list[str], list[str]]: +) -> Tuple[List[str], List[str]]: in_name = f"{arg_name}_list_in" elem_name = f"{arg_name}_elem" code = [f"const c10::List {in_name} = {arg_name}.toList();"] diff --git a/torchgen/code_template.py b/torchgen/code_template.py index cdb86a48064..b4afde2d7be 100644 --- a/torchgen/code_template.py +++ b/torchgen/code_template.py @@ -1,7 +1,5 @@ -from __future__ import annotations - import re -from typing import Mapping, Sequence +from typing import Mapping, Match, Optional, Sequence # match $identifier or ${identifier} and replace with value in env @@ -22,7 +20,7 @@ class CodeTemplate: filename: str @staticmethod - def from_file(filename: str) -> CodeTemplate: + def from_file(filename: str) -> "CodeTemplate": with open(filename) as f: return CodeTemplate(f.read(), filename) @@ -31,7 +29,7 @@ class CodeTemplate: self.filename = filename def substitute( - self, env: Mapping[str, object] | None = None, **kwargs: object + self, env: Optional[Mapping[str, object]] = None, **kwargs: object ) -> str: if env is None: env = {} @@ -45,7 +43,7 @@ class CodeTemplate: [indent + l + "\n" for e in v for l in str(e).splitlines()] ).rstrip() - def replace(match: re.Match[str]) -> str: + def replace(match: Match[str]) -> str: indent = match.group(1) key = match.group(2) comma_before = "" diff --git a/torchgen/context.py b/torchgen/context.py index a2031049816..40e765a97ec 100644 --- a/torchgen/context.py +++ b/torchgen/context.py @@ -1,8 +1,6 @@ -from __future__ import annotations - import contextlib import functools -from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union import torchgen.local as local from torchgen.model import ( @@ -40,7 +38,7 @@ F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction]) @contextlib.contextmanager def native_function_manager( - g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction, + g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction] ) -> Iterator[None]: if isinstance(g, NativeFunctionsGroup): # By default, we associate all errors with structured native functions @@ -120,10 +118,10 @@ def with_native_function_and_index( # Convenience decorator for functions that explicitly take in a Dict of BackendIndices def with_native_function_and_indices( - func: Callable[[F, dict[DispatchKey, BackendIndex]], T] -) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]: + func: Callable[[F, Dict[DispatchKey, BackendIndex]], T] +) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]: @functools.wraps(func) - def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T: + def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T: with native_function_manager(f): return func(f, backend_indices) diff --git a/torchgen/dest/lazy_ir.py b/torchgen/dest/lazy_ir.py index 976c823a165..9cd3dd419fc 100644 --- a/torchgen/dest/lazy_ir.py +++ b/torchgen/dest/lazy_ir.py @@ -1,9 +1,7 @@ -from __future__ import annotations - import itertools from abc import ABC from dataclasses import dataclass -from typing import Any +from typing import Any, Dict, List, Optional, Tuple, Union import torchgen.api.dispatcher as dispatcher from torchgen.api.lazy import ( @@ -111,7 +109,7 @@ def node_ctor_inputs(schema: LazyIrSchema) -> str: def gen_fallback_code( schema: LazyIrSchema, - sig: DispatcherSignature | NativeSignature, + sig: Union[DispatcherSignature, NativeSignature], overload_name: str, ) -> str: """ @@ -149,9 +147,9 @@ def aten_symbol(schema: LazyIrSchema) -> str: # converts all tensor-like arguments to meta tensors. Returns: # (1) a string containing all of the logic that does the conversions. # (2) a context, to be used by translate(), with all of the relevant bindings. -def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]: - context: list[Binding] = [] - unwrapped_tensor_args: list[str] = [] +def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]: + context: List[Binding] = [] + unwrapped_tensor_args: List[str] = [] for arg in sig.arguments(): if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like(): unwrapped_name = f"{arg.name}_meta" @@ -173,7 +171,7 @@ class GenLazyIR(ABC): use_lazy_shape: bool @method_with_native_function - def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]: + def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func metadata = self.backend_index.get_kernel( f.functional if isinstance(f, NativeFunctionsGroup) else f @@ -238,7 +236,7 @@ class GenLazyIR(ABC): /* num_outputs */ {len(schema.returns)}, torch::lazy::MHash({scalar_hashes}))""" - def gen(self, schema: LazyIrSchema) -> list[str]: + def gen(self, schema: LazyIrSchema) -> List[str]: opkind = schema.opkind or aten_symbol(schema) # for now, we just want one IR class decl and soon after also the method defs @@ -415,7 +413,7 @@ class GenLazyNativeFuncDefinition: def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str: value_args = schema.filtered_args(values=True, scalars=False) # Generates lazy_{name} variables for LazyTensors wrapping input tensors - lazy_tensor_decls: list[str] = [] + lazy_tensor_decls: List[str] = [] for arg in value_args: if arg.is_wrapped_scalar: if isinstance(arg.lazy_type, OptionalCType): @@ -462,7 +460,7 @@ class GenLazyNativeFuncDefinition: func: NativeFunction, schema: LazyIrSchema, metadata: BackendMetadata, - sig: DispatcherSignature | NativeSignature, + sig: Union[DispatcherSignature, NativeSignature], ) -> str: if self.gen_forced_fallback_code: return gen_fallback_code( @@ -576,7 +574,7 @@ std::vector shapes{torch::lazy::Shape(out_meta.scalar_type() }} """ - def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str: + def create_lazy_tensor(self, first_tensor_name: Optional[str] = None) -> str: # xla uses an instance method for tensor creation, for the time being if self.create_from_first_tensor: # TODO(whc) remove this if XLA switches to using static method for creation @@ -617,7 +615,7 @@ std::vector shapes{torch::lazy::Shape(out_meta.scalar_type() return bridge_str @method_with_native_function - def __call__(self, func: NativeFunction) -> list[str]: + def __call__(self, func: NativeFunction) -> List[str]: sig = kernel_signature(func, self.backend_index) metadata = self.backend_index.get_kernel(func) assert metadata is not None @@ -641,7 +639,7 @@ class ComputeShapeSignature: Here we use the base name as the suffix of the signature to avoid generating for in-place variants. """ - def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None: + def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool): self.__schema = LazyIrSchema(f.func, symint=symint) self.__dispatch_args = ", ".join( [a.decl() for a in dispatcher.arguments(f.func, symint=symint)] @@ -672,7 +670,7 @@ class GenLazyShapeInferenceDefinition: tensor_class: str @method_with_native_function - def __call__(self, f: NativeFunction) -> list[str]: + def __call__(self, f: NativeFunction) -> List[str]: metadata = self.backend_index.get_kernel(f) assert metadata is not None @@ -689,8 +687,8 @@ class GenLazyShapeInferenceDefinition: def generate_non_native_lazy_ir_nodes( - non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR -) -> list[str]: + non_native: List[Dict[str, Any]], gen_lazy_ir: GenLazyIR +) -> List[str]: """Generate the non-native lazy IR node classes""" nodes = [] for op in non_native: diff --git a/torchgen/dest/native_functions.py b/torchgen/dest/native_functions.py index a93405555bc..531c01b699f 100644 --- a/torchgen/dest/native_functions.py +++ b/torchgen/dest/native_functions.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import List, Optional, Union import torchgen.api.meta as meta import torchgen.api.structured as structured @@ -9,7 +9,7 @@ from torchgen.utils import mapMaybe @with_native_function_and_index -def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None: +def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]: sig = kernel_signature(f, backend_index) metadata = backend_index.get_kernel(f) if metadata is None: @@ -22,7 +22,7 @@ def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | No @with_native_function_and_index -def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]: +def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]: meta_name = meta.name(g) out_args = structured.impl_arguments(g) metadata = backend_index.get_kernel(g) @@ -42,8 +42,8 @@ void impl({', '.join(a.decl() for a in out_args)}); # actual kernel definitions we keep in aten/src/ATen/native/ @with_native_function_and_index def compute_native_function_declaration( - g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex -) -> list[str]: + g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex +) -> List[str]: metadata = backend_index.get_kernel(g) if isinstance(g, NativeFunctionsGroup): if metadata is not None and metadata.structured: diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index 75aa165590d..4e24345883b 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -1,9 +1,7 @@ -from __future__ import annotations - import itertools import textwrap from dataclasses import dataclass -from typing import Literal, TYPE_CHECKING +from typing import List, Literal, Optional, Tuple, Union import torchgen.api.cpp as cpp import torchgen.api.meta as meta @@ -36,18 +34,15 @@ from torchgen.model import ( SchemaKind, TensorOptionsArguments, ) +from torchgen.selective_build.selector import SelectiveBuilder from torchgen.utils import assert_never, mapMaybe, Target -if TYPE_CHECKING: - from torchgen.selective_build.selector import SelectiveBuilder - - def gen_registration_headers( backend_index: BackendIndex, per_operator_headers: bool, rocm: bool, -) -> list[str]: +) -> List[str]: if per_operator_headers: headers = ["#include "] else: @@ -78,7 +73,7 @@ def gen_registration_headers( def gen_empty_impl_names( backend_index: BackendIndex, -) -> tuple[str | None, str | None]: +) -> Tuple[Optional[str], Optional[str]]: empty_impl = None empty_strided_impl = None @@ -102,7 +97,7 @@ def gen_empty_impl_names( return empty_impl, empty_strided_impl -def gen_create_out_helper(backend_index: BackendIndex) -> list[str]: +def gen_create_out_helper(backend_index: BackendIndex) -> List[str]: if backend_index.dispatch_key == DispatchKey.Meta: empty_options = "options.device(at::kMeta)" else: @@ -125,7 +120,7 @@ Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &o ] -def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]: +def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> List[str]: _, empty_strided_impl = gen_empty_impl_names(backend_index) return ( [] @@ -143,7 +138,7 @@ std::optional maybe_create_proxy(const Tensor &out, IntArrayRef sizes, I ) -def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]: +def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]: if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: # The function isn't used by this key (since only functional ops have a kernel for this key), # so we need to not include it to avoid a defined-but-not-used error. @@ -173,7 +168,7 @@ void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const ] -def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]: +def gen_check_inplace_helper(backend_index: BackendIndex) -> List[str]: return [ """ void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) { @@ -196,7 +191,7 @@ void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &o ] -def gen_registration_helpers(backend_index: BackendIndex) -> list[str]: +def gen_registration_helpers(backend_index: BackendIndex) -> List[str]: return [ 'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")', *gen_create_out_helper(backend_index), @@ -254,7 +249,7 @@ class RegisterDispatchKey: # Finally, this field is currently Optional because it is only used by external backends. # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating # all of the existing kernel signatures scattered across aten/src/ATen/native. - class_method_name: str | None + class_method_name: Optional[str] # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher. @@ -262,7 +257,7 @@ class RegisterDispatchKey: @staticmethod def gen_device_check( - type: DeviceCheckType, args: list[Argument], method_name: str + type: DeviceCheckType, args: List[Argument], method_name: str ) -> str: if type == DeviceCheckType.NoCheck: return " // No device check\n" @@ -277,7 +272,7 @@ class RegisterDispatchKey: return device_check @method_with_native_function - def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]: + def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: if isinstance(f, NativeFunctionsGroup): g: NativeFunctionsGroup = f # Note: We call gen_structured() if the operator is marked structured, regardless of the backend. @@ -296,7 +291,7 @@ class RegisterDispatchKey: def wrapper_kernel_sig( self, f: NativeFunction - ) -> NativeSignature | DispatcherSignature: + ) -> Union[NativeSignature, DispatcherSignature]: # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names. return DispatcherSignature.from_schema( f.func, @@ -305,8 +300,8 @@ class RegisterDispatchKey: ) def gen_out_inplace_wrapper( - self, f: NativeFunction, g: NativeFunctionsGroup | None - ) -> str | None: + self, f: NativeFunction, g: Optional[NativeFunctionsGroup] + ) -> Optional[str]: if g is None: return None k = f.func.kind() @@ -355,7 +350,7 @@ class RegisterDispatchKey: }} """ - def gen_structured(self, g: NativeFunctionsGroup) -> list[str]: + def gen_structured(self, g: NativeFunctionsGroup) -> List[str]: metadata = self.backend_index.get_kernel(g) if self.backend_index.dispatch_key == DispatchKey.Meta: assert not self.backend_index.has_kernel(g.out), ( @@ -385,8 +380,8 @@ class RegisterDispatchKey: return list(mapMaybe(structured_gen.gen_one, g.functions())) def gen_unstructured( - self, f: NativeFunction, g: NativeFunctionsGroup | None = None - ) -> str | None: + self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None + ) -> Optional[str]: with native_function_manager(f): inplace_meta = False gets_out_inplace_wrapper = False @@ -737,7 +732,7 @@ resize_out(out, sizes, strides, options); return "\n".join(line for line in lines if line) @method_with_native_function - def gen_one(self, f: NativeFunction) -> str | None: + def gen_one(self, f: NativeFunction) -> Optional[str]: assert not f.manual_kernel_registration if ( @@ -811,7 +806,7 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si sig_body = [] # We'll use context to keep track of any variables we've brought # into scope while generating code - context: list[Binding | Expr] = list(sig.arguments()) + context: List[Union[Binding, Expr]] = list(sig.arguments()) # Initialize the class corresponding to this structured # operator; feeding it the output argument(s) if it is known diff --git a/torchgen/dest/ufunc.py b/torchgen/dest/ufunc.py index 073df2eb184..8c90160fa69 100644 --- a/torchgen/dest/ufunc.py +++ b/torchgen/dest/ufunc.py @@ -1,7 +1,5 @@ -from __future__ import annotations - from dataclasses import dataclass -from typing import Sequence, TYPE_CHECKING +from typing import Dict, List, Optional, Sequence, Tuple, Union import torchgen.api.ufunc as ufunc from torchgen.api.translate import translate @@ -16,6 +14,7 @@ from torchgen.api.types import ( StructuredImplSignature, VectorizedCType, ) +from torchgen.api.ufunc import UfunctorBindings from torchgen.context import with_native_function from torchgen.model import ( Argument, @@ -29,10 +28,6 @@ from torchgen.model import ( from torchgen.utils import OrderedSet -if TYPE_CHECKING: - from torchgen.api.ufunc import UfunctorBindings - - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # CUDA STUFF @@ -65,7 +60,7 @@ if TYPE_CHECKING: @dataclass(frozen=True) class UfunctorSignature: g: NativeFunctionsGroup - scalar_tensor_idx: int | None + scalar_tensor_idx: Optional[int] name: str def arguments(self) -> UfunctorBindings: @@ -73,7 +68,7 @@ class UfunctorSignature: self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t ) - def fields(self) -> list[Binding]: + def fields(self) -> List[Binding]: # fields are renamed to have a trailing underscore, as is conventional return [b.rename(f"{b.name}_") for b in self.arguments().ctor] @@ -103,10 +98,10 @@ class UfuncSignature: name: str compute_t: CType - def arguments(self) -> list[Binding]: + def arguments(self) -> List[Binding]: return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t) - def call(self, ctx: Sequence[Binding | Expr]) -> str: + def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str: return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})" @@ -137,10 +132,10 @@ def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool: def compute_ufunc_cuda_functors( g: NativeFunctionsGroup, -) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]: +) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]: # First, build the functors. - ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {} - ufunctors: list[str] = [] + ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {} + ufunctors: List[str] = [] loops = g.out.ufunc_inner_loop scalar_tensor_idx_lookup = { UfuncKey.CUDAFunctorOnSelf: 1, @@ -242,7 +237,7 @@ BinaryScalarSpecializationConfigs = [ def compute_ufunc_cuda_dtype_body( g: NativeFunctionsGroup, dtype: ScalarType, - inner_loops: dict[UfuncKey, UfunctorSignature], + inner_loops: Dict[UfuncKey, UfunctorSignature], parent_ctx: Sequence[Binding], ) -> str: body = "using opmath_t = at::opmath_type;" @@ -254,7 +249,7 @@ def compute_ufunc_cuda_dtype_body( scalar_idx = config.scalar_idx + 1 # Make a copy and at the same time widen the type (not permissible # without copy; we don't want to mutate the input argument anyway) - ctx: list[Expr | Binding] = list(parent_ctx) + ctx: List[Union[Expr, Binding]] = list(parent_ctx) ctx.append( Expr( expr=f"iter.scalar_value({scalar_idx})", @@ -351,7 +346,7 @@ class StubSignature: def type_name(self) -> str: return f"{str(self.g.functional.func.name.name)}_fn" - def arguments(self) -> list[Binding]: + def arguments(self) -> List[Binding]: return ufunc.stub_arguments(self.g) def type(self) -> str: @@ -398,7 +393,7 @@ def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str: def compute_ufunc_cpu_dtype_body( g: NativeFunctionsGroup, dtype: ScalarType, - inner_loops: dict[UfuncKey, UfuncSignature], + inner_loops: Dict[UfuncKey, UfuncSignature], parent_ctx: Sequence[Binding], ) -> str: assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}" @@ -464,8 +459,8 @@ def compute_ufunc_cpu_dtype_body( ) ) - def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]: - r: list[Expr | Binding] = [] + def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]: + r: List[Union[Expr, Binding]] = [] r.extend(ctx) r.extend(b) return r @@ -494,7 +489,7 @@ def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str: # Reindex the ufunc by dtypes; processing generic/scalaronly as well loops = g.out.ufunc_inner_loop - ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {} + ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {} for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]: lks = [] # ORDER MATTERS: this specifies overriding precedence diff --git a/torchgen/executorch/api/custom_ops.py b/torchgen/executorch/api/custom_ops.py index c74af600d4d..e4eec9e3fb2 100644 --- a/torchgen/executorch/api/custom_ops.py +++ b/torchgen/executorch/api/custom_ops.py @@ -1,29 +1,24 @@ -from __future__ import annotations - from collections import defaultdict from dataclasses import dataclass -from typing import Sequence, TYPE_CHECKING +from typing import Dict, List, Optional, Sequence, Tuple from torchgen import dest # disable import sorting to avoid circular dependency. from torchgen.api.types import DispatcherSignature # usort:skip from torchgen.context import method_with_native_function +from torchgen.executorch.model import ETKernelIndex from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant +from torchgen.selective_build.selector import SelectiveBuilder from torchgen.utils import concatMap, Target -if TYPE_CHECKING: - from torchgen.executorch.model import ETKernelIndex - from torchgen.selective_build.selector import SelectiveBuilder - - # Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at # model authoring side. @dataclass(frozen=True) class ComputeNativeFunctionStub: @method_with_native_function - def __call__(self, f: NativeFunction) -> str | None: + def __call__(self, f: NativeFunction) -> Optional[str]: if Variant.function not in f.variants: return None @@ -85,7 +80,7 @@ def gen_custom_ops_registration( selector: SelectiveBuilder, kernel_index: ETKernelIndex, rocm: bool, -) -> tuple[str, str]: +) -> Tuple[str, str]: """ Generate custom ops registration code for dest.RegisterDispatchKey. @@ -102,7 +97,7 @@ def gen_custom_ops_registration( dispatch_key = DispatchKey.CPU backend_index = kernel_index._to_backend_index() static_init_dispatch_registrations = "" - ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list) + ns_grouped_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list) for native_function in native_functions: ns_grouped_native_functions[native_function.namespace].append(native_function) diff --git a/torchgen/executorch/api/et_cpp.py b/torchgen/executorch/api/et_cpp.py index 0bdf28acfa6..18574f472fc 100644 --- a/torchgen/executorch/api/et_cpp.py +++ b/torchgen/executorch/api/et_cpp.py @@ -1,6 +1,4 @@ -from __future__ import annotations - -from typing import Sequence +from typing import List, Optional, Sequence, Set, Union from torchgen import local from torchgen.api.types import ( @@ -65,7 +63,7 @@ def valuetype_type( *, binds: ArgName, remove_non_owning_ref_types: bool = False, -) -> NamedCType | None: +) -> Optional[NamedCType]: if isinstance(t, BaseType): if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar: return None @@ -211,7 +209,7 @@ def returns_type(rs: Sequence[Return]) -> CType: def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]: - returns: list[str] = [] + returns: List[str] = [] for i, r in enumerate(f.func.returns): # If we have an inplace function, the return argument is # implicitly named self. @@ -297,16 +295,16 @@ def default_expr(d: str, t: Type) -> str: def argument( - a: Argument | TensorOptionsArguments | SelfArgument, + a: Union[Argument, TensorOptionsArguments, SelfArgument], *, - cpp_no_default_args: set[str], + cpp_no_default_args: Set[str], method: bool, faithful: bool, has_tensor_options: bool, -) -> list[Binding]: +) -> List[Binding]: def sub_argument( - a: Argument | TensorOptionsArguments | SelfArgument, - ) -> list[Binding]: + a: Union[Argument, TensorOptionsArguments, SelfArgument] + ) -> List[Binding]: return argument( a, cpp_no_default_args=cpp_no_default_args, @@ -321,7 +319,7 @@ def argument( binds = SpecialArgName.possibly_redundant_memory_format else: binds = a.name - default: str | None = None + default: Optional[str] = None if a.name not in cpp_no_default_args and a.default is not None: default = default_expr(a.default, a.type) return [ @@ -349,9 +347,9 @@ def arguments( *, faithful: bool, method: bool, - cpp_no_default_args: set[str], -) -> list[Binding]: - args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + cpp_no_default_args: Set[str], +) -> List[Binding]: + args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] if faithful: args.extend(arguments.non_out) args.extend(arguments.out) diff --git a/torchgen/executorch/api/types/signatures.py b/torchgen/executorch/api/types/signatures.py index ac3477cede6..3449b2b9a52 100644 --- a/torchgen/executorch/api/types/signatures.py +++ b/torchgen/executorch/api/types/signatures.py @@ -1,15 +1,10 @@ -from __future__ import annotations - from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import List, Optional, Set import torchgen.api.cpp as aten_cpp +from torchgen.api.types import Binding, CType from torchgen.executorch.api.types.types import contextArg - - -if TYPE_CHECKING: - from torchgen.api.types import Binding, CType - from torchgen.model import FunctionSchema, NativeFunction +from torchgen.model import FunctionSchema, NativeFunction @dataclass(frozen=True) @@ -25,14 +20,14 @@ class ExecutorchCppSignature: func: FunctionSchema # The set of C++ arguments which should not have defaults applied to them - cpp_no_default_args: set[str] + cpp_no_default_args: Set[str] # Allows you to prepend an arbitrary prefix to the signature name. # This is useful for parts of the codegen that generate wrappers around kernels, # and need to avoid naming collisions. prefix: str = "" - def arguments(self, *, include_context: bool = True) -> list[Binding]: + def arguments(self, *, include_context: bool = True) -> List[Binding]: return ([contextArg] if include_context else []) + et_cpp.arguments( self.func.arguments, faithful=True, # always faithful, out argument at the end @@ -46,7 +41,7 @@ class ExecutorchCppSignature: faithful_name_for_out_overloads=True, ) - def decl(self, name: str | None = None, *, include_context: bool = True) -> str: + def decl(self, name: Optional[str] = None, *, include_context: bool = True) -> str: args_str = ", ".join( a.decl() for a in self.arguments(include_context=include_context) ) @@ -54,7 +49,7 @@ class ExecutorchCppSignature: name = self.name() return f"{self.returns_type().cpp_type()} {name}({args_str})" - def defn(self, name: str | None = None) -> str: + def defn(self, name: Optional[str] = None) -> str: args = [a.defn() for a in self.arguments()] args_str = ", ".join(args) if name is None: @@ -67,7 +62,7 @@ class ExecutorchCppSignature: @staticmethod def from_native_function( f: NativeFunction, *, prefix: str = "" - ) -> ExecutorchCppSignature: + ) -> "ExecutorchCppSignature": return ExecutorchCppSignature( func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args ) diff --git a/torchgen/executorch/api/types/types.py b/torchgen/executorch/api/types/types.py index b3a960a8246..6ec48c803ae 100644 --- a/torchgen/executorch/api/types/types.py +++ b/torchgen/executorch/api/types/types.py @@ -1,6 +1,5 @@ -from __future__ import annotations - from dataclasses import dataclass +from typing import Dict from torchgen.api.types import ( BaseCppType, @@ -41,7 +40,7 @@ contextArg = Binding( default=None, ) -BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = { +BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = { BaseTy.int: longT, BaseTy.float: doubleT, BaseTy.bool: boolT, @@ -55,7 +54,7 @@ BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = { @dataclass(frozen=True) class OptionalCType(CType): - elem: CType + elem: "CType" def cpp_type(self, *, strip_ref: bool = False) -> str: # Do not pass `strip_ref` recursively. @@ -64,13 +63,13 @@ class OptionalCType(CType): def cpp_type_registration_declarations(self) -> str: return f"torch::executor::optional<{self.elem.cpp_type_registration_declarations()}>" - def remove_const_ref(self) -> CType: + def remove_const_ref(self) -> "CType": return OptionalCType(self.elem.remove_const_ref()) @dataclass(frozen=True) class ArrayRefCType(CType): - elem: CType + elem: "CType" def cpp_type(self, *, strip_ref: bool = False) -> str: # Do not pass `strip_ref` recursively. @@ -79,5 +78,5 @@ class ArrayRefCType(CType): def cpp_type_registration_declarations(self) -> str: return f"torch::executor::ArrayRef<{self.elem.cpp_type_registration_declarations()}>" - def remove_const_ref(self) -> CType: + def remove_const_ref(self) -> "CType": return ArrayRefCType(self.elem.remove_const_ref()) diff --git a/torchgen/executorch/api/unboxing.py b/torchgen/executorch/api/unboxing.py index 86e19cd6320..a81e6d11fea 100644 --- a/torchgen/executorch/api/unboxing.py +++ b/torchgen/executorch/api/unboxing.py @@ -1,8 +1,7 @@ -from __future__ import annotations - from dataclasses import dataclass -from typing import Callable, Sequence, TYPE_CHECKING +from typing import Callable, List, Sequence, Tuple +from torchgen.api.types import Binding, CType, NamedCType from torchgen.model import ( Argument, BaseTy, @@ -14,10 +13,6 @@ from torchgen.model import ( ) -if TYPE_CHECKING: - from torchgen.api.types import Binding, CType, NamedCType - - connector = "\n\t" @@ -57,7 +52,7 @@ class Unboxing: # Convert all the arguments in a NativeFunction to C++ code def convert_arguments( self, args: Sequence[Binding] - ) -> tuple[list[Binding], list[str]]: + ) -> Tuple[List[Binding], List[str]]: code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))] binding_list = [] for arg in args: @@ -77,7 +72,7 @@ class Unboxing: def argumenttype_evalue_convert( self, t: Type, arg_name: str, *, mutable: bool = False - ) -> tuple[str, CType, list[str], list[str]]: + ) -> Tuple[str, CType, List[str], List[str]]: """ Takes in the type, name and mutability corresponding to an argument, and generates a tuple of: (1) the C++ code necessary to unbox the argument @@ -112,14 +107,14 @@ class Unboxing: def _gen_code_base_type( self, arg_name: str, out_name: str, ctype: CType - ) -> tuple[list[str], list[str]]: + ) -> Tuple[List[str], List[str]]: return [ f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" ], [] def _gen_code_optional_type( self, arg_name: str, out_name: str, t: OptionalType, ctype: CType - ) -> tuple[list[str], list[str]]: + ) -> Tuple[List[str], List[str]]: in_name = f"{arg_name}_opt_in" res_name, base_type, res_code, decl = self.argumenttype_evalue_convert( t.elem, in_name @@ -135,7 +130,7 @@ class Unboxing: def _gen_code_list_type( self, arg_name: str, out_name: str, t: ListType, ctype: CType - ) -> tuple[list[str], list[str]]: + ) -> Tuple[List[str], List[str]]: in_name = f"{arg_name}_list_in" elem_name = f"{arg_name}_elem" code = [] diff --git a/torchgen/executorch/model.py b/torchgen/executorch/model.py index 6aadfe41dae..a7d5f1ceb16 100644 --- a/torchgen/executorch/model.py +++ b/torchgen/executorch/model.py @@ -1,12 +1,11 @@ # Represents all kernels used by an Executorch model. # It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure. -from __future__ import annotations - import itertools from collections import defaultdict, namedtuple from dataclasses import dataclass from enum import IntEnum +from typing import Dict, List, Tuple, Union from torchgen.model import ( BackendIndex, @@ -42,7 +41,7 @@ class ETKernelKeyOpArgMeta: arg_name: str dtype: str # The order of the dimensions if entry is a Tensor - dim_order: tuple[int, ...] + dim_order: Tuple[int, ...] def to_native_string(self) -> str: dtype_str = ScalarType[self.dtype].value @@ -53,7 +52,7 @@ class ETKernelKeyOpArgMeta: @dataclass(frozen=True) class ETKernelKey: # Field undefined is default = True - arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = () + arg_meta: Tuple[ETKernelKeyOpArgMeta, ...] = () # Indicator for this kernel being used as a catch all default: bool = False @@ -62,10 +61,10 @@ class ETKernelKey: @staticmethod def gen_from_yaml( - args: dict[str, tuple[str, str]], - type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val - dim_order_alias_map: dict[str, list[int]], - ) -> list[ETKernelKey]: + args: Dict[str, Tuple[str, str]], + type_alias_map: Dict[str, List[str]], # TODO: Support unwrapped str val + dim_order_alias_map: Dict[str, List[int]], + ) -> List["ETKernelKey"]: """Generate ETKernelKeys from arg kernel specs Multiple ETKernelKeys are returned due to dtype permutations from utilizing type_alias_map (actualizing each potential type permutation as a KernelKey) @@ -138,15 +137,15 @@ class ETKernelKey: @dataclass(frozen=True) class ETKernelIndex: - index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] + index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] - def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool: + def has_kernels(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool: m = self.get_kernels(g) return m is not None def get_kernels( - self, g: NativeFunction | NativeFunctionsGroup - ) -> dict[ETKernelKey, BackendMetadata]: + self, g: Union[NativeFunction, NativeFunctionsGroup] + ) -> Dict[ETKernelKey, BackendMetadata]: if isinstance(g, NativeFunction): f = g elif isinstance(g, NativeFunctionsGroup): @@ -159,8 +158,8 @@ class ETKernelIndex: @staticmethod def grow_from_backend_indices( - kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]], - backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]], + kernel_index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]], + backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], ) -> None: for dk in backend_indices: index = backend_indices[dk] @@ -172,17 +171,17 @@ class ETKernelIndex: @staticmethod def from_backend_indices( - backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] - ) -> ETKernelIndex: - kernel_index: dict[ - OperatorName, dict[ETKernelKey, BackendMetadata] + backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] + ) -> "ETKernelIndex": + kernel_index: Dict[ + OperatorName, Dict[ETKernelKey, BackendMetadata] ] = defaultdict(dict) ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices) return ETKernelIndex(kernel_index) def grow( - self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] - ) -> ETKernelIndex: + self, backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] + ) -> "ETKernelIndex": ETKernelIndex.grow_from_backend_indices(self.index, backend_indices) return self @@ -190,7 +189,7 @@ class ETKernelIndex: """ WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex. """ - index: dict[OperatorName, BackendMetadata] = {} + index: Dict[OperatorName, BackendMetadata] = {} for op in self.index: kernel_dict = self.index[op] assert ( @@ -210,7 +209,9 @@ class ETKernelIndex: # Note duplicate ETKernelKey from index_b will clobber the metadata from index_a @staticmethod - def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex: + def merge_indices( + index_a: "ETKernelIndex", index_b: "ETKernelIndex" + ) -> "ETKernelIndex": combined = defaultdict(dict, index_a.index.copy()) for op, entry in index_b.index.items(): diff --git a/torchgen/executorch/parse.py b/torchgen/executorch/parse.py index 8095abd5b6b..94acb5c2115 100644 --- a/torchgen/executorch/parse.py +++ b/torchgen/executorch/parse.py @@ -1,7 +1,5 @@ -from __future__ import annotations - from collections import defaultdict, namedtuple -from typing import Any +from typing import Any, Dict, List, Optional, Set, Tuple import yaml @@ -24,7 +22,7 @@ ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indice ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"] -def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]: +def parse_from_yaml(ei: Dict[str, object]) -> Dict[ETKernelKey, BackendMetadata]: """Given a loaded yaml representing kernel assignment information, extract the mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance) @@ -36,11 +34,11 @@ def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata] if (kernels := e.pop("kernels", None)) is None: return {} - type_alias: dict[str, list[str]] = e.pop("type_alias", {}) # type: ignore[assignment] - dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment] + type_alias: Dict[str, List[str]] = e.pop("type_alias", {}) # type: ignore[assignment] + dim_order_alias: Dict[str, List[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment] dim_order_alias.pop("__line__", None) - kernel_mapping: dict[ETKernelKey, BackendMetadata] = {} + kernel_mapping: Dict[ETKernelKey, BackendMetadata] = {} for entry in kernels: # type: ignore[attr-defined] arg_meta = entry.get("arg_meta") @@ -78,7 +76,7 @@ def parse_et_yaml_struct(es: object) -> ETKernelIndex: of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance that should be used by the kernel key). """ - indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {} + indices: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] = {} for ei in es: # type: ignore[attr-defined] e = ei.copy() @@ -97,11 +95,11 @@ def parse_et_yaml_struct(es: object) -> ETKernelIndex: return ETKernelIndex(indices) -def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]: +def extract_kernel_fields(es: object) -> Dict[OperatorName, Dict[str, Any]]: """Given a loaded yaml representing a list of operators, extract the kernel key related fields indexed by the operator name. """ - fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict) + fields: Dict[OperatorName, Dict[str, Any]] = defaultdict(dict) for ei in es: # type: ignore[attr-defined] funcs = ei.get("func") assert isinstance(funcs, str), f"not a str: {funcs}" @@ -120,9 +118,9 @@ def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]: def parse_et_yaml( path: str, tags_yaml_path: str, - ignore_keys: set[DispatchKey] | None = None, + ignore_keys: Optional[Set[DispatchKey]] = None, skip_native_fns_gen: bool = False, -) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]: +) -> Tuple[List[NativeFunction], Dict[OperatorName, Dict[str, Any]]]: """Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict of fields to persist from native_functions.yaml to functions.yaml """ diff --git a/torchgen/gen.py b/torchgen/gen.py index ef4867a9504..e9dc04d0b9b 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -1,13 +1,23 @@ -from __future__ import annotations - import argparse import functools import json import os +import pathlib from collections import defaultdict, namedtuple, OrderedDict from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Callable, Literal, Sequence, TypeVar +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, +) import yaml @@ -138,20 +148,20 @@ class LineLoader(YamlLoader): ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"]) -_GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {} -_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {} +_GLOBAL_PARSE_NATIVE_YAML_CACHE: Dict[str, ParsedYaml] = {} +_GLOBAL_PARSE_TAGS_YAML_CACHE: Dict[str, Set[str]] = {} def parse_native_yaml_struct( es: object, - valid_tags: set[str], - ignore_keys: set[DispatchKey] | None = None, + valid_tags: Set[str], + ignore_keys: Optional[Set[DispatchKey]] = None, path: str = "", skip_native_fns_gen: bool = False, ) -> ParsedYaml: assert isinstance(es, list) - rs: list[NativeFunction] = [] - bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict) + rs: List[NativeFunction] = [] + bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict) for e in es: assert isinstance(e, dict), f"expected to be dict: {e}" assert isinstance(e.get("__line__"), int), e @@ -164,7 +174,7 @@ def parse_native_yaml_struct( BackendIndex.grow_index(bs, m) error_check_native_functions(rs) # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet. - indices: dict[DispatchKey, BackendIndex] = defaultdict( + indices: Dict[DispatchKey, BackendIndex] = defaultdict( lambda: BackendIndex( dispatch_key=DispatchKey.Undefined, use_out_as_primary=True, @@ -190,9 +200,9 @@ def parse_native_yaml_struct( return ParsedYaml(rs, indices) -def parse_tags_yaml_struct(es: object, path: str = "") -> set[str]: +def parse_tags_yaml_struct(es: object, path: str = "") -> Set[str]: assert isinstance(es, list) - rs: set[str] = set() + rs: Set[str] = set() for e in es: assert isinstance(e.get("__line__"), int), e loc = Location(path, e["__line__"]) @@ -208,7 +218,7 @@ def parse_tags_yaml_struct(es: object, path: str = "") -> set[str]: @functools.lru_cache(maxsize=None) -def parse_tags_yaml(path: str) -> set[str]: +def parse_tags_yaml(path: str) -> Set[str]: global _GLOBAL_PARSE_TAGS_YAML_CACHE if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE: with open(path) as f: @@ -221,10 +231,10 @@ def parse_tags_yaml(path: str) -> set[str]: def parse_native_yaml( path: str, tags_yaml_path: str, - ignore_keys: set[DispatchKey] | None = None, + ignore_keys: Optional[Set[DispatchKey]] = None, *, skip_native_fns_gen: bool = False, - loaded_yaml: object | None = None, + loaded_yaml: Optional[object] = None, ) -> ParsedYaml: global _GLOBAL_PARSE_NATIVE_YAML_CACHE if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE: @@ -251,8 +261,8 @@ def parse_native_yaml( # Some assertions are already performed during parsing, but those are only within a single NativeFunction. # Assertions here are meant to be performed across NativeFunctions. def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None: - func_map: dict[OperatorName, NativeFunction] = {} - base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list) + func_map: Dict[OperatorName, NativeFunction] = {} + base_func_map: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list) for f in funcs: func_map[f.func.name] = f base_func_map[f.func.name.name].append(f) @@ -319,7 +329,7 @@ def cpp_string(s: str) -> str: # and similar functional combinators. -def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]: +def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]: if len(backends) == 0: return [] else: @@ -333,7 +343,7 @@ def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]: def get_static_dispatch_backend( f: NativeFunction, backend_index: BackendIndex -) -> DispatchKey | None: +) -> Optional[DispatchKey]: if f.structured_delegate is not None or backend_index.has_kernel(f): # TODO: for ops with structured_delegate it should check the dispatch table of # the out variant instead. For now, these structured ops all have CPU/CUDA kernels @@ -352,8 +362,8 @@ def get_static_dispatch_backend( def static_dispatch_ops_header( - f: NativeFunction, backend_index: list[BackendIndex] -) -> str | None: + f: NativeFunction, backend_index: List[BackendIndex] +) -> Optional[str]: if backend_index is None or f.manual_kernel_registration: return None @@ -367,7 +377,7 @@ def static_dispatch_ops_header( return "\n".join(output) -def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]: +def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]: return [ f"#include " for dispatch_key in static_dispatch_keys(backends) @@ -378,12 +388,12 @@ def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]: # Note that we have a special case for `memory_format` argument and this case is not covered by # tools.codegen.api.translate() yet as its application is limited to static dispatch. def translate_args( - sig: CppSignature | DispatcherSignature, + sig: Union[CppSignature, DispatcherSignature], cpp_sig: CppSignature, ) -> str: # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings - def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]: - output_bindings: list[Binding] = [] + def add_spl_memory_format_binding(input_bindings: List[Binding]) -> List[Binding]: + output_bindings: List[Binding] = [] for binding in input_bindings: if binding.name == "memory_format": spl_mem_format_binding = Binding( @@ -413,7 +423,7 @@ def translate_args( def generate_static_dispatch_backend_call( - sig: CppSignature | DispatcherSignature, + sig: Union[CppSignature, DispatcherSignature], f: NativeFunction, backend_index: BackendIndex, ) -> str: @@ -431,9 +441,9 @@ def generate_static_dispatch_backend_call( def generate_static_dispatch_fallback_call( - sig: CppSignature | DispatcherSignature, + sig: Union[CppSignature, DispatcherSignature], f: NativeFunction, - backend_indices: list[BackendIndex], + backend_indices: List[BackendIndex], ) -> str: cpp_sigs = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False @@ -460,9 +470,9 @@ def generate_static_dispatch_fallback_call( def static_dispatch( - sig: CppSignature | DispatcherSignature, + sig: Union[CppSignature, DispatcherSignature], f: NativeFunction, - backend_indices: list[BackendIndex], + backend_indices: List[BackendIndex], ) -> str: """ For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one @@ -502,7 +512,7 @@ def static_dispatch( tensor_opts = f.func.arguments.tensor_options stmts = [] - subexprs: list[str] = [] + subexprs: List[str] = [] if tensor_opts is not None: subexprs.append( "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))" @@ -538,10 +548,10 @@ def static_dispatch( @dataclass(frozen=True) class RegisterSchema: selector: SelectiveBuilder - known_tags: dict[str, int] = field(default_factory=dict) + known_tags: Dict[str, int] = field(default_factory=dict) @method_with_native_function - def __call__(self, f: NativeFunction) -> str | None: + def __call__(self, f: NativeFunction) -> Optional[str]: if not self.selector.is_native_function_selected(f): return None tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}" @@ -563,7 +573,7 @@ class RegisterSchema: @dataclass(frozen=True) class ComputeOperators: target: Literal[Target.DECLARATION, Target.DEFINITION] - static_dispatch_backend_indices: list[BackendIndex] + static_dispatch_backend_indices: List[BackendIndex] @method_with_native_function def __call__(self, f: NativeFunction) -> str: @@ -660,7 +670,7 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed @dataclass(frozen=True) class ComputeFunction: @method_with_native_function - def __call__(self, f: NativeFunction) -> str | None: + def __call__(self, f: NativeFunction) -> Optional[str]: sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=f.manual_cpp_binding ) @@ -708,10 +718,10 @@ namespace symint {{ @dataclass(frozen=True) class ComputeTensorMethod: target: Literal[Target.DECLARATION, Target.DEFINITION] - static_dispatch_backend_indices: list[BackendIndex] + static_dispatch_backend_indices: List[BackendIndex] @method_with_native_function - def __call__(self, f: NativeFunction) -> str | None: + def __call__(self, f: NativeFunction) -> Optional[str]: if Variant.method not in f.variants: return None @@ -754,7 +764,7 @@ inline {sig.defn(prefix="Tensor::")} const {{ @dataclass(frozen=True) class ComputeRedispatchFunction: @method_with_native_function - def __call__(self, f: NativeFunction) -> str | None: + def __call__(self, f: NativeFunction) -> Optional[str]: # We unconditionally generate function variants of the redispatch API. # This is mainly because we can namespace functions separately, but not methods, sig_group = CppSignatureGroup.from_native_function( @@ -788,7 +798,7 @@ def compute_aten_op(f: NativeFunction) -> str: # Generates MetaFunctions.h -def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None: +def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]: if not g.structured: return None with native_function_manager(g.out): @@ -933,7 +943,7 @@ class ComputeBackendSelect: selector: SelectiveBuilder @method_with_native_function - def __call__(self, f: NativeFunction) -> str | None: + def __call__(self, f: NativeFunction) -> Optional[str]: if not needs_backend_select(f, self.selector): return None @@ -949,7 +959,7 @@ class ComputeBackendSelect: dispatcher_sig = DispatcherSignature.from_schema(f.func) - sig: NativeSignature | DispatcherSignature + sig: Union[NativeSignature, DispatcherSignature] sig = dispatcher_sig dispatcher_exprs = dispatcher_sig.exprs() dispatch_key = "c10::computeDispatchKey(dtype, layout, device)" @@ -1049,7 +1059,7 @@ def dynamic_type(t: Type) -> str: ).cpp_type() -def compute_method_of_yaml(variants: set[Variant]) -> list[str]: +def compute_method_of_yaml(variants: Set[Variant]) -> List[str]: # This is written out explicitly to ensure that Tensor and # namespace are put into the list in the right order method_of = ["Type"] @@ -1062,7 +1072,7 @@ def compute_method_of_yaml(variants: set[Variant]) -> list[str]: def compute_returns_yaml( f: NativeFunction, -) -> tuple[list[dict[str, str]], dict[str, str]]: +) -> Tuple[List[Dict[str, str]], Dict[str, str]]: # Note [name and field_name] # ~~~~~~~~~~~~~~~~~~~~~~~~~~ # To understand name_to_field_name, we must first talk about this @@ -1102,7 +1112,7 @@ def compute_returns_yaml( # schema itself. # # See also https://github.com/pytorch/pytorch/issues/43114 - name_to_field_name: dict[str, str] = {} + name_to_field_name: Dict[str, str] = {} # Compute the returns field of the YAML entry names = cpp.return_names(f) @@ -1131,12 +1141,12 @@ def compute_cpp_argument_yaml( cpp_a: Binding, *, schema_order: bool, - kwarg_only_set: set[str], - out_arg_set: set[str], - name_to_field_name: dict[str, str], + kwarg_only_set: Set[str], + out_arg_set: Set[str], + name_to_field_name: Dict[str, str], ) -> object: if isinstance(cpp_a.argument, TensorOptionsArguments): - arg: dict[str, object] = { + arg: Dict[str, object] = { "annotation": None, "dynamic_type": "at::TensorOptions", "is_nullable": False, @@ -1163,11 +1173,11 @@ def compute_argument_yaml( a: Argument, *, schema_order: bool, - kwarg_only_set: set[str], - out_arg_set: set[str], - name_to_field_name: dict[str, str], + kwarg_only_set: Set[str], + out_arg_set: Set[str], + name_to_field_name: Dict[str, str], ) -> object: - arg: dict[str, object] = { + arg: Dict[str, object] = { "annotation": str(a.annotation) if a.annotation else None, "dynamic_type": dynamic_type(a.type), "is_nullable": a.type.is_nullable(), @@ -1293,7 +1303,7 @@ def has_autogenerated_composite_kernel(f: NativeFunction) -> bool: @with_native_function_and_indices def compute_registration_declarations( - f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex] + f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex] ) -> str: name = dispatcher.name(f.func) returns_type = dispatcher.returns_type( @@ -1301,7 +1311,7 @@ def compute_registration_declarations( ).cpp_type_registration_declarations() args = dispatcher.arguments(f.func) args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args) - comment_data: dict[str, str] = { + comment_data: Dict[str, str] = { "schema": f"aten::{f.func}", # TODO: What exactly is the semantics of the 'dispatch' field? "dispatch": str( @@ -1327,8 +1337,8 @@ def compute_registration_declarations( def get_custom_build_selector( - provided_op_registration_allowlist: list[str] | None, - op_selection_yaml_path: str | None, + provided_op_registration_allowlist: Optional[List[str]], + op_selection_yaml_path: Optional[str], ) -> SelectiveBuilder: assert not ( provided_op_registration_allowlist is not None @@ -1339,7 +1349,7 @@ def get_custom_build_selector( + "same time." ) - op_registration_allowlist: set[str] | None = None + op_registration_allowlist: Optional[Set[str]] = None if provided_op_registration_allowlist is not None: op_registration_allowlist = set(provided_op_registration_allowlist) @@ -1359,11 +1369,11 @@ def get_custom_build_selector( def get_grouped_by_view_native_functions( native_functions: Sequence[NativeFunction], -) -> Sequence[NativeFunction | NativeFunctionsViewGroup]: +) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]: def maybe_create_view_group( - d: dict[ViewSchemaKind | SchemaKind, NativeFunction] - ) -> list[NativeFunction | NativeFunctionsViewGroup]: - funcs: list[NativeFunction | NativeFunctionsViewGroup] = [] + d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction] + ) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]: + funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = [] if ViewSchemaKind.aliasing in d: view = d.pop(ViewSchemaKind.aliasing) view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None) @@ -1381,8 +1391,8 @@ def get_grouped_by_view_native_functions( funcs.extend(d.values()) return funcs - grouped_by_views: dict[ - FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction] + grouped_by_views: Dict[ + FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction] ] = defaultdict(dict) for f in native_functions: schema = f.func.view_signature() @@ -1406,10 +1416,10 @@ def get_grouped_by_view_native_functions( def get_grouped_native_functions( native_functions: Sequence[NativeFunction], -) -> Sequence[NativeFunction | NativeFunctionsGroup]: +) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]: def flatten_pre_group( - d: dict[SchemaKind, NativeFunction] - ) -> Sequence[NativeFunction | NativeFunctionsGroup]: + d: Dict[SchemaKind, NativeFunction] + ) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]: r = NativeFunctionsGroup.from_dict(d) if r is None: # Invariant: any NativeFunctions that are code-generated @@ -1428,13 +1438,13 @@ def get_grouped_native_functions( def get_ns_grouped_kernels( *, - grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], - backend_indices: dict[DispatchKey, BackendIndex], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + backend_indices: Dict[DispatchKey, BackendIndex], native_function_decl_gen: Callable[ - [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str] + [Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str] ] = dest.compute_native_function_declaration, -) -> dict[str, list[str]]: - ns_grouped_kernels: dict[str, list[str]] = defaultdict(list) +) -> Dict[str, List[str]]: + ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list) for f in grouped_native_functions: native_function_namespaces = set() dispatch_keys = set() @@ -1457,9 +1467,9 @@ def get_ns_grouped_kernels( def get_native_function_declarations_from_ns_grouped_kernels( *, - ns_grouped_kernels: dict[str, list[str]], -) -> list[str]: - declarations: list[str] = [] + ns_grouped_kernels: Dict[str, List[str]], +) -> List[str]: + declarations: List[str] = [] newline = "\n" for namespace, kernels in ns_grouped_kernels.items(): ns_helper = NamespaceHelper( @@ -1485,12 +1495,12 @@ def get_native_function_declarations_from_ns_grouped_kernels( # Return native function declarations grouped by their namespaces. def get_native_function_declarations( *, - grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], - backend_indices: dict[DispatchKey, BackendIndex], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + backend_indices: Dict[DispatchKey, BackendIndex], native_function_decl_gen: Callable[ - [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str] + [Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str] ] = dest.compute_native_function_declaration, -) -> list[str]: +) -> List[str]: """ Generate kernel declarations, in `NativeFunction(s).h`. :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`. @@ -1510,7 +1520,7 @@ def get_native_function_declarations( def get_kernel_namespace( - *, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex + *, f: Union[NativeFunction, NativeFunctionsGroup], backend_idx: BackendIndex ) -> str: backend_metadata = backend_idx.get_kernel(f) assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, ( @@ -1528,7 +1538,7 @@ def get_kernel_namespace( def get_native_function_definitions( *, fm: FileManager, - grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], dispatch_key: DispatchKey, backend_idx: BackendIndex, selector: SelectiveBuilder, @@ -1536,11 +1546,11 @@ def get_native_function_definitions( symint: bool, skip_dispatcher_op_registration: bool, gen_dispatch_helpers: bool, -) -> list[str]: - definitions: list[str] = [] - ns_definitions: dict[str, list[str]] = defaultdict(list) - anonymous_definitions: dict[str, list[str]] = defaultdict(list) - registrations: dict[str, dict[str, list[str]]] = defaultdict(dict) +) -> List[str]: + definitions: List[str] = [] + ns_definitions: Dict[str, List[str]] = defaultdict(list) + anonymous_definitions: Dict[str, List[str]] = defaultdict(list) + registrations: Dict[str, Dict[str, List[str]]] = defaultdict(dict) newline = "\n" ns_gen = dest.RegisterDispatchKey( backend_idx, @@ -1630,15 +1640,15 @@ TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{ # Used in CPUFunctions_inl.h and etc. def get_namespaced_declaration( *, - grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], dispatch_key: DispatchKey, backend_idx: BackendIndex, selector: SelectiveBuilder, rocm: bool, symint: bool, -) -> list[str]: - declarations: list[str] = [] - ns_grouped_kernels: dict[str, list[str]] = defaultdict(list) +) -> List[str]: + declarations: List[str] = [] + ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list) newline = "\n" func = dest.RegisterDispatchKey( backend_idx, @@ -1682,8 +1692,8 @@ def get_native_function_schema_registrations( *, native_functions: Sequence[NativeFunction], schema_selector: SelectiveBuilder, -) -> tuple[list[str], str]: - ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list) +) -> Tuple[List[str], str]: + ns_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list) for native_function in native_functions: ns_native_functions[native_function.namespace].append(native_function) schema_registrations = "" @@ -1717,14 +1727,14 @@ def get_native_function_schema_registrations( def gen_aggregated_headers( *, native_functions: Sequence[NativeFunction], - grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], structured_native_functions: Sequence[NativeFunctionsGroup], - static_dispatch_idx: list[BackendIndex], + static_dispatch_idx: List[BackendIndex], selector: SelectiveBuilder, - backend_indices: dict[DispatchKey, BackendIndex], + backend_indices: Dict[DispatchKey, BackendIndex], cpu_fm: FileManager, cuda_fm: FileManager, - functions_keys: set[DispatchKey], + functions_keys: Set[DispatchKey], dispatch_keys: Sequence[DispatchKey], rocm: bool, ) -> None: @@ -1838,25 +1848,25 @@ def gen_aggregated_headers( def gen_per_operator_headers( *, native_functions: Sequence[NativeFunction], - grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], - static_dispatch_idx: list[BackendIndex], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + static_dispatch_idx: List[BackendIndex], selector: SelectiveBuilder, - backend_indices: dict[DispatchKey, BackendIndex], + backend_indices: Dict[DispatchKey, BackendIndex], cpu_fm: FileManager, cuda_fm: FileManager, ops_fm: FileManager, - functions_keys: set[DispatchKey], + functions_keys: Set[DispatchKey], dispatch_keys: Sequence[DispatchKey], rocm: bool, ) -> None: # For CMake builds, split operator declarations into separate headers in # the ATen/ops folder to split up header dependencies - functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list) + functions_by_root_name: Dict[str, List[NativeFunction]] = defaultdict(list) for fn in native_functions: functions_by_root_name[fn.root_name].append(fn) - grouped_functions_by_root_name: dict[ - str, list[NativeFunction | NativeFunctionsGroup] + grouped_functions_by_root_name: Dict[ + str, List[Union[NativeFunction, NativeFunctionsGroup]] ] = defaultdict(list) for group in grouped_native_functions: name = group.root_name @@ -2032,18 +2042,18 @@ def gen_per_operator_headers( def gen_headers( *, native_functions: Sequence[NativeFunction], - valid_tags: set[str], - grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + valid_tags: Set[str], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], structured_native_functions: Sequence[NativeFunctionsGroup], - static_dispatch_idx: list[BackendIndex], + static_dispatch_idx: List[BackendIndex], selector: SelectiveBuilder, - backend_indices: dict[DispatchKey, BackendIndex], + backend_indices: Dict[DispatchKey, BackendIndex], core_fm: FileManager, cpu_fm: FileManager, cuda_fm: FileManager, ops_fm: FileManager, dispatch_keys: Sequence[DispatchKey], - functions_keys: set[DispatchKey], + functions_keys: Set[DispatchKey], rocm: bool, per_operator_headers: bool, ) -> None: @@ -2123,8 +2133,8 @@ def gen_headers( "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions) ) - def gen_aten_interned_strings() -> dict[str, str]: - attrs: set[str] = set() # All function argument names + def gen_aten_interned_strings() -> Dict[str, str]: + attrs: Set[str] = set() # All function argument names names = set() # All ATen function names for func in native_functions: names.add(str(func.func.name.name)) @@ -2161,7 +2171,7 @@ def gen_headers( core_fm.write("aten_interned_strings.h", gen_aten_interned_strings) - def gen_tags_enum() -> dict[str, str]: + def gen_tags_enum() -> Dict[str, str]: return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))} core_fm.write("enum_tag.h", gen_tags_enum) @@ -2170,19 +2180,19 @@ def gen_headers( def gen_source_files( *, native_functions: Sequence[NativeFunction], - grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], structured_native_functions: Sequence[NativeFunctionsGroup], view_groups: Sequence[NativeFunctionsViewGroup], selector: SelectiveBuilder, - static_dispatch_idx: list[BackendIndex], - backend_indices: dict[DispatchKey, BackendIndex], + static_dispatch_idx: List[BackendIndex], + backend_indices: Dict[DispatchKey, BackendIndex], aoti_fm: FileManager, core_fm: FileManager, cpu_fm: FileManager, cpu_vec_fm: FileManager, cuda_fm: FileManager, dispatch_keys: Sequence[DispatchKey], - functions_keys: set[DispatchKey], + functions_keys: Set[DispatchKey], rocm: bool, force_schema_registration: bool, per_operator_headers: bool, @@ -2206,7 +2216,7 @@ def gen_source_files( if per_operator_headers: - def operator_headers() -> list[str]: + def operator_headers() -> List[str]: headers = [] for g in grouped_native_functions: is_registered = False @@ -2248,7 +2258,7 @@ def gen_source_files( else: - def operator_headers() -> list[str]: + def operator_headers() -> List[str]: headers = ["#include "] if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: headers.append("#include ") @@ -2439,7 +2449,7 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f del fm # BackendSelect is generated specially - def gen_backend_select() -> dict[str, list[str]]: + def gen_backend_select() -> Dict[str, List[str]]: relevant_fns = [ fn for fn in native_functions if needs_backend_select(fn, selector) ] @@ -2484,7 +2494,7 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f ) def key_func( - fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, + fn: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] ) -> str: return fn.root_name @@ -2526,11 +2536,11 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f ) def functionalization_env_callable( - g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, - ) -> dict[str, list[str]]: + g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] + ) -> Dict[str, List[str]]: def gen_op_headers( - g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, - ) -> list[str]: + g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] + ) -> List[str]: if isinstance(g, NativeFunctionsViewGroup): # view ops always get a functionalization kernel headers = [ @@ -2580,8 +2590,8 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f ), } - all_groups: list[ - NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup + all_groups: List[ + Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] ] = list(structured_native_functions) + list( view_groups # type: ignore[assignment, arg-type, operator] ) @@ -2590,11 +2600,11 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic) # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped. # Although this could go away long-term if we add a dedicated dispatch key for decompositions. - structured_map: dict[OperatorName, NativeFunction] = { + structured_map: Dict[OperatorName, NativeFunction] = { f.func.name: f for f in concatMap(lambda g: list(g.functions()), structured_native_functions) } - view_map: dict[OperatorName, NativeFunction] = { + view_map: Dict[OperatorName, NativeFunction] = { f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups) } for f in native_functions: @@ -2705,12 +2715,12 @@ def gen_declarations_yaml( ) -def get_torchgen_root() -> Path: +def get_torchgen_root() -> pathlib.Path: """ If you're depending on torchgen out-of-tree, you can use the root to figure out the path to native_functions.yaml """ - return Path(__file__).parent.resolve() + return pathlib.Path(__file__).parent.resolve() def main() -> None: @@ -2872,11 +2882,11 @@ def main() -> None: # # Invalid character escape '\c'. core_install_dir = f"{options.install_dir}/core" - Path(core_install_dir).mkdir(parents=True, exist_ok=True) + pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True) ops_install_dir = f"{options.install_dir}/ops" - Path(ops_install_dir).mkdir(parents=True, exist_ok=True) + pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True) aoti_install_dir = f"{options.aoti_install_dir}" - Path(aoti_install_dir).mkdir(parents=True, exist_ok=True) + pathlib.Path(aoti_install_dir).mkdir(parents=True, exist_ok=True) core_fm = make_file_manager(options=options, install_dir=core_install_dir) cpu_fm = make_file_manager(options=options) @@ -2906,7 +2916,7 @@ def main() -> None: if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist ] - static_dispatch_idx: list[BackendIndex] = [] + static_dispatch_idx: List[BackendIndex] = [] if options.static_dispatch_backend: static_dispatch_idx = [ backend_indices[DispatchKey.parse(key)] @@ -2963,7 +2973,7 @@ def main() -> None: gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm) if options.output_dependencies: - depfile_path = Path(options.output_dependencies).resolve() + depfile_path = pathlib.Path(options.output_dependencies).resolve() depfile_name = depfile_path.name depfile_stem = depfile_path.stem diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index ecb919bce19..d8d91fbaa16 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -1,8 +1,6 @@ -from __future__ import annotations - import textwrap from dataclasses import dataclass -from typing import Sequence +from typing import Dict, List, Optional, Sequence, Tuple, Union from torchgen.api.types import DispatcherSignature from torchgen.api.types.signatures import CppSignature, CppSignatureGroup @@ -71,7 +69,7 @@ base_type_to_callsite_expr = { # convert args to C types, names in declarations, and expressions in function bodies -def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]: # type: ignore[return] +def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str], List[str], List[str]]: # type: ignore[return] if isinstance(typ, BaseType): if typ.name in base_type_to_c_type: return ( @@ -169,12 +167,12 @@ def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str ) -def zip_type_and_name(types: list[str], names: list[str]) -> list[str]: +def zip_type_and_name(types: List[str], names: List[str]) -> List[str]: return [typ + " " + name for typ, name in zip(types, names)] # Generate argument declarations and callsite expressions -def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]: +def gen_arguments(flat_arguments: Sequence[Argument]) -> Tuple[List[str], List[str]]: types = [] new_names = [] callsite_exprs = [] @@ -191,7 +189,7 @@ def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[s # Return values are passed out as pointer arguments because all the C shim functions # are expected to return AOTITorchError. # Generate returns as declarations and callsite expressions -def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]: +def gen_returns(schema: FunctionSchema) -> Tuple[List[str], List[str]]: types = [] names = [] for idx, ret in enumerate(schema.returns): @@ -224,7 +222,7 @@ def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]: ret_pointer_can_be_null = True break - callsite_exprs: list[str] = [] + callsite_exprs: List[str] = [] for idx, ret in enumerate(schema.returns): tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)" assert isinstance(ret.type, BaseType) @@ -238,12 +236,12 @@ def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]: # gen.py generates header first and then src, so caching the result here to avoid duplicate work -declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {} +declaration_definition_cache: Dict[Tuple[str, str, str], Tuple[str, str]] = {} def gen_declaration_and_definition( schema: FunctionSchema, device: str, backend_call: str -) -> tuple[str, str]: +) -> Tuple[str, str]: func_name = schema.name.unambiguous_name() global declaration_definition_cache @@ -256,7 +254,7 @@ def gen_declaration_and_definition( args, callsite_exprs = gen_arguments( [*schema.arguments.out, *schema.arguments.flat_non_out] ) - ret_assignments: list[str] = [] + ret_assignments: List[str] = [] else: args, callsite_exprs = gen_arguments(schema.arguments.flat_all) # ignore return values for inplace ops @@ -286,7 +284,7 @@ def gen_declaration_and_definition( def gen_static_dispatch_backend_call_signature( - sig: CppSignature | DispatcherSignature, + sig: Union[CppSignature, DispatcherSignature], f: NativeFunction, ) -> CppSignature: sig = DispatcherSignature.from_schema(f.func) @@ -312,10 +310,10 @@ def gen_static_dispatch_backend_call( def get_backend_index_for_aoti( func: NativeFunction, - func_group_mapping: dict[OperatorName, NativeFunctionsGroup], + func_group_mapping: Dict[OperatorName, NativeFunctionsGroup], dispatch_key: DispatchKey, - backend_indices: dict[DispatchKey, BackendIndex], -) -> BackendIndex | None: + backend_indices: Dict[DispatchKey, BackendIndex], +) -> Optional[BackendIndex]: backend_index = None if backend_indices[dispatch_key].has_kernel(func) or ( func.structured_delegate is not None @@ -343,10 +341,10 @@ def get_backend_index_for_aoti( def get_header_for_aoti( func: NativeFunction, - func_group_mapping: dict[OperatorName, NativeFunctionsGroup], + func_group_mapping: Dict[OperatorName, NativeFunctionsGroup], dispatch_key: DispatchKey, - backend_indices: dict[DispatchKey, BackendIndex], -) -> str | None: + backend_indices: Dict[DispatchKey, BackendIndex], +) -> Optional[str]: backend_index = get_backend_index_for_aoti( func, func_group_mapping, dispatch_key, backend_indices ) @@ -367,11 +365,11 @@ def get_fallback_op_name(func: NativeFunction) -> str: def gen_c_shim( func: NativeFunction, - func_group_mapping: dict[OperatorName, NativeFunctionsGroup], + func_group_mapping: Dict[OperatorName, NativeFunctionsGroup], dispatch_key: DispatchKey, - backend_indices: dict[DispatchKey, BackendIndex], + backend_indices: Dict[DispatchKey, BackendIndex], header: bool, -) -> str | None: +) -> Optional[str]: backend_index = get_backend_index_for_aoti( func, func_group_mapping, dispatch_key, backend_indices ) @@ -401,16 +399,16 @@ def gen_c_shim( @dataclass(frozen=True) class ShimGenerator: - func_group_mapping: dict[OperatorName, NativeFunctionsGroup] + func_group_mapping: Dict[OperatorName, NativeFunctionsGroup] dispatch_key: DispatchKey - backend_indices: dict[DispatchKey, BackendIndex] + backend_indices: Dict[DispatchKey, BackendIndex] header: bool # True to generate .h and False to generate .cpp @method_with_native_function def __call__( self, func: NativeFunction, - ) -> str | None: + ) -> Optional[str]: result = gen_c_shim( func, self.func_group_mapping, @@ -423,9 +421,9 @@ class ShimGenerator: def gen_aoti_c_shim( native_functions: Sequence[NativeFunction], - func_group_mapping: dict[OperatorName, NativeFunctionsGroup], + func_group_mapping: Dict[OperatorName, NativeFunctionsGroup], dispatch_key: DispatchKey, - backend_indices: dict[DispatchKey, BackendIndex], + backend_indices: Dict[DispatchKey, BackendIndex], header: bool, includes: str = "", ) -> str: diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index 54ae0755ea2..208e1534329 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -1,11 +1,9 @@ -from __future__ import annotations - import argparse import os import re from collections import Counter, defaultdict, namedtuple from pathlib import Path -from typing import Sequence +from typing import Dict, List, Optional, Sequence, Set, Union import yaml @@ -38,10 +36,10 @@ ParsedExternalYaml = namedtuple( def parse_backend_yaml( backend_yaml_path: str, - grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], - backend_indices: dict[DispatchKey, BackendIndex], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + backend_indices: Dict[DispatchKey, BackendIndex], ) -> ParsedExternalYaml: - native_functions_map: dict[OperatorName, NativeFunction] = { + native_functions_map: Dict[OperatorName, NativeFunction] = { f.func.name: f for f in concatMap( lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), @@ -121,14 +119,14 @@ def parse_backend_yaml( Only the following keys are supported: {", ".join(valid_keys)}' def create_backend_index( - backend_ops: list[str], - symint_ops: set[str], + backend_ops: List[str], + symint_ops: Set[str], dispatch_key: DispatchKey, *, use_out_as_primary: bool, use_device_guard: bool, ) -> BackendIndex: - metadata: dict[OperatorName, BackendMetadata] = {} + metadata: Dict[OperatorName, BackendMetadata] = {} for op in backend_ops: op_name = OperatorName.parse(op) assert ( @@ -151,7 +149,7 @@ Only the following keys are supported: {", ".join(valid_keys)}' index=metadata, ) - backend_key: DispatchKey | None = None + backend_key: Optional[DispatchKey] = None if len(supported) > 0: with context( lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.' @@ -168,7 +166,7 @@ Only the following keys are supported: {", ".join(valid_keys)}' assert backend_key not in backend_indices backend_indices[backend_key] = backend_idx - autograd_key: DispatchKey | None = None + autograd_key: Optional[DispatchKey] = None if len(supported_autograd) > 0: with context( lambda: f'The "autograd" key was specified, which indicates that you would like to override \ @@ -247,12 +245,12 @@ autograd key. They cannot be mix and matched. If this is something you need, fee def error_on_missing_kernels( native_functions: Sequence[NativeFunction], - backend_indices: dict[DispatchKey, BackendIndex], + backend_indices: Dict[DispatchKey, BackendIndex], backend_key: DispatchKey, - autograd_key: DispatchKey | None, + autograd_key: Optional[DispatchKey], class_name: str, kernel_defn_file_path: str, - full_codegen: list[OperatorName] | None = None, + full_codegen: Optional[List[OperatorName]] = None, ) -> None: try: with open(kernel_defn_file_path) as f: @@ -270,7 +268,7 @@ def error_on_missing_kernels( ) # Quick mapping from each OperatorName used by the external backend # to its backend kernel name - expected_backend_op_names: dict[OperatorName, str] = dict( + expected_backend_op_names: Dict[OperatorName, str] = dict( list( concatMap( lambda index: [ @@ -280,13 +278,13 @@ def error_on_missing_kernels( ) ) ) - expected_backend_native_funcs: list[NativeFunction] = [ + expected_backend_native_funcs: List[NativeFunction] = [ f for f in native_functions if f.func.name in expected_backend_op_names.keys() and f.func.name not in full_codegen ] - expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict( + expected_backend_kernel_name_counts: Dict[str, List[NativeFunction]] = defaultdict( list ) for native_f in expected_backend_native_funcs: @@ -358,10 +356,10 @@ def gen_dispatchkey_nativefunc_headers( fm: FileManager, class_name: str, cpp_namespace: str, - backend_indices: dict[DispatchKey, BackendIndex], - grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + backend_indices: Dict[DispatchKey, BackendIndex], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], backend_dispatch_key: DispatchKey, - autograd_dispatch_key: DispatchKey | None, + autograd_dispatch_key: Optional[DispatchKey], backend_name: str = "", ) -> None: assert class_name is not None @@ -415,11 +413,11 @@ def gen_dispatcher_registrations( fm: FileManager, output_dir: str, class_name: str, - backend_indices: dict[DispatchKey, BackendIndex], - grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + backend_indices: Dict[DispatchKey, BackendIndex], + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], backend_dispatch_key: DispatchKey, dispatch_key: DispatchKey, - selector: SelectiveBuilder, + selector: "SelectiveBuilder", # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends build_in_tree: bool = False, per_operator_headers: bool = False, @@ -526,7 +524,7 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() { def run( - source_yaml: str, output_dir: str, dry_run: bool, impl_path: str | None = None + source_yaml: str, output_dir: str, dry_run: bool, impl_path: Optional[str] = None ) -> None: # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py pytorch_root = Path(__file__).absolute().parent.parent diff --git a/torchgen/gen_executorch.py b/torchgen/gen_executorch.py index 0e8d79cf679..436630bb664 100644 --- a/torchgen/gen_executorch.py +++ b/torchgen/gen_executorch.py @@ -1,11 +1,9 @@ -from __future__ import annotations - import argparse import os +import pathlib from collections import defaultdict from dataclasses import dataclass -from pathlib import Path -from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union import yaml @@ -47,6 +45,7 @@ from torchgen.model import ( OperatorName, Variant, ) +from torchgen.selective_build.selector import SelectiveBuilder from torchgen.utils import ( context, FileManager, @@ -56,11 +55,7 @@ from torchgen.utils import ( ) -if TYPE_CHECKING: - from torchgen.selective_build.selector import SelectiveBuilder - - -def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str: +def _sig_decl_wrapper(sig: Union[CppSignature, ExecutorchCppSignature]) -> str: """ A wrapper function to basically get `sig.decl(include_context=True)`. For ATen kernel, the codegen has no idea about ET contextArg, so we @@ -77,9 +72,9 @@ def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str: def static_dispatch( - sig: CppSignature | ExecutorchCppSignature, + sig: Union[CppSignature, ExecutorchCppSignature], f: NativeFunction, - backend_indices: list[BackendIndex], + backend_indices: List[BackendIndex], ) -> str: """ For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one @@ -118,7 +113,7 @@ TORCH_API inline {_sig_decl_wrapper(sig)} {{ # and the scaffolding to call into the dispatcher from these functions. @dataclass(frozen=True) class ComputeFunction: - static_dispatch_backend_indices: list[BackendIndex] + static_dispatch_backend_indices: List[BackendIndex] selector: SelectiveBuilder @@ -127,7 +122,7 @@ class ComputeFunction: is_custom_op: Callable[[NativeFunction], bool] @method_with_native_function - def __call__(self, f: NativeFunction) -> str | None: + def __call__(self, f: NativeFunction) -> Optional[str]: is_method_variant = False if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"): return None @@ -141,7 +136,7 @@ class ComputeFunction: f"Can't handle native function {f.func} with the following variant specification {f.variants}." ) - sig: CppSignature | ExecutorchCppSignature = ( + sig: Union[CppSignature, ExecutorchCppSignature] = ( CppSignatureGroup.from_native_function( f, method=False, fallback_binding=f.manual_cpp_binding ).most_faithful_signature() @@ -184,10 +179,10 @@ class ComputeCodegenUnboxedKernels: @method_with_nested_native_function def __call__( self, - unbox_kernel_entry: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]], + unbox_kernel_entry: Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]], ) -> str: f: NativeFunction = unbox_kernel_entry[0] - kernel_key: ETKernelKey | list[ETKernelKey] = unbox_kernel_entry[1][0] + kernel_key: Union[ETKernelKey, List[ETKernelKey]] = unbox_kernel_entry[1][0] kernel_meta: BackendMetadata = unbox_kernel_entry[1][1] op_name = f"{f.namespace}::{f.func.name}" @@ -201,7 +196,7 @@ class ComputeCodegenUnboxedKernels: ) if not used_kernel_keys: return "" - sig: CppSignature | ExecutorchCppSignature + sig: Union[CppSignature, ExecutorchCppSignature] argument_type_gen: Callable[..., NamedCType] return_type_gen: Callable[..., CType] if self.use_aten_lib: @@ -295,11 +290,11 @@ def gen_unboxing( ) -> None: # Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata)) def key_func( - item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]] + item: Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]] ) -> str: return item[0].root_name + ":" + item[1][0].to_native_string() - items: list[tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]] = [ + items: List[Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]]] = [ (native_function, (kernel_key, metadata)) for native_function in native_functions for kernel_key, metadata in kernel_index.get_kernels(native_function).items() @@ -330,8 +325,8 @@ def gen_unboxing( @with_native_function_and_index # type: ignore[arg-type] def compute_native_function_declaration( - g: NativeFunctionsGroup | NativeFunction, kernel_index: ETKernelIndex -) -> list[str]: + g: Union[NativeFunctionsGroup, NativeFunction], kernel_index: ETKernelIndex +) -> List[str]: assert isinstance(g, NativeFunction) sig = ExecutorchCppSignature.from_native_function(f=g) metadata_list = kernel_index.get_kernels(g).values() @@ -357,7 +352,7 @@ def gen_functions_declarations( kernel_index: ETKernelIndex, selector: SelectiveBuilder, use_aten_lib: bool, - custom_ops_native_functions: Sequence[NativeFunction] | None = None, + custom_ops_native_functions: Optional[Sequence[NativeFunction]] = None, ) -> str: """ Generates namespace separated C++ function API inline declaration/definitions. @@ -411,13 +406,13 @@ def get_ns_grouped_kernels( kernel_index: ETKernelIndex, native_function_decl_gen: Callable[ [ - NativeFunctionsGroup | NativeFunction, + Union[NativeFunctionsGroup, NativeFunction], ETKernelIndex, ], - list[str], + List[str], ], -) -> dict[str, list[str]]: - ns_grouped_kernels: dict[str, list[str]] = defaultdict(list) +) -> Dict[str, List[str]]: + ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list) for f in native_functions: native_function_namespaces = set() op_kernels = kernel_index.get_kernels(f) @@ -600,7 +595,7 @@ def gen_custom_ops( def translate_native_yaml( tags_yaml_path: str, aten_yaml_path: str, - native_yaml_path: str | None, + native_yaml_path: Optional[str], use_aten_lib: bool, out_file: TextIO, ) -> None: @@ -651,15 +646,15 @@ def translate_native_yaml( skip_native_fns_gen=False, ) - func_to_scoped_name: dict[FunctionSchema, str] = { + func_to_scoped_name: Dict[FunctionSchema, str] = { f.func: f"{f.namespace}::{f.func.name}" for f in native_functions } - op_to_scoped_name: dict[OperatorName, str] = { + op_to_scoped_name: Dict[OperatorName, str] = { func.name: name for func, name in func_to_scoped_name.items() } schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()} - kernel_persist_dict: dict[str, dict[str, Any]] = { + kernel_persist_dict: Dict[str, Dict[str, Any]] = { op_to_scoped_name[op]: v for op, v in persisted_fields.items() } @@ -697,13 +692,13 @@ def translate_native_yaml( def parse_yaml( - path: str | None, + path: Optional[str], tags_yaml_path: str, function_filter: Callable[[NativeFunction], bool], skip_native_fns_gen: bool = False, -) -> tuple[ - list[NativeFunction], - dict[DispatchKey, dict[OperatorName, BackendMetadata]] | ETKernelIndex, +) -> Tuple[ + List[NativeFunction], + Union[Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], ETKernelIndex], ]: if path and os.path.exists(path) and os.stat(path).st_size > 0: with open(path) as f: @@ -740,8 +735,8 @@ def parse_yaml( # (2) Return BackendIndices if kernel index is absent def map_index( - m: dict[OperatorName, BackendMetadata] - ) -> dict[OperatorName, BackendMetadata]: + m: Dict[OperatorName, BackendMetadata] + ) -> Dict[OperatorName, BackendMetadata]: return {op: m[op] for op in m if op in op_names} backend_indices = { @@ -756,11 +751,11 @@ def parse_yaml( def parse_yaml_files( tags_yaml_path: str, aten_yaml_path: str, - native_yaml_path: str | None, - custom_ops_yaml_path: str | None, + native_yaml_path: Optional[str], + custom_ops_yaml_path: Optional[str], selector: SelectiveBuilder, use_aten_lib: bool, -) -> tuple[ETParsedYaml, ETParsedYaml | None]: +) -> Tuple[ETParsedYaml, Optional[ETParsedYaml]]: """Parses functions.yaml and custom_ops.yaml files. Args: @@ -983,7 +978,7 @@ def main() -> None: ) if options.output_dependencies: - depfile_path = Path(options.output_dependencies).resolve() + depfile_path = pathlib.Path(options.output_dependencies).resolve() depfile_name = depfile_path.name depfile_stem = depfile_path.stem diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index d07a99f808a..34b5e617e49 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -1,7 +1,5 @@ -from __future__ import annotations - from dataclasses import dataclass -from typing import Callable, TYPE_CHECKING +from typing import Callable, List, Optional, Tuple, Union from torchgen.api import cpp, dispatcher from torchgen.api.translate import translate @@ -48,13 +46,10 @@ from torchgen.native_function_generation import ( MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, ) +from torchgen.selective_build.selector import SelectiveBuilder from torchgen.utils import dataclass_repr -if TYPE_CHECKING: - from torchgen.selective_build.selector import SelectiveBuilder - - # Note: [Mutable Ops Not Using Functionalization] # Ops in this list currently do not work with functionalization and should be fixed. MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = ( @@ -93,7 +88,7 @@ class GenCompositeViewCopyKernel: backend_index: BackendIndex @method_with_native_function - def __call__(self, g: NativeFunctionsViewGroup) -> str | None: + def __call__(self, g: NativeFunctionsViewGroup) -> Optional[str]: if g.view_copy is None: return None elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy": @@ -165,7 +160,7 @@ at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) { """ -def return_str(rets: tuple[Return, ...], names: list[str]) -> str: +def return_str(rets: Tuple[Return, ...], names: List[str]) -> str: assert len(rets) == len(names) if len(rets) == 0: return "" @@ -189,7 +184,7 @@ def wrapper_name(func: FunctionSchema) -> str: return cpp.name(func) -def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool: +def is_tensor_like(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> bool: return isinstance(a, SelfArgument) or ( isinstance(a, Argument) and a.type.is_tensor_like() ) @@ -199,7 +194,7 @@ def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool: # Some op schemas include non-owning types though (like TensorList), # and when we unwrap them we expect to get out an owning type!. # We also return a lambda that tells you how to conver the non-owning type argument into the owning type. -def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]: +def get_owning_type(t: CType) -> Tuple[CType, Callable[[str], str]]: if t == BaseCType(tensorListT): return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()" if t == BaseCType(iTensorListRefT): @@ -214,9 +209,9 @@ def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]: # (2) a context, to be used by translate(), with all of the relevant bindings. def unwrap_tensor_args( sig: DispatcherSignature, *, is_view_op: bool -) -> tuple[str, list[Binding]]: - context: list[Binding] = [] - unwrapped_tensor_args: list[str] = [] +) -> Tuple[str, List[Binding]]: + context: List[Binding] = [] + unwrapped_tensor_args: List[str] = [] for arg in sig.arguments(): if is_tensor_like(arg.argument): # for tensor inputs, we want to unwrap them before passing them into the redispatch calls. @@ -252,9 +247,9 @@ def unwrap_tensor_args( # converts all tensor-like arguments to meta tensors, which are used to compute stride info. Returns: # (1) a string containing all of the logic that does the conversions. # (2) a context, to be used by translate(), with all of the relevant bindings. -def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]: - context: list[Binding] = [] - unwrapped_tensor_args: list[str] = [] +def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]: + context: List[Binding] = [] + unwrapped_tensor_args: List[str] = [] for arg in sig.arguments(): if is_tensor_like(arg.argument): # for tensor inputs, we want to unwrap them before passing them into the redispatch calls. @@ -322,7 +317,7 @@ def emit_expr_has_symbolic_values(expr: str, type: CType) -> str: # Detects whether any of the SymInt arguments are, in fact, symbolic values. # This is used in the constructor of ViewMeta. -def emit_has_symbolic_inputs(sig: DispatcherSignature) -> tuple[str, str]: +def emit_has_symbolic_inputs(sig: DispatcherSignature) -> Tuple[str, str]: name = "has_symbolic_inputs" statements = [ f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});" @@ -527,7 +522,7 @@ def maybe_create_output(f: NativeFunction, var_name: str) -> str: # - the names of returns corresponding to the (immutable) outputs of the inner redispatched function def get_mutable_redispatch_return_names( f: NativeFunction, inner_return_var: str -) -> tuple[list[str], list[str]]: +) -> Tuple[List[str], List[str]]: aliased_returns = [] non_aliased_returns = [] for i, name in enumerate(f.func.aliased_return_names()): @@ -756,11 +751,11 @@ def emit_inplace_functionalization_body( # See Note [Functionalization Pass: View Inverses]. def gen_functionalization_view_inverse_declaration( selector: SelectiveBuilder, g: NativeFunctionsViewGroup -) -> str | None: +) -> Optional[str]: # For every (non-composite) view op, we need a corresponding "inverse view" function. # This generates the declarations so we get a good compiler error when someone adds a new view. @with_native_function - def emit_decl_helper(g: NativeFunctionsViewGroup) -> str | None: + def emit_decl_helper(g: NativeFunctionsViewGroup) -> Optional[str]: if g.view.has_composite_implicit_autograd_kernel: return None view_inverse_sig = ViewInverseSignature(g) @@ -771,9 +766,9 @@ def gen_functionalization_view_inverse_declaration( def gen_functionalization_registration( selector: SelectiveBuilder, - g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, + g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup], composite_implicit_autograd_index: BackendIndex, -) -> list[str]: +) -> List[str]: @with_native_function def emit_registration_helper(f: NativeFunction) -> str: assert not f.has_composite_implicit_autograd_kernel @@ -837,8 +832,8 @@ def gen_functionalization_definition( # (and instead only need to operate on grouped NativeFunctions). # The only reason currently is because we need to emit direct dispatch registrations # For CompositeImplicitAutograd operators, which are potentially ungrouped. - g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, -) -> list[str]: + g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup], +) -> List[str]: # Don't generate kernels in mobile build if not selector.include_all_operators: return [] diff --git a/torchgen/gen_lazy_tensor.py b/torchgen/gen_lazy_tensor.py index 461f530034e..52b42094356 100644 --- a/torchgen/gen_lazy_tensor.py +++ b/torchgen/gen_lazy_tensor.py @@ -1,10 +1,19 @@ -from __future__ import annotations - import argparse import os from collections import namedtuple from pathlib import Path -from typing import Any, Callable, Iterable, Iterator, Sequence +from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) import yaml @@ -93,8 +102,8 @@ ParsedExternalYaml = namedtuple( def parse_native_functions_keys( backend_yaml_path: str, - grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], -) -> tuple[list[OperatorName], list[Any], list[OperatorName]]: + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], +) -> Tuple[List[OperatorName], List[Any], List[OperatorName]]: with open(backend_yaml_path) as f: yaml_values = yaml.load(f, Loader=YamlLoader) assert isinstance(yaml_values, dict) @@ -111,7 +120,7 @@ def parse_native_functions_keys( def validate_shape_inference_header( - shape_inference_hdr: str, expected_shape_infr_decls: list[str] + shape_inference_hdr: str, expected_shape_infr_decls: List[str] ) -> None: try: with open(shape_inference_hdr) as f: @@ -171,12 +180,12 @@ std::vector to_meta(at::ITensorListRef t_list) { class default_args: node_base: str = "Node" - node_base_hdr: str | None = None + node_base_hdr: Optional[str] = None shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h" tensor_class: str = "torch::lazy::LazyTensor" tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h" - lazy_ir_generator: type[GenLazyIR] = GenLazyIR - native_func_definition_generator: type[ + lazy_ir_generator: Type[GenLazyIR] = GenLazyIR + native_func_definition_generator: Type[ GenLazyNativeFuncDefinition ] = GenLazyNativeFuncDefinition backend_name: str = "TorchScript" @@ -254,10 +263,10 @@ def main() -> None: # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py torch_root = Path(__file__).absolute().parents[2] aten_path = str(torch_root / "aten" / "src" / "ATen") - lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator + lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator if options.gen_ts_lowerings: lazy_ir_generator = GenTSLazyIR - native_func_definition_generator: type[ + native_func_definition_generator: Type[ GenLazyNativeFuncDefinition ] = default_args.native_func_definition_generator @@ -283,14 +292,14 @@ def run_gen_lazy_tensor( source_yaml: str, output_dir: str, dry_run: bool, - impl_path: str | None, + impl_path: Optional[str], node_base: str = default_args.node_base, - node_base_hdr: str | None = default_args.node_base_hdr, + node_base_hdr: Optional[str] = default_args.node_base_hdr, tensor_class: str = default_args.tensor_class, tensor_class_hdr: str = default_args.tensor_class_hdr, shape_inference_hdr: str = default_args.shape_inference_hdr, - lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator, - native_func_definition_generator: type[ + lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator, + native_func_definition_generator: Type[ GenLazyNativeFuncDefinition ] = default_args.native_func_definition_generator, # build_in_tree is true for TS backend and affects include paths @@ -338,7 +347,7 @@ def run_gen_lazy_tensor( ) grouped_native_functions = get_grouped_native_functions(native_functions) - def sort_native_function(f: NativeFunctionsGroup | NativeFunction) -> str: + def sort_native_function(f: Union[NativeFunctionsGroup, NativeFunction]) -> str: """ We sort the native function because of the note in concat_map_codegen. TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly. @@ -368,8 +377,8 @@ def run_gen_lazy_tensor( def concat_map_codegen( func: Callable[[NativeFunction], Sequence[str]], - xs: Iterable[NativeFunctionsGroup | NativeFunction], - ops_list: list[OperatorName] = full_codegen, + xs: Iterable[Union[NativeFunctionsGroup, NativeFunction]], + ops_list: List[OperatorName] = full_codegen, ) -> Iterator[str]: """ We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we diff --git a/torchgen/gen_vmap_plumbing.py b/torchgen/gen_vmap_plumbing.py index 7a87121613c..ac7ac283dd3 100644 --- a/torchgen/gen_vmap_plumbing.py +++ b/torchgen/gen_vmap_plumbing.py @@ -1,8 +1,6 @@ -from __future__ import annotations - import textwrap from dataclasses import dataclass -from typing import Sequence +from typing import List, Optional, Sequence, Tuple from torchgen.api.translate import translate from torchgen.api.types import DispatcherSignature @@ -34,7 +32,7 @@ def is_tensor_list(typ: Type) -> bool: return isinstance(typ, ListType) and is_tensor(typ.elem) -def unwrap_tensor(name: str, cur_level_var: str) -> list[str]: +def unwrap_tensor(name: str, cur_level_var: str) -> List[str]: result = f"""\ Tensor {name}_value; optional {name}_bdim; @@ -42,7 +40,7 @@ def unwrap_tensor(name: str, cur_level_var: str) -> list[str]: return textwrap.dedent(result).split("\n") -def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]: +def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]: result = f"""\ optional {name}_value; optional {name}_bdim; @@ -54,7 +52,7 @@ def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]: def gen_unwraps( flat_arguments: Sequence[Argument], cur_level_var: str -) -> tuple[str, list[str]]: +) -> Tuple[str, List[str]]: arg_names = [a.name for a in flat_arguments] arg_types = [a.type for a in flat_arguments] @@ -101,7 +99,7 @@ if ({' && '.join(conditions)}) {{ def gen_returns( - returns: tuple[Return, ...], cur_level_var: str, results_var: str + returns: Tuple[Return, ...], cur_level_var: str, results_var: str ) -> str: idx = 0 wrapped_returns = [] @@ -134,7 +132,7 @@ def is_mutated_arg(argument: Argument) -> bool: return argument.annotation is not None and argument.annotation.is_write -def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None: +def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]: # Assumptions: # - only one argument is being modified in-place # - the argument that is being modified in-place is the first argument @@ -199,7 +197,7 @@ template }}""" -def gen_vmap_plumbing(native_function: NativeFunction) -> str | None: +def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]: schema = native_function.func sig = DispatcherSignature.from_schema(schema) returns = schema.returns @@ -246,7 +244,7 @@ template @dataclass(frozen=True) class ComputeBatchRulePlumbing: @method_with_native_function - def __call__(self, f: NativeFunction) -> str | None: + def __call__(self, f: NativeFunction) -> Optional[str]: result = gen_vmap_plumbing(f) return result diff --git a/torchgen/local.py b/torchgen/local.py index 7c687c3a799..09532c7bfc6 100644 --- a/torchgen/local.py +++ b/torchgen/local.py @@ -1,8 +1,6 @@ -from __future__ import annotations - import threading from contextlib import contextmanager -from typing import Iterator +from typing import Iterator, Optional # Simple dynamic scoping implementation. The name "parametrize" comes @@ -19,8 +17,8 @@ from typing import Iterator class Locals(threading.local): - use_const_ref_for_mutable_tensors: bool | None = None - use_ilistref_for_tensor_lists: bool | None = None + use_const_ref_for_mutable_tensors: Optional[bool] = None + use_ilistref_for_tensor_lists: Optional[bool] = None _locals = Locals() diff --git a/torchgen/model.py b/torchgen/model.py index 33e5e1427f5..e150cb7cf6f 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -1,11 +1,9 @@ -from __future__ import annotations - import dataclasses import itertools import re from dataclasses import dataclass from enum import auto, Enum -from typing import Callable, Iterator, Sequence +from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union from torchgen.utils import assert_never, NamespaceHelper, OrderedSet @@ -231,7 +229,7 @@ class DispatchKey(Enum): return str(self).lower() @staticmethod - def parse(value: str) -> DispatchKey: + def parse(value: str) -> "DispatchKey": for k, v in DispatchKey.__members__.items(): if k == value: return v @@ -352,20 +350,20 @@ class ScalarType(Enum): return self.name @staticmethod - def maybe_parse(value: str) -> ScalarType | None: + def maybe_parse(value: str) -> Optional["ScalarType"]: for k, v in ScalarType.__members__.items(): if k == value: return v return None @staticmethod - def parse(value: str) -> ScalarType: + def parse(value: str) -> "ScalarType": mb_r = ScalarType.maybe_parse(value) assert mb_r is not None, f"unknown dtype {value}" return mb_r @staticmethod - def parse_set(values: str) -> OrderedSet[ScalarType]: + def parse_set(values: str) -> OrderedSet["ScalarType"]: dtypes: OrderedSet[ScalarType] = OrderedSet() for value in values.split(", "): if value in DTYPE_CLASSES: @@ -375,7 +373,7 @@ class ScalarType(Enum): return dtypes -DTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {} +DTYPE_CLASSES: Dict[str, OrderedSet[ScalarType]] = {} # NB: Integral doesn't include boolean DTYPE_CLASSES["Integral"] = OrderedSet( [ @@ -421,7 +419,7 @@ class UfuncKey(Enum): return self.name @staticmethod - def parse(value: str) -> UfuncKey: + def parse(value: str) -> "UfuncKey": for k, v in UfuncKey.__members__.items(): if k == value: return v @@ -464,7 +462,7 @@ class NativeFunction: # (This type is quoted as we are forward referencing a type # defined later in the file. I opted for this ordering of the # classes for expository clarity.) - func: FunctionSchema + func: "FunctionSchema" # Whether or not to generate mutable tensor arguments like regular # ones @@ -477,14 +475,14 @@ class NativeFunction: device_check: DeviceCheckType # What python module to put the function in - python_module: str | None + python_module: Optional[str] # TODO: figure out what this does - category_override: str | None + category_override: Optional[str] # If no variants are specified in native_functions.yaml, this is # assumed to be {'function'}. - variants: set[Variant] + variants: Set[Variant] # Whether or not we should skip generating registrations for # this kernel. This is a bit of a double-edged sword, as manual @@ -499,7 +497,7 @@ class NativeFunction: # The location in the YAML file were this native function entry was # defined. This is for conveniently reporting error messages! - loc: Location + loc: "Location" # A list of operators that are expected to be auto-generated for this NativeFunction. # Note: This list isn't actually directly used by the codegen to generate anything. @@ -507,11 +505,11 @@ class NativeFunction: # function schema, and uses the autogen declarations to error check. # We expect every NativeFunction that gets auto-generated be explicitly called out # in native_functions.yaml - autogen: list[OperatorName] + autogen: List["OperatorName"] # If non-empty, this kernel is subject to ufunc codegen. # Sorted by ufunc_key - ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop] + ufunc_inner_loop: Dict[UfuncKey, "UfuncInnerLoop"] # Whether or not this out functions is a "structured kernel". Structured # kernels are defined a little differently from normal kernels; in @@ -524,13 +522,13 @@ class NativeFunction: # Whether or not this non-out function is a structured kernel, defined # in terms of the out kernel referenced by the string here. - structured_delegate: OperatorName | None + structured_delegate: Optional["OperatorName"] # Only valid for structured kernels. Specifies alternative of what # to inherit from when defining the meta class for the structured # operator. This will usually be TensorIteratorBase. This also # changes the semantics of set_output to call the parent class. - structured_inherits: str | None + structured_inherits: Optional[str] # Structured kernels can declare elements as "precomputed". These elements # are returned by the meta function in one struct and passed to the impl @@ -538,11 +536,11 @@ class NativeFunction: # elements supersede. Information about the names and types of these # precomputed elements and how they correspond to kernel arguments is stored # in this member, if applicable. - precomputed: Precompute | None + precomputed: Optional["Precompute"] # Argument names whose default should be excluded from the C++ interface. # Intended for resolving overload ambiguities between signatures. - cpp_no_default_args: set[str] + cpp_no_default_args: Set[str] # Note [Abstract ATen methods] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -562,7 +560,7 @@ class NativeFunction: # Tags are used to describe semantic information about (groups of) operators, # That aren't easily inferrable directly from the operator's schema. - tags: set[str] + tags: Set[str] # NB: The benefit of defining a dataclass is that we automatically get # a constructor defined for all the fields we specify. No need @@ -571,11 +569,13 @@ class NativeFunction: # We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex. @staticmethod def from_yaml( - ei: dict[str, object], - loc: Location, - valid_tags: set[str], - ignore_keys: set[DispatchKey] | None = None, - ) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]: + ei: Dict[str, object], + loc: "Location", + valid_tags: Set[str], + ignore_keys: Optional[Set[DispatchKey]] = None, + ) -> Tuple[ + "NativeFunction", Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]] + ]: """ Parse a NativeFunction from a dictionary as directly parsed from native_functions.yaml @@ -602,7 +602,7 @@ class NativeFunction: variants_s = e.pop("variants", "function") assert isinstance(variants_s, str) - variants: set[Variant] = set() + variants: Set[Variant] = set() for v in variants_s.split(", "): if v == "function": variants.add(Variant.function) @@ -646,7 +646,7 @@ class NativeFunction: "namespace is not supported in structured delegate," " using the same namespace as the native function" ) - structured_delegate: OperatorName | None = None + structured_delegate: Optional[OperatorName] = None if structured_delegate_s is not None: structured_delegate = OperatorName.parse(structured_delegate_s) @@ -685,7 +685,7 @@ class NativeFunction: if namespace == "aten" and "pt2_compliant_tag" in valid_tags: tags_inp.append("pt2_compliant_tag") - tags: set[str] = set() + tags: Set[str] = set() for t in tags_inp: assert len(valid_tags) > 0 # TODO: verify that the tag is valid and has an entry in tags.yaml @@ -698,7 +698,7 @@ class NativeFunction: raw_dispatch = e.pop("dispatch", None) assert raw_dispatch is None or isinstance(raw_dispatch, dict), e - dispatch: dict[DispatchKey, BackendMetadata] = {} + dispatch: Dict[DispatchKey, BackendMetadata] = {} num_dispatch_keys: int = 0 if raw_dispatch is not None: assert not manual_kernel_registration, ( @@ -1081,8 +1081,8 @@ class SchemaKind(Enum): @dataclass(frozen=True) class NativeFunctionsGroup: functional: NativeFunction - inplace: NativeFunction | None - mutable: NativeFunction | None + inplace: Optional[NativeFunction] + mutable: Optional[NativeFunction] out: NativeFunction @property @@ -1136,7 +1136,7 @@ class NativeFunctionsGroup: [str(f.func.name) for f in self.functions() if "generated" in f.tags] ) generated_fns_str = ", ".join(str(x) for x in generated_fns) - expected_generated_fns: set[str] = set() + expected_generated_fns: Set[str] = set() for f in self.functions(): expected_generated_fns.update(str(op) for op in f.autogen) expected_generated_fns_str = ", ".join( @@ -1155,7 +1155,7 @@ class NativeFunctionsGroup: f" Instead, it found 'autogen: {expected_generated_fns_str}'" ) - def signature(self) -> FunctionSchema: + def signature(self) -> "FunctionSchema": return self.out.func.signature() def functions(self) -> Iterator[NativeFunction]: @@ -1171,7 +1171,9 @@ class NativeFunctionsGroup: return self.functional.root_name @staticmethod - def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None: + def from_dict( + d: Dict[SchemaKind, NativeFunction] + ) -> Optional["NativeFunctionsGroup"]: assert d if len(d) == 1: return None @@ -1227,7 +1229,7 @@ class UfuncInnerLoop: ufunc_key: UfuncKey @staticmethod - def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop: + def parse(value: str, ufunc_key: UfuncKey) -> "UfuncInnerLoop": name, supported_dtypes_str = value.split(" ", 1) assert supported_dtypes_str[0] == "(" assert supported_dtypes_str[-1] == ")" @@ -1259,12 +1261,12 @@ class BackendIndex: # Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA) external: bool # Other backend-specific information that is on a per-operator basis - index: dict[OperatorName, BackendMetadata] + index: Dict["OperatorName", BackendMetadata] @staticmethod def grow_index( - parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]], - child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]], + parent_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]], + child_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]], ) -> None: for k, v in child_index.items(): for op_name, metadata in v.items(): @@ -1279,13 +1281,13 @@ class BackendIndex: else: return g.functional - def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool: + def has_kernel(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool: m = self.get_kernel(g) return m is not None def get_kernel( - self, g: NativeFunction | NativeFunctionsGroup - ) -> BackendMetadata | None: + self, g: Union[NativeFunction, NativeFunctionsGroup] + ) -> Optional[BackendMetadata]: if isinstance(g, NativeFunction): f = g elif isinstance(g, NativeFunctionsGroup): @@ -1296,7 +1298,7 @@ class BackendIndex: return None return self.index[f.func.name] - def native_function_class_name(self) -> str | None: + def native_function_class_name(self) -> Optional[str]: if self.external: return f"{str(self.dispatch_key)}NativeFunctions" else: @@ -1362,16 +1364,16 @@ class BackendIndex: @dataclass(frozen=True) class FunctionSchema: # The name of the operator this function schema describes. - name: OperatorName + name: "OperatorName" - arguments: Arguments + arguments: "Arguments" # TODO: Need to handle collisions with argument names at some point - returns: tuple[Return, ...] + returns: Tuple["Return", ...] @property def is_mutable(self) -> bool: - def is_write(arg: Argument) -> bool: + def is_write(arg: "Argument") -> bool: if arg.annotation is None: return False return arg.annotation.is_write @@ -1380,7 +1382,7 @@ class FunctionSchema: # See aten/src/ATen/core/function_schema.h (keep these in sync) return any(is_write(a) for a in self.arguments.flat_all) - def schema_order_arguments(self) -> Iterator[Argument]: + def schema_order_arguments(self) -> Iterator["Argument"]: return itertools.chain( self.arguments.flat_positional, self.arguments.flat_kwarg_only, @@ -1390,7 +1392,7 @@ class FunctionSchema: decl_re = re.compile(r"(?P[^\(]+)\((?P.*)\) -> (?P.*)") @staticmethod - def parse(func: str) -> FunctionSchema: + def parse(func: str) -> "FunctionSchema": # We should probably get a proper parser here decls = FunctionSchema.decl_re.findall(func) assert len(decls) == 1, f"Invalid function schema: {func}" @@ -1585,8 +1587,8 @@ class FunctionSchema: # - If the return aliases an input, we return the input name # - Otherwise, we return None. # If return names were enforced to be consistent with aliasing information, then we wouldn't need this. - def aliased_return_names(self) -> list[str | None]: - outs: list[str | None] = [] + def aliased_return_names(self) -> List[Optional[str]]: + outs: List[Optional[str]] = [] for r in self.returns: aliased_args = [ a @@ -1610,7 +1612,7 @@ class FunctionSchema: strip_default: bool = False, strip_view_copy_name: bool = False, keep_return_names: bool = False, - ) -> FunctionSchema: + ) -> "FunctionSchema": """ Certain schemas are 'related', in that they are simply inplace/out/functional versions of the same function. This method @@ -1707,10 +1709,10 @@ class FunctionSchema: returns=returns, ) - def view_signature(self) -> FunctionSchema: + def view_signature(self) -> "FunctionSchema": return self.signature(strip_view_copy_name=True) - def with_name(self, name: OperatorName) -> FunctionSchema: + def with_name(self, name: "OperatorName") -> "FunctionSchema": return FunctionSchema( name=name, arguments=self.arguments, @@ -1745,12 +1747,12 @@ class FunctionSchema: class Annotation: # Typically only has one element. Not actually a set so # we can conveniently assume it is canonically ordered - alias_set: tuple[str, ...] + alias_set: Tuple[str, ...] is_write: bool - alias_set_after: tuple[str, ...] + alias_set_after: Tuple[str, ...] @staticmethod - def parse(ann: str) -> Annotation: + def parse(ann: str) -> "Annotation": # TODO: implement a proper parser if this gets more ugly # Regex Explanation: # Example: "a! -> a|b" @@ -1803,13 +1805,13 @@ class Annotation: @dataclass(frozen=True) class Type: @staticmethod - def parse(t: str) -> Type: + def parse(t: str) -> "Type": r = Type._parse(t) assert str(r) == t, f"{r} != {t}" return r @staticmethod - def _parse(t: str) -> Type: + def _parse(t: str) -> "Type": m = re.match(r"^(.+)\?$", t) if m is not None: return OptionalType(Type.parse(m.group(1))) @@ -1835,7 +1837,7 @@ class Type: # so we can conveniently generate legacy Declarations.yaml but # really we should probably just remove these at some point - def is_base_ty_like(self, base_ty: BaseTy) -> bool: + def is_base_ty_like(self, base_ty: "BaseTy") -> bool: raise NotImplementedError def is_tensor_like(self) -> bool: @@ -1850,7 +1852,7 @@ class Type: def is_nullable(self) -> bool: raise NotImplementedError - def is_list_like(self) -> ListType | None: + def is_list_like(self) -> Optional["ListType"]: raise NotImplementedError @@ -1890,7 +1892,7 @@ class BaseType(Type): def is_nullable(self) -> bool: return False - def is_list_like(self) -> ListType | None: + def is_list_like(self) -> Optional["ListType"]: return None def is_symint_like(self) -> bool: @@ -1914,7 +1916,7 @@ class OptionalType(Type): def is_nullable(self) -> bool: return True - def is_list_like(self) -> ListType | None: + def is_list_like(self) -> Optional["ListType"]: return self.elem.is_list_like() @@ -1941,7 +1943,7 @@ class CustomClassType(Type): """ return False - def is_list_like(self) -> ListType | None: + def is_list_like(self) -> Optional["ListType"]: return None @@ -1955,7 +1957,7 @@ class CustomClassType(Type): @dataclass(frozen=True) class ListType(Type): elem: Type - size: int | None + size: Optional[int] def __str__(self) -> str: size = f"{self.size}" if self.size else "" @@ -1970,7 +1972,7 @@ class ListType(Type): def is_nullable(self) -> bool: return self.elem.is_nullable() - def is_list_like(self) -> ListType | None: + def is_list_like(self) -> Optional["ListType"]: return self @@ -1981,7 +1983,7 @@ class Argument: name: str type: Type - default: str | None + default: Optional[str] # The semantics of the annotation field are a little strange. # @@ -2002,16 +2004,16 @@ class Argument: # structure of annotated types is very simple. So we just hard # code it here. But if we ever do get anything more complex, this # model will have to change! - annotation: Annotation | None + annotation: Optional[Annotation] @property - def alias_info(self) -> Annotation | None: + def alias_info(self) -> Optional[Annotation]: return self.annotation @staticmethod - def parse(arg: str) -> Argument: + def parse(arg: str) -> "Argument": name: str - default: str | None + default: Optional[str] assert " " in arg, f"illegal argument '{arg}'" type_and_annot, name_and_default = arg.rsplit(" ", 1) if "=" in name_and_default: @@ -2024,7 +2026,7 @@ class Argument: default = None # TODO: deduplicate annotation matching with Return match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) - annotation: Annotation | None + annotation: Optional[Annotation] if match: # If you update this, make sure the __str__ still works too assert match.group(2) in [ @@ -2067,24 +2069,24 @@ class Argument: @dataclass(frozen=True) class Return: - name: str | None + name: Optional[str] type: Type - annotation: Annotation | None + annotation: Optional[Annotation] @property - def alias_info(self) -> Annotation | None: + def alias_info(self) -> Optional[Annotation]: return self.annotation @staticmethod - def parse(arg: str) -> Return: - name: str | None + def parse(arg: str) -> "Return": + name: Optional[str] if " " in arg: type_and_annot, name = arg.rsplit(" ", 1) else: type_and_annot = arg name = None match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) - annotation: Annotation | None + annotation: Optional[Annotation] if match: # If you update this, make sure the __str__ still works too assert match.group(2) in [ @@ -2146,34 +2148,34 @@ class Arguments: # pre_self_positional is usually empty, but is notably non-empty # for where.self, where the condition argument comes before the # self argument - pre_self_positional: tuple[Argument, ...] - self_arg: SelfArgument | None - post_self_positional: tuple[Argument, ...] + pre_self_positional: Tuple[Argument, ...] + self_arg: Optional[SelfArgument] + post_self_positional: Tuple[Argument, ...] - pre_tensor_options_kwarg_only: tuple[Argument, ...] - tensor_options: TensorOptionsArguments | None + pre_tensor_options_kwarg_only: Tuple[Argument, ...] + tensor_options: Optional[TensorOptionsArguments] # post_tensor_options is typically memory format, which should be # part of tensor options but isn't right now, and is usually # placed after the tensor options arguments - post_tensor_options_kwarg_only: tuple[Argument, ...] + post_tensor_options_kwarg_only: Tuple[Argument, ...] # Unlike in the previous codegen, we have factored out 'out' arguments # in the canonical representation, removing them from kwarg # arguments. This choice is justified by numerous downstream # transformations which treat out arguments specially; additionally, # you can see that canonicity is not violated! - out: tuple[Argument, ...] # these are also kwarg-only + out: Tuple[Argument, ...] # these are also kwarg-only @property def flat_non_out(self) -> Sequence[Argument]: - ret: list[Argument] = [] + ret: List[Argument] = [] ret.extend(self.flat_positional) ret.extend(self.flat_kwarg_only) return ret @property def flat_positional(self) -> Sequence[Argument]: - ret: list[Argument] = [] + ret: List[Argument] = [] ret.extend(self.pre_self_positional) if self.self_arg is not None: ret.append(self.self_arg.argument) @@ -2187,7 +2189,7 @@ class Arguments: # NB: doesn't contain out arguments @property def flat_kwarg_only(self) -> Sequence[Argument]: - ret: list[Argument] = [] + ret: List[Argument] = [] ret.extend(self.pre_tensor_options_kwarg_only) if self.tensor_options is not None: ret.extend(self.tensor_options.all()) @@ -2196,7 +2198,7 @@ class Arguments: @property def flat_all(self) -> Sequence[Argument]: - ret: list[Argument] = [] + ret: List[Argument] = [] ret.extend(self.flat_positional) ret.extend(self.flat_kwarg_only) ret.extend(self.out) @@ -2205,15 +2207,15 @@ class Arguments: @property def non_out( self, - ) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]: - ret: list[Argument | SelfArgument | TensorOptionsArguments] = [] + ) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]: + ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = [] ret.extend(self.positional) ret.extend(self.kwarg_only) return ret @property - def positional(self) -> Sequence[Argument | SelfArgument]: - ret: list[Argument | SelfArgument] = [] + def positional(self) -> Sequence[Union[Argument, SelfArgument]]: + ret: List[Union[Argument, SelfArgument]] = [] ret.extend(self.pre_self_positional) if self.self_arg is not None: ret.append(self.self_arg) @@ -2221,8 +2223,8 @@ class Arguments: return ret @property - def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]: - ret: list[Argument | TensorOptionsArguments] = [] + def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]: + ret: List[Union[Argument, TensorOptionsArguments]] = [] ret.extend(self.pre_tensor_options_kwarg_only) if self.tensor_options is not None: ret.append(self.tensor_options) @@ -2230,14 +2232,14 @@ class Arguments: return ret @property - def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]: - ret: list[Argument | SelfArgument | TensorOptionsArguments] = [] + def all(self) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]: + ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = [] ret.extend(self.positional) ret.extend(self.kwarg_only) ret.extend(self.out) return ret - def mutable_arg_names(self) -> list[str]: + def mutable_arg_names(self) -> List[str]: return [ a.name for a in self.flat_all @@ -2253,7 +2255,7 @@ class Arguments: def has_generator_arg(self) -> bool: return any(a.type.is_generator_like() for a in self.flat_non_out) - def signature(self, *, strip_default: bool = False) -> Arguments: + def signature(self, *, strip_default: bool = False) -> "Arguments": # dataclasses.replace could be used here, but it is less # type safe so for now I've opted to type everything out def strip_arg_annotation(a: Argument) -> Argument: @@ -2288,7 +2290,7 @@ class Arguments: out=(), ) - def remove_self_annotation(self) -> Arguments: + def remove_self_annotation(self) -> "Arguments": assert self.self_arg is not None return dataclasses.replace( self, @@ -2297,7 +2299,7 @@ class Arguments: ), ) - def with_out_args(self, outs: list[Argument]) -> Arguments: + def with_out_args(self, outs: List[Argument]) -> "Arguments": assert len(self.out) == 0 return dataclasses.replace( self, @@ -2305,10 +2307,10 @@ class Arguments: ) @staticmethod - def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]: - positional: list[Argument] = [] - kwarg_only: list[Argument] = [] - out: list[Argument] = [] + def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]: + positional: List[Argument] = [] + kwarg_only: List[Argument] = [] + out: List[Argument] = [] arguments_acc = positional # TODO: Use a real parser here; this will get bamboozled @@ -2341,7 +2343,7 @@ class Arguments: return positional, kwarg_only, out @staticmethod - def parse(args: str) -> Arguments: + def parse(args: str) -> "Arguments": """ Input: 'int x, int y, int z' """ @@ -2359,9 +2361,9 @@ class Arguments: if a.name == "self": self_ix = i break - pre_self_positional: list[Argument] - self_arg: SelfArgument | None - post_self_positional: list[Argument] + pre_self_positional: List[Argument] + self_arg: Optional[SelfArgument] + post_self_positional: List[Argument] if self_ix is not None: pre_self_positional = positional[:self_ix] self_arg = SelfArgument(positional[self_ix]) @@ -2372,9 +2374,9 @@ class Arguments: post_self_positional = positional # Group tensor options arguments - pre_tensor_options_kwarg_only: list[Argument] = [] - tensor_options: TensorOptionsArguments | None = None - post_tensor_options_kwarg_only: list[Argument] = [] + pre_tensor_options_kwarg_only: List[Argument] = [] + tensor_options: Optional[TensorOptionsArguments] = None + post_tensor_options_kwarg_only: List[Argument] = [] kwarg_only_acc = pre_tensor_options_kwarg_only def pred(name: str, ty: Type) -> Callable[[Argument], bool]: @@ -2421,7 +2423,7 @@ class Arguments: ) def __str__(self) -> str: - all_arguments: list[str] = [] + all_arguments: List[str] = [] all_arguments.extend(map(str, self.flat_positional)) if self.flat_kwarg_only or self.out: all_arguments.append("*") @@ -2500,7 +2502,7 @@ class BaseOperatorName: functional_overload: bool = False @staticmethod - def parse(op: str) -> BaseOperatorName: + def parse(op: str) -> "BaseOperatorName": assert op != "" assert not op.endswith("_out"), ( "_out suffix is reserved and not permitted for operator names; " @@ -2572,7 +2574,7 @@ class OperatorName: overload_name: str @staticmethod - def parse(op_name: str) -> OperatorName: + def parse(op_name: str) -> "OperatorName": if "." in op_name: name, overload_name = op_name.split(".", 1) else: @@ -2599,7 +2601,7 @@ class OperatorName: else: return f"{self.name}" - def remove_inplace(self) -> OperatorName: + def remove_inplace(self) -> "OperatorName": return OperatorName( name=BaseOperatorName( base=self.name.base, @@ -2609,7 +2611,7 @@ class OperatorName: overload_name=self.overload_name, ) - def with_overload(self, overload: str) -> OperatorName: + def with_overload(self, overload: str) -> "OperatorName": return OperatorName( name=BaseOperatorName( base=self.name.base, @@ -2647,9 +2649,9 @@ class NativeFunctionsViewGroup: # Note: the {view}_copy operator is optional because we currently don't generate copy variants # for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views # (we already get them "for free" through decomposition) - view_copy: NativeFunction | None + view_copy: Optional[NativeFunction] # view_inplace ops are also optional, but every view_inplace op should have out-of-place variant. - view_inplace: NativeFunction | None + view_inplace: Optional[NativeFunction] def __post_init__(self) -> None: assert self.view.is_view_op @@ -2729,7 +2731,7 @@ def gets_generated_view_copy(f: NativeFunction) -> bool: # Given a NativeFunction that corresponds to a view op, # returns the OperatorName of the corresponding "copy" variant of the op. -def get_view_copy_name(f: NativeFunction) -> OperatorName: +def get_view_copy_name(f: NativeFunction) -> "OperatorName": # Right now, when asking for a view op's corresponding "view_copy" name # we assert for sanity that the op is allowed to have a generated view_copy variant. # (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op). @@ -2753,7 +2755,7 @@ def get_view_copy_name(f: NativeFunction) -> OperatorName: # Helper functions for parsing argument lists (both inputs and returns) -def parse_returns(return_decl: str) -> tuple[Return, ...]: +def parse_returns(return_decl: str) -> Tuple[Return, ...]: """ Input: '()' Output: [] @@ -2772,12 +2774,12 @@ def parse_returns(return_decl: str) -> tuple[Return, ...]: class Precompute: # A map from kernel argument name -> a list of precomputed # elements that replaces/supersedes it. - replace: dict[str, list[Argument]] + replace: Dict[str, List[Argument]] # List of precomputed args added without replacement - add: list[Argument] + add: List[Argument] @staticmethod - def parse(src: object) -> Precompute: + def parse(src: object) -> "Precompute": assert isinstance(src, list) # src is a list of strings of the format: @@ -2822,7 +2824,7 @@ class Precompute: for a in args: assert a.name.upper() != a.name - def to_list(self) -> list[str]: + def to_list(self) -> List[str]: replace_list = [] for kernel_param, replacement_params in self.replace.items(): replacements = ", ".join(str(param) for param in replacement_params) diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index a44efab6842..3705944309d 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -1,7 +1,5 @@ -from __future__ import annotations - from collections import defaultdict -from typing import Sequence +from typing import Dict, List, Optional, Sequence, Tuple, Union import torchgen.api.dispatcher as dispatcher from torchgen.api.translate import translate @@ -103,9 +101,9 @@ INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [ # But have differing SchemaKinds. def pre_group_native_functions( native_functions: Sequence[NativeFunction], -) -> dict[FunctionSchema, dict[SchemaKind, NativeFunction]]: - pre_grouped_native_functions: dict[ - FunctionSchema, dict[SchemaKind, NativeFunction] +) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]: + pre_grouped_native_functions: Dict[ + FunctionSchema, Dict[SchemaKind, NativeFunction] ] = defaultdict(dict) for f in native_functions: d = pre_grouped_native_functions[f.func.signature()] @@ -115,7 +113,7 @@ def pre_group_native_functions( # Returns the out variant overload name given a base function overload name -def get_expected_out_variant_overload_name(overload_name: str | None) -> str: +def get_expected_out_variant_overload_name(overload_name: Optional[str]) -> str: return "out" if not overload_name else f"{overload_name}_out" @@ -180,7 +178,7 @@ def functional_to_out_signature(func: FunctionSchema) -> FunctionSchema: # Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations. def generate_out_args_from_schema( func: FunctionSchema, -) -> tuple[list[Return], list[Argument]]: +) -> Tuple[List[Return], List[Argument]]: # More of a sanity check - our existing restrictions on schemas should enforce that # mutable schema kinds never return their mutable arguments. assert not any( @@ -200,11 +198,11 @@ def generate_out_args_from_schema( all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns) - new_out_args: list[Argument] = [] + new_out_args: List[Argument] = [] # The end result of new_returns is that: # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added. # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any). - new_returns: list[Return] = [] + new_returns: List[Return] = [] for i, r in enumerate(func.returns): if r.type.is_tensor_like(): new_out = Argument( @@ -268,7 +266,7 @@ def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema: # Details are in the function, but we only generate composite kernels (in some cases) today. def generate_function( f: NativeFunction, k: SchemaKind -) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]: +) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]: from torchgen.api import cpp if k == SchemaKind.functional: @@ -377,8 +375,8 @@ def generate_function( # Note: this function *mutates* its two inputs, # adding the new NativeFunctions / BackendMetadata to them def add_generated_native_functions( - rs: list[NativeFunction], - indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]], + rs: List[NativeFunction], + indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], ) -> None: # The main code for generating new NativeFunctions # First we group of NativeFunctions by schema kind, @@ -499,7 +497,7 @@ out= variant is not needed, please add the function name into FUNCTIONAL_OPS_THA rs.append(fn) -def return_str(rets: tuple[Return, ...], names: list[str]) -> str: +def return_str(rets: Tuple[Return, ...], names: List[str]) -> str: assert len(rets) == len(names) if len(rets) == 0: return "" @@ -511,7 +509,7 @@ def return_str(rets: tuple[Return, ...], names: list[str]) -> str: # Given a function, and the name of a variable corresponding to the output of that function, # gather up all of the individual returns that are not aliased -def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str]: +def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> List[str]: aliased_rets = func.aliased_return_names() non_aliased_names = [] is_out_var_a_tuple = len(func.returns) > 1 @@ -526,7 +524,7 @@ def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str # Generates functional kernels in terms of their inplace.mutable counterparts. # We only do this for "generated" NativeFunctions @with_native_function -def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None: +def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]: # We should only be generating these for code-generated NativeFunctions if "generated" not in g.functional.tags: return None @@ -543,7 +541,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None: sig = DispatcherSignature(g.functional.func) target_sig = DispatcherSignature(target_f.func) - context: list[Binding | Expr] = [] + context: List[Union[Binding, Expr]] = [] clone_mutable_inputs = [] cloned_return_names = [] # We can't just directly pass all of the arguments from the functional op into the mutating op. @@ -589,7 +587,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None: # Generates out= kernels in terms of their functional counterparts. # We only do this for "generated" NativeFunctions @with_native_function -def gen_composite_out_kernel(g: NativeFunctionsGroup) -> str | None: +def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]: # We should only be generating these for code-generated NativeFunctions if "generated" not in g.out.tags: return None diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py index 362ce427d50..18b2952c9ea 100644 --- a/torchgen/operator_versions/gen_mobile_upgraders.py +++ b/torchgen/operator_versions/gen_mobile_upgraders.py @@ -1,12 +1,9 @@ #!/usr/bin/env python3 - -from __future__ import annotations - import os from enum import Enum from operator import itemgetter from pathlib import Path -from typing import Any +from typing import Any, Dict, List import torch from torch.jit.generate_bytecode import generate_upgraders_bytecode @@ -188,7 +185,7 @@ PER_OPERATOR_UPGRADER_LIST = CodeTemplate( ) -def construct_instruction(instruction_list_from_yaml: list[Any]) -> str: +def construct_instruction(instruction_list_from_yaml: List[Any]) -> str: instruction_list_part = [] for instruction in instruction_list_from_yaml: instruction_list_part.append( @@ -203,7 +200,7 @@ def construct_instruction(instruction_list_from_yaml: list[Any]) -> str: ) -def construct_constants(constants_list_from_yaml: list[Any]) -> str: +def construct_constants(constants_list_from_yaml: List[Any]) -> str: constants_list_part = [] for constant_from_yaml in constants_list_from_yaml: convert_constant = None @@ -229,7 +226,7 @@ def construct_constants(constants_list_from_yaml: list[Any]) -> str: ) -def construct_operators(operator_list_from_yaml: list[Any]) -> str: +def construct_operators(operator_list_from_yaml: List[Any]) -> str: operator_list_part = [] for operator in operator_list_from_yaml: operator_list_part.append( @@ -244,7 +241,7 @@ def construct_operators(operator_list_from_yaml: list[Any]) -> str: ) -def construct_types(types_tr_list_from_yaml: list[Any]) -> str: +def construct_types(types_tr_list_from_yaml: List[Any]) -> str: types_tr_list_part = [] for types_tr in types_tr_list_from_yaml: types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr)) @@ -263,7 +260,7 @@ def construct_register_size(register_size_from_yaml: int) -> str: def construct_version_maps( - upgrader_bytecode_function_to_index_map: dict[str, Any] + upgrader_bytecode_function_to_index_map: Dict[str, Any] ) -> str: version_map = torch._C._get_operator_version_map() sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return] @@ -305,8 +302,8 @@ def construct_version_maps( def get_upgrader_bytecode_function_to_index_map( - upgrader_dict: list[dict[str, Any]] -) -> dict[str, Any]: + upgrader_dict: List[Dict[str, Any]] +) -> Dict[str, Any]: upgrader_bytecode_function_to_index_map = {} index = 0 for upgrader_bytecode in upgrader_dict: @@ -318,7 +315,7 @@ def get_upgrader_bytecode_function_to_index_map( return upgrader_bytecode_function_to_index_map -def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None: +def write_cpp(cpp_path: str, upgrader_dict: List[Dict[str, Any]]) -> None: body_parts = [] upgrader_bytecode_function_to_index_map = ( get_upgrader_bytecode_function_to_index_map(upgrader_dict) @@ -373,7 +370,7 @@ def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None: out_file.write(upgrader_file_content.encode("utf-8")) -def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]: +def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: sorted_upgrader_list = sorted( upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader)) ) diff --git a/torchgen/selective_build/operator.py b/torchgen/selective_build/operator.py index 0cb92dfc09e..939d97ff94c 100644 --- a/torchgen/selective_build/operator.py +++ b/torchgen/selective_build/operator.py @@ -1,6 +1,5 @@ -from __future__ import annotations - from dataclasses import dataclass +from typing import Dict, Optional, Tuple # This class holds information about a single operator used to determine @@ -47,12 +46,12 @@ class SelectiveBuildOperator: include_all_overloads: bool # Debug Information at the operator level - _debug_info: tuple[str, ...] | None + _debug_info: Optional[Tuple[str, ...]] @staticmethod def from_yaml_dict( - op_name: str, op_info: dict[str, object] - ) -> SelectiveBuildOperator: + op_name: str, op_info: Dict[str, object] + ) -> "SelectiveBuildOperator": allowed_keys = { "name", "is_root_operator", @@ -80,7 +79,7 @@ class SelectiveBuildOperator: include_all_overloads = op_info.get("include_all_overloads", True) assert isinstance(include_all_overloads, bool) - debug_info: tuple[str, ...] | None = None + debug_info: Optional[Tuple[str, ...]] = None if "debug_info" in op_info: di_list = op_info["debug_info"] assert isinstance(di_list, list) @@ -97,7 +96,7 @@ class SelectiveBuildOperator: @staticmethod def from_legacy_operator_name_without_overload( name: str, - ) -> SelectiveBuildOperator: + ) -> "SelectiveBuildOperator": return SelectiveBuildOperator( name=name, is_root_operator=True, @@ -106,8 +105,8 @@ class SelectiveBuildOperator: _debug_info=None, ) - def to_dict(self) -> dict[str, object]: - ret: dict[str, object] = { + def to_dict(self) -> Dict[str, object]: + ret: Dict[str, object] = { "is_root_operator": self.is_root_operator, "is_used_for_training": self.is_used_for_training, "include_all_overloads": self.include_all_overloads, @@ -119,9 +118,9 @@ class SelectiveBuildOperator: def merge_debug_info( - lhs: tuple[str, ...] | None, - rhs: tuple[str, ...] | None, -) -> tuple[str, ...] | None: + lhs: Optional[Tuple[str, ...]], + rhs: Optional[Tuple[str, ...]], +) -> Optional[Tuple[str, ...]]: # Ensure that when merging, each entry shows up just once. if lhs is None and rhs is None: return None @@ -130,8 +129,8 @@ def merge_debug_info( def combine_operators( - lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator -) -> SelectiveBuildOperator: + lhs: "SelectiveBuildOperator", rhs: "SelectiveBuildOperator" +) -> "SelectiveBuildOperator": if str(lhs.name) != str(rhs.name): raise Exception( # noqa: TRY002 f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead" @@ -153,10 +152,10 @@ def combine_operators( def merge_operator_dicts( - lhs: dict[str, SelectiveBuildOperator], - rhs: dict[str, SelectiveBuildOperator], -) -> dict[str, SelectiveBuildOperator]: - operators: dict[str, SelectiveBuildOperator] = {} + lhs: Dict[str, SelectiveBuildOperator], + rhs: Dict[str, SelectiveBuildOperator], +) -> Dict[str, SelectiveBuildOperator]: + operators: Dict[str, SelectiveBuildOperator] = {} for op_name, op in list(lhs.items()) + list(rhs.items()): new_op = op if op_name in operators: diff --git a/torchgen/selective_build/selector.py b/torchgen/selective_build/selector.py index 04acc354203..aa60349966a 100644 --- a/torchgen/selective_build/selector.py +++ b/torchgen/selective_build/selector.py @@ -1,12 +1,11 @@ -from __future__ import annotations - from collections import defaultdict from collections.abc import Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Dict, List, Optional, Set, Tuple import yaml +from torchgen.model import NativeFunction from torchgen.selective_build.operator import ( merge_debug_info, merge_operator_dicts, @@ -15,10 +14,6 @@ from torchgen.selective_build.operator import ( ) -if TYPE_CHECKING: - from torchgen.model import NativeFunction - - # A SelectiveBuilder holds information extracted from the selective build # YAML specification. # @@ -33,10 +28,10 @@ class SelectiveBuilder: include_all_operators: bool # Debug Information at the selective/custom build level. - _debug_info: tuple[str, ...] | None + _debug_info: Optional[Tuple[str, ...]] # A dictionary of operator -> operator metadata. - operators: dict[str, SelectiveBuildOperator] + operators: Dict[str, SelectiveBuildOperator] # A dictionary of selected kernel tags and dtypes. Typically a # PyTorch Operator Kernel (function) may have many code paths @@ -44,22 +39,22 @@ class SelectiveBuilder: # one per kernel function, but there could be many per kernel # function. The tag isn't a kernel function name, but some fragment # of the kernel function implementation itself. - kernel_metadata: dict[str, list[str]] + kernel_metadata: Dict[str, List[str]] # ExecuTorch only. A dictionary of kernel tag -> list of (list of input # dtypes for tensor-like input args). # This is from selective.yaml - et_kernel_metadata: dict[str, list[str]] + et_kernel_metadata: Dict[str, List[str]] # A set of all the custom torch bind classes used by the selected models # Stored as a set internally to remove duplicates proactively, but written # as a list to yamls - custom_classes: set[str] + custom_classes: Set[str] # A set of all the build features used by the selected models # Stored as a set internally to remove duplicates proactively, but written # as a list to yamls - build_features: set[str] + build_features: Set[str] # If true, then fragments for all dtypes for all kernel functions # are included as well as all custom classes. This is typically set when any one of the @@ -68,11 +63,11 @@ class SelectiveBuilder: include_all_non_op_selectives: bool @staticmethod - def get_nop_selector() -> SelectiveBuilder: + def get_nop_selector() -> "SelectiveBuilder": return SelectiveBuilder.from_yaml_dict({"include_all_operators": True}) @staticmethod - def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder: + def from_yaml_dict(data: Dict[str, object]) -> "SelectiveBuilder": valid_top_level_keys = { "include_all_non_op_selectives", "include_all_operators", @@ -140,20 +135,20 @@ class SelectiveBuilder: ) @staticmethod - def from_yaml_str(config_contents: str) -> SelectiveBuilder: + def from_yaml_str(config_contents: str) -> "SelectiveBuilder": contents = yaml.safe_load(config_contents) return SelectiveBuilder.from_yaml_dict(contents) @staticmethod - def from_yaml_path(config_path: str) -> SelectiveBuilder: + def from_yaml_path(config_path: str) -> "SelectiveBuilder": with open(config_path) as f: contents = yaml.safe_load(f) return SelectiveBuilder.from_yaml_dict(contents) @staticmethod def from_legacy_op_registration_allow_list( - allow_list: set[str], is_root_operator: bool, is_used_for_training: bool - ) -> SelectiveBuilder: + allow_list: Set[str], is_root_operator: bool, is_used_for_training: bool + ) -> "SelectiveBuilder": operators = {} for op in allow_list: operators[op] = { @@ -236,7 +231,7 @@ class SelectiveBuilder: and dtype in self.kernel_metadata[kernel_tag] ) - def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]: + def et_get_selected_kernels(self, op_name: str, kernel_key: List[str]) -> List[str]: """ Return a list of kernel keys that cover the used ops """ @@ -266,8 +261,8 @@ class SelectiveBuilder: return list(result_set) - def to_dict(self) -> dict[str, object]: - ret: dict[str, object] = { + def to_dict(self) -> Dict[str, object]: + ret: Dict[str, object] = { "include_all_non_op_selectives": self.include_all_non_op_selectives, "include_all_operators": self.include_all_operators, } @@ -293,10 +288,10 @@ class SelectiveBuilder: def merge_kernel_metadata( - lhs: dict[str, list[str]], - rhs: dict[str, list[str]], -) -> dict[str, list[str]]: - kernel_metadata: dict[str, list[str]] = {} + lhs: Dict[str, List[str]], + rhs: Dict[str, List[str]], +) -> Dict[str, List[str]]: + kernel_metadata: Dict[str, List[str]] = {} for tag_name, dtypes in list(lhs.items()) + list(rhs.items()): dtypes_copy = set(dtypes) if tag_name in kernel_metadata: @@ -308,10 +303,10 @@ def merge_kernel_metadata( def merge_et_kernel_metadata( - lhs: dict[str, list[str]], - rhs: dict[str, list[str]], -) -> dict[str, list[str]]: - merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set) + lhs: Dict[str, List[str]], + rhs: Dict[str, List[str]], +) -> Dict[str, List[str]]: + merge_et_kernel_metadata: Dict[str, Set[str]] = defaultdict(set) for op in list(lhs.keys()) + list(rhs.keys()): merge_et_kernel_metadata[op].update(lhs.get(op, [])) merge_et_kernel_metadata[op].update(rhs.get(op, [])) diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py index 56a3d8bf0dd..bdfd5c75b28 100644 --- a/torchgen/shape_functions/gen_jit_shape_functions.py +++ b/torchgen/shape_functions/gen_jit_shape_functions.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 +import importlib.util import os import sys -from importlib.util import module_from_spec, spec_from_file_location from itertools import chain from pathlib import Path @@ -18,9 +18,9 @@ you are in the root directory of the Pytorch git repo""" if not file_path.exists(): raise Exception(err_msg) # noqa: TRY002 -spec = spec_from_file_location(module_name, file_path) +spec = importlib.util.spec_from_file_location(module_name, file_path) assert spec is not None -module = module_from_spec(spec) +module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module assert spec.loader is not None assert module is not None diff --git a/torchgen/static_runtime/config.py b/torchgen/static_runtime/config.py index 1e7b541fa2c..da6e2a21c2a 100644 --- a/torchgen/static_runtime/config.py +++ b/torchgen/static_runtime/config.py @@ -1,9 +1,9 @@ -from __future__ import annotations +from typing import Dict, Union from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup -def func_name_base_str(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> str: +def func_name_base_str(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> str: if isinstance(g, NativeFunctionsGroup): return str(g.functional.func.name.name.base) else: @@ -55,12 +55,12 @@ is_hand_written_ops_ = frozenset( ) -def is_hand_written(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: +def is_hand_written(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool: name_base = func_name_base_str(g) return name_base in is_hand_written_ops_ -def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> None: +def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> None: assert index == 0 or index == 1 if op_name == "addr": if index == 0: diff --git a/torchgen/static_runtime/gen_static_runtime_ops.py b/torchgen/static_runtime/gen_static_runtime_ops.py index 9f735717374..93a4436fd22 100644 --- a/torchgen/static_runtime/gen_static_runtime_ops.py +++ b/torchgen/static_runtime/gen_static_runtime_ops.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import argparse import itertools import os @@ -30,7 +28,7 @@ def group_functions_by_op_name( return [] groups = [] - def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: + def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool: with native_function_manager(g): return generator.is_supported(g) diff --git a/torchgen/static_runtime/generator.py b/torchgen/static_runtime/generator.py index 7bbb7f64d86..7960679660b 100644 --- a/torchgen/static_runtime/generator.py +++ b/torchgen/static_runtime/generator.py @@ -1,9 +1,7 @@ -from __future__ import annotations - import json import logging import math -from typing import Sequence +from typing import Dict, List, Optional, Sequence, Tuple, Union import torchgen.api.cpp as cpp from torchgen.context import native_function_manager @@ -27,7 +25,7 @@ logger: logging.Logger = logging.getLogger() def has_alias( - arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments], + arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]] ) -> bool: for arg in arguments: annotation = getattr(arg, "annotation", None) @@ -239,7 +237,7 @@ BLOCKED_OPS = frozenset( ) -def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: +def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool: base_op_name = "" func = None if isinstance(g, NativeFunctionsViewGroup): @@ -300,8 +298,8 @@ def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: def ivalue_type_conversion_method( - arg_type: BaseType | OptionalType | Type, -) -> tuple[bool, str] | None: + arg_type: Union[BaseType, OptionalType, Type] +) -> Optional[Tuple[bool, str]]: """ Return the method call expression of `c10::ivalue' to convert its contained value to the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor, @@ -396,7 +394,7 @@ def test_tensor_dim(op_name: str) -> int: test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}' -test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string) +test_tensor_shape_json: Dict[str, str] = json.loads(test_tensor_shapes_string) def test_tensor_shape(op_name: str) -> str: @@ -407,7 +405,7 @@ def test_tensor_shape(op_name: str) -> str: def test_value_expression( - arg_type: BaseType | OptionalType | Type, index: int, op_name: str + arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str ) -> str: tensor_size_ex = test_tensor_shape(op_name) if tensor_size_ex == "": @@ -477,8 +475,8 @@ generate_test_ir_arguments_base_ty_to_type_str_ = { def generate_test_ir_arguments( schema: FunctionSchema, -) -> list[tuple[str, str | None]]: - def ir_argument(arg: Argument) -> tuple[str, str | None]: +) -> List[Tuple[str, Optional[str]]]: + def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]: t = arg.type add_optional = False if isinstance(t, OptionalType): diff --git a/torchgen/utils.py b/torchgen/utils.py index abb79900a83..16fd0022f8d 100644 --- a/torchgen/utils.py +++ b/torchgen/utils.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import contextlib import functools import hashlib @@ -7,29 +5,31 @@ import os import re import sys import textwrap +from argparse import Namespace from dataclasses import fields, is_dataclass from enum import auto, Enum from typing import ( Any, Callable, + Dict, Generic, Iterable, Iterator, + List, Literal, NoReturn, + Optional, Sequence, - TYPE_CHECKING, + Set, + Tuple, TypeVar, + Union, ) from typing_extensions import Self from torchgen.code_template import CodeTemplate -if TYPE_CHECKING: - from argparse import Namespace - - # Many of these functions share logic for defining both the definition # and declaration (for example, the function signature is the same), so # we organize them into one function that takes a Target to say which @@ -57,7 +57,7 @@ IDENT_REGEX = r"(^|\W){}($|\W)" # TODO: Use a real parser here; this will get bamboozled -def split_name_params(schema: str) -> tuple[str, list[str]]: +def split_name_params(schema: str) -> Tuple[str, List[str]]: m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema) if m is None: raise RuntimeError(f"Unsupported function schema: {schema}") @@ -73,7 +73,7 @@ S = TypeVar("S") # Map over function that may return None; omit Nones from output sequence -def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]: +def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]: for x in xs: r = func(x) if r is not None: @@ -127,7 +127,7 @@ class FileManager: install_dir: str template_dir: str dry_run: bool - filenames: set[str] + filenames: Set[str] def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None: self.install_dir = install_dir @@ -136,7 +136,7 @@ class FileManager: self.dry_run = dry_run def _write_if_changed(self, filename: str, contents: str) -> None: - old_contents: str | None + old_contents: Optional[str] try: with open(filename) as f: old_contents = f.read() @@ -150,7 +150,7 @@ class FileManager: # Read from template file and replace pattern with callable (type could be dict or str). def substitute_with_template( - self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]] + self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]] ) -> str: template_path = os.path.join(self.template_dir, template_fn) env = env_callable() @@ -171,7 +171,7 @@ class FileManager: self, filename: str, template_fn: str, - env_callable: Callable[[], str | dict[str, Any]], + env_callable: Callable[[], Union[str, Dict[str, Any]]], ) -> None: filename = f"{self.install_dir}/{filename}" assert filename not in self.filenames, "duplicate file write {filename}" @@ -186,7 +186,7 @@ class FileManager: def write( self, filename: str, - env_callable: Callable[[], str | dict[str, Any]], + env_callable: Callable[[], Union[str, Dict[str, Any]]], ) -> None: self.write_with_template(filename, filename, env_callable) @@ -196,13 +196,13 @@ class FileManager: items: Iterable[T], *, key_fn: Callable[[T], str], - env_callable: Callable[[T], dict[str, list[str]]], + env_callable: Callable[[T], Dict[str, List[str]]], num_shards: int, - base_env: dict[str, Any] | None = None, - sharded_keys: set[str], + base_env: Optional[Dict[str, Any]] = None, + sharded_keys: Set[str], ) -> None: - everything: dict[str, Any] = {"shard_id": "Everything"} - shards: list[dict[str, Any]] = [ + everything: Dict[str, Any] = {"shard_id": "Everything"} + shards: List[Dict[str, Any]] = [ {"shard_id": f"_{i}"} for i in range(num_shards) ] all_shards = [everything] + shards @@ -221,7 +221,7 @@ class FileManager: else: shard[key] = [] - def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None: + def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None: for k, v in from_.items(): assert k in sharded_keys, f"undeclared sharded key {k}" into[k] += v @@ -275,7 +275,7 @@ class FileManager: # Helper function to generate file manager def make_file_manager( - options: Namespace, install_dir: str | None = None + options: Namespace, install_dir: Optional[str] = None ) -> FileManager: template_dir = os.path.join(options.source_path, "templates") install_dir = install_dir if install_dir else options.install_dir @@ -335,7 +335,7 @@ def _pformat( def _format_dict( - attr: dict[Any, Any], + attr: Dict[Any, Any], indent: int, width: int, curr_indent: int, @@ -355,7 +355,7 @@ def _format_dict( def _format_list( - attr: list[Any] | set[Any] | tuple[Any, ...], + attr: Union[List[Any], Set[Any], Tuple[Any, ...]], indent: int, width: int, curr_indent: int, @@ -370,7 +370,7 @@ def _format_list( def _format( - fields_str: list[str], + fields_str: List[str], indent: int, width: int, curr_indent: int, @@ -402,9 +402,7 @@ class NamespaceHelper: } // namespace torch """ - def __init__( - self, namespace_str: str, entity_name: str = "", max_level: int = 2 - ) -> None: + def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2): # cpp_namespace can be a colon joined string such as torch::lazy cpp_namespaces = namespace_str.split("::") assert ( @@ -421,7 +419,7 @@ class NamespaceHelper: @staticmethod def from_namespaced_entity( namespaced_entity: str, max_level: int = 2 - ) -> NamespaceHelper: + ) -> "NamespaceHelper": """ Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add" """ @@ -454,9 +452,9 @@ class NamespaceHelper: class OrderedSet(Generic[T]): - storage: dict[T, Literal[None]] + storage: Dict[T, Literal[None]] - def __init__(self, iterable: Iterable[T] | None = None) -> None: + def __init__(self, iterable: Optional[Iterable[T]] = None): if iterable is None: self.storage = {} else: @@ -468,28 +466,28 @@ class OrderedSet(Generic[T]): def __iter__(self) -> Iterator[T]: return iter(self.storage.keys()) - def update(self, items: OrderedSet[T]) -> None: + def update(self, items: "OrderedSet[T]") -> None: self.storage.update(items.storage) def add(self, item: T) -> None: self.storage[item] = None - def copy(self) -> OrderedSet[T]: + def copy(self) -> "OrderedSet[T]": ret: OrderedSet[T] = OrderedSet() ret.storage = self.storage.copy() return ret @staticmethod - def union(*args: OrderedSet[T]) -> OrderedSet[T]: + def union(*args: "OrderedSet[T]") -> "OrderedSet[T]": ret = args[0].copy() for s in args[1:]: ret.update(s) return ret - def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]: + def __or__(self, other: "OrderedSet[T]") -> "OrderedSet[T]": return OrderedSet.union(self, other) - def __ior__(self, other: OrderedSet[T]) -> Self: + def __ior__(self, other: "OrderedSet[T]") -> Self: self.update(other) return self