from tools.codegen.model import * from tools.codegen.api.types import * from tools.codegen.api import cpp from tools.codegen import local from typing import Union, Sequence, List, Optional # This file describes the translation of JIT schema to the native functions API. # This looks a lot like the C++ API (which makes historical sense, because the # idea was you wrote native functions to implement functions in the C++ API), # but over time we have evolved the C++ API without actually changing our # native:: kernels. The intention is to make native API and dispatcher API # line up as closely as possible, since this results in the least overhead # (no translation is needed from dispatcher API to native API). # # When a function is not use_c10_dispatcher: full, the dispatcher API actually # coincides with the native:: API (e.g., we do as dumb as pass through as # possible). def name(func: FunctionSchema) -> str: name = str(func.name.name) # TODO: delete this! if func.is_out_fn(): name += '_out' if func.name.overload_name: name += f'_{func.name.overload_name}' return name def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType: if str(t) == 'Tensor?': tensor_type: CType = BaseCType('Tensor', binds) if local.use_c10_dispatcher() is not UseC10Dispatcher.hacky_wrapper_for_legacy_signatures: tensor_type = OptionalCType(tensor_type) if mutable: return MutRefCType(tensor_type) else: return ConstRefCType(tensor_type) elif str(t) == 'Tensor?[]': return ConstRefCType(BaseCType("c10::List>", binds)) return cpp.argumenttype_type(t, mutable=mutable, binds=binds) def returns_type(rs: Sequence[Return]) -> str: return cpp.returns_type(rs) def argument_type(a: Argument, *, binds: ArgName) -> CType: return argumenttype_type(a.type, mutable=a.is_write, binds=binds) def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool) -> List[Binding]: # Ideally, we NEVER default native functions. However, there are a number # of functions that call native:: directly and rely on the defaulting # existing. So for BC, we generate defaults for non-out variants (but not # for out variants, where it is impossible to generate an appropriate # default) should_default = not is_out or local.use_c10_dispatcher() is not UseC10Dispatcher.full if isinstance(a, Argument): default: Optional[str] = None if should_default and a.default is not None: default = cpp.default_expr(a.default, a.type) return [Binding( ctype=argument_type(a, binds=a.name), name=a.name, default=default, argument=a, )] elif isinstance(a, SelfArgument): # Erase SelfArgument from the distinction return argument(a.argument, is_out=is_out) elif isinstance(a, TensorOptionsArguments): if local.use_c10_dispatcher() == UseC10Dispatcher.hacky_wrapper_for_legacy_signatures: # TODO: expunge this logic entirely default = None if should_default: if all(x.default == "None" for x in a.all()): default = '{}' elif a.dtype.default == "long": default = 'at::kLong' # TODO: this is wrong return [Binding( ctype=ConstRefCType(BaseCType('TensorOptions', 'options')), name='options', default=default, argument=a, )] else: assert local.use_c10_dispatcher() == UseC10Dispatcher.full default = None if should_default: default = '{}' # TODO: Not sure why the arguments assigned here are for # TensorOptionsArguments and not the constituent pieces. It seems # to matter return [ Binding( ctype=OptionalCType(BaseCType('ScalarType', 'dtype')), name='dtype', default=default, argument=a, ), Binding( ctype=OptionalCType(BaseCType('Layout', 'layout')), name='layout', default=default, argument=a, ), Binding( ctype=OptionalCType(BaseCType('Device', 'device')), name='device', default=default, argument=a, ), Binding( ctype=OptionalCType(BaseCType('bool', 'pin_memory')), name='pin_memory', default=default, argument=a, )] else: assert_never(a) def arguments(func: FunctionSchema) -> List[Binding]: args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] if local.use_c10_dispatcher() is UseC10Dispatcher.full: args.extend(func.arguments.non_out) args.extend(func.arguments.out) else: args.extend(func.arguments.out) args.extend(func.arguments.non_out) return [r for arg in args for r in argument(arg, is_out=func.is_out_fn())]