from typing import List, Union, Tuple, Optional from torchgen.model import ( Type, BaseTy, BaseType, OptionalType, ListType, OperatorName, FunctionSchema, Return, TensorOptionsArguments, Argument, ) from torchgen.api.types import ( CType, BaseCppType, BaseCType, OptionalCType, NamedCType, deviceT, layoutT, VectorCType, boolT, longT, doubleT, ListCType, stringT, scalarT, scalarTypeT, memoryFormatT, SymIntT, ) _valueT = None def getValueT() -> BaseCppType: global _valueT if not _valueT: raise NotImplementedError( "The value type needs to be set with setValueT() in run_gen_lazy_tensor()" ) return _valueT def setValueT(val: BaseCppType) -> None: global _valueT _valueT = val # this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object, # making it easier to represent special properties of an arg. tensorListValueT = BaseCppType("torch::lazy", "Value") def process_ir_type( typ: Type, ) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]: """ This function takes a type from NativeFunctions and converts it for use with lazy tensor codegen. Type conversion for lazy currently consists of (1) changing at::Tensors into lazy::Values (2) wrapping everything in a BaseCType (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef) (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.) There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like' This is incomplete- there are assertions in places that it's expected to need to add more types as the codegen is used with more operators. """ if isinstance(typ, BaseType): if typ.name == BaseTy.Tensor: return BaseCType(getValueT()) elif typ.name == BaseTy.Scalar: # at::scalar has special handling, # and is wrapped in an lazy::Value just like at::tensor return BaseCType(getValueT()) elif typ.name == BaseTy.ScalarType: return BaseCType(scalarTypeT) elif typ.name == BaseTy.int: return BaseCType(longT) elif typ.name == BaseTy.SymInt: return BaseCType(getValueT()) elif typ.name == BaseTy.bool: return BaseCType(boolT) elif typ.name == BaseTy.float: return BaseCType(doubleT) elif typ.name == BaseTy.str: return BaseCType(stringT) elif typ.name == BaseTy.Device: return BaseCType(deviceT) elif typ.name == BaseTy.Layout: return BaseCType(layoutT) elif typ.name == BaseTy.MemoryFormat: return BaseCType(memoryFormatT) else: raise AssertionError(f"TODO add support for type {repr(typ)}") elif isinstance(typ, OptionalType): return OptionalCType(process_ir_type(typ.elem)) elif isinstance(typ, ListType): if str(typ.elem) == "Tensor?": # TODO(whc) is this actually correct? or should it use a Vector like above return ListCType(OptionalCType(BaseCType(getValueT()))) elif str(typ.elem) == "Tensor": # this is a TensorList which comes in from GetTensorList as a Value return BaseCType(tensorListValueT) else: return VectorCType(process_ir_type(typ.elem)) else: raise AssertionError(f"unrecognized type {repr(typ)}") def isValueType(typ: CType) -> bool: """ Given a type, determine if it is a Value-like type. This is equivalent to being Tensor-like, but assumes the type has already been transformed. """ if isinstance(typ, BaseCType): # I am regretting my naming conventions, but now we are wrapping at::scalar in # lazy value, while preserving other 'scalar' types as scalars in the IR return typ.type == getValueT() or typ.type == scalarT or typ.type == SymIntT elif isinstance(typ, (OptionalCType, ListCType, VectorCType)): return isValueType(typ.elem) return False def isSymIntType(typ: Type) -> bool: return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt def isWrappedScalarType(typ: Type) -> bool: """ Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value. Since we literally change the type from scalarT to valueT, information is lost. This function helps build a list of wrapped scalars to save that information """ if isinstance(typ, BaseType): # I am regretting my naming conventions, but now we are wrapping at::scalar in # lazy value, while preserving other 'scalar' types as scalars in the IR return typ.name == BaseTy.Scalar elif isinstance(typ, (OptionalType, ListType)): return isWrappedScalarType(typ.elem) return False def isGeneratorType(typ: Type) -> bool: if isinstance(typ, BaseType): return typ.name == BaseTy.Generator elif isinstance(typ, (OptionalType)): return isGeneratorType(typ.elem) return False class LazyArgument: name: str orig_type: Type lazy_type_: Optional[CType] is_wrapped_scalar: bool is_generator: bool is_symint_or_list: bool # true if this argument is or contains a lazy IR value is_lazy_value: bool def __init__(self, arg: Argument): self.name = arg.name self.orig_type = arg.type self.is_optional = isinstance(arg.type, OptionalType) self.is_generator = isGeneratorType(arg.type) if self.is_generator: assert ( self.is_optional ), "We expect all generators are optional since currently they are" # there is no handling for generators in TorchScript IR (or XLA) # so we fall back to eager if the (optional)generator has value, and otherwise # its null and safe to exclude from lazy IR self.lazy_type_ = None else: self.lazy_type_ = process_ir_type(arg.type) self.is_wrapped_scalar = isWrappedScalarType(arg.type) self.is_symint_or_list = isSymIntType(arg.type) self.is_lazy_value = not self.is_generator and isValueType(self.lazy_type) @property def lazy_type(self) -> CType: assert ( self.lazy_type_ is not None ), f"Attempted to access lazy_type for invalid argument {self.name}" return self.lazy_type_ # Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node. # Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML), # but carries type information from a native FunctionSchema modified for use with IR nodes, # and preserving original argument names. class LazyIrSchema: # The name of the operator this function schema describes. name: "OperatorName" positional_args: Tuple[LazyArgument, ...] keyword_args: Tuple[LazyArgument, ...] # TODO: Need to handle collisions with argument names at some point returns: Tuple["Return", ...] # if this schema has a Generator arg, list its orig ctype/name but don't # build a LazyArgument since lazy IR doesn't support it generator_arg: Optional[NamedCType] = None def __init__(self, func: FunctionSchema): positional_args = [] for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]: if arg_field == "self_arg" and func.arguments.self_arg is not None: arg = getattr(func.arguments, "self_arg").argument positional_args.append(LazyArgument(arg)) elif getattr(func.arguments, arg_field) is not None: positional_args.extend( [LazyArgument(arg) for arg in getattr(func.arguments, arg_field)] ) self.positional_args = tuple(positional_args) keyword_args = [] for arg_field in [ "pre_tensor_options_kwarg_only", "tensor_options", "post_tensor_options_kwarg_only", "out", ]: curr_args = getattr(func.arguments, arg_field) if curr_args is not None: if isinstance(curr_args, TensorOptionsArguments): curr_args = curr_args.all() for arg in curr_args: if isGeneratorType(arg.type): assert ( self.generator_arg is None ), "We expect there is only one generator arg" self.generator_arg = NamedCType(arg.name, arg.type) keyword_args.extend([LazyArgument(arg) for arg in curr_args]) self.keyword_args = tuple(keyword_args) self.name = func.name self.returns = func.returns @property def node_name(self) -> str: """ Return camel-case version of op in node. Note: This function also appends any `overload_name` in the operation. For example, if the op is `bitwise_and.Tensor`, the returned name will be `BitwiseAndTensor`. """ op_name = f"{self.name.name}_{self.name.overload_name}".lower() return "".join(word.capitalize() or "" for word in op_name.split("_")) @property def aten_name(self) -> str: return f"{self.name.name}" @property def base_name(self) -> str: return f"{self.name.name.base}" def filtered_args( self, positional: bool = True, keyword: bool = True, values: bool = True, scalars: bool = True, generator: bool = False, ) -> List[LazyArgument]: # This function maintains the sorted order of arguments but provides different filtered views. # Some parts of the code care about kwargs vs args (TS lowerings), # other parts care about whether they need to wrap the arg in a lazy value or leave it alone. # Generators are special cased, as they are needed for fallback/shape-inference but not supported # in TS lowerings and therefore also omitted from lazy IR. args: List[LazyArgument] = [] if positional: args.extend(self.positional_args) if keyword: args.extend(self.keyword_args) if values and scalars and generator: return args elif values and scalars: return [a for a in args if not a.is_generator] elif values: return [a for a in args if a.is_lazy_value] elif scalars: return [ a for a in args if not a.is_lazy_value and (generator or not a.is_generator) ] return [] @property def positional_values(self) -> List[LazyArgument]: return self.filtered_args( positional=True, keyword=False, values=True, scalars=False ) @property def positional_scalars(self) -> List[LazyArgument]: return self.filtered_args( positional=True, keyword=False, values=False, scalars=True ) @property def keyword_values(self) -> List[LazyArgument]: return self.filtered_args( positional=False, keyword=True, values=True, scalars=False ) @property def keyword_scalars(self) -> List[LazyArgument]: return self.filtered_args( positional=False, keyword=True, values=False, scalars=True )