from tools.codegen.model import * from tools.codegen.api.types import * import tools.codegen.local as local from typing import Optional, Sequence, Union, List, Set # This file describes the translation of JIT schema to the public C++ # API, which is what people use when they call functions like at::add. # # Prominent characteristics of the C++ API: # # - dtype, layout, device and pin_memory are collected into # a single C++ type TensorOptions (the native functions API # also has this, but tensor options is really most relevant # for the C++ API; it makes calling kwarg factory functions # pleasant) # # - for 'use_c10_dispatcher: full' functions, optional tensors are # represented explicitly using c10::optional # # - defaulting lives here (in fact, the dispatcher is completely # oblivious of defaults!) # # BTW: policy on name collisions: we try not to have types with # collisions, but functions are fair game to collide def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str: name = str(func.name.name) if func.is_out_fn(): if faithful_name_for_out_overloads: name += '_outf' else: name += '_out' return name # Translation of "value types" in JIT schema to C++ API type. Value # types look the same no matter if they are argument types or return # types. Returns None if the type in question is not a value type. def valuetype_type(t: Type, *, binds: ArgName) -> Optional[CType]: if isinstance(t, BaseType): if t.name == BaseTy.Tensor: return None elif t.name == BaseTy.int: return BaseCType('int64_t', binds) elif t.name == BaseTy.float: return BaseCType('double', binds) elif t.name == BaseTy.str: return BaseCType('std::string', binds) elif t.name in [BaseTy.bool, BaseTy.QScheme, BaseTy.Scalar, BaseTy.ScalarType, BaseTy.Generator, BaseTy.Storage, BaseTy.Layout, BaseTy.Device, BaseTy.MemoryFormat, BaseTy.Dimname, BaseTy.Stream, BaseTy.ConstQuantizerPtr]: # These C++ names line up with their schema names return BaseCType(t.name.name, binds) else: raise AssertionError(f"unsupported type: {t}") elif isinstance(t, OptionalType): elem = valuetype_type(t.elem, binds=binds) if elem is None: return None return OptionalCType(elem) elif isinstance(t, ListType): if str(t.elem) == 'bool': assert t.size is not None return BaseCType(f"std::array", binds) else: return None else: raise AssertionError(f"unrecognized type {repr(t)}") # Translation of types occuring in JIT arguments to a C++ argument type. def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType: # If it's a value type, do the value type translation r = valuetype_type(t, binds=binds) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: if mutable: return MutRefCType(BaseCType('Tensor', binds)) else: return ConstRefCType(BaseCType('Tensor', binds)) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if str(t.elem) == 'Tensor': if mutable: return MutRefCType(BaseCType('Tensor', binds)) # TODO: fix this discrepancy else: if local.use_c10_dispatcher().dispatcher_uses_new_style(): return ConstRefCType(OptionalCType(BaseCType('Tensor', binds))) else: return ConstRefCType(BaseCType('Tensor', binds)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return OptionalCType(elem) elif isinstance(t, ListType): # TODO: remove these special cases, ArrayRef fallthrough works fine # NB: CType throws away ArrayRef structure because it is not currently # relevant in translation. When it becomes relevant, need to add back if str(t.elem) == 'int': return BaseCType("IntArrayRef", binds) elif str(t.elem) == 'Tensor': return BaseCType("TensorList", binds) elif str(t.elem) == 'Dimname': return BaseCType("DimnameList", binds) elif str(t.elem) == 'Tensor?': if local.use_c10_dispatcher().dispatcher_uses_new_style(): return ConstRefCType(BaseCType("c10::List>", binds)) else: return BaseCType("TensorList", binds) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) # TODO: explicitly qualify namespace here return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds) else: raise AssertionError(f"unrecognized type {repr(t)}") # Translate a JIT argument into its C++ type def argument_type(a: Argument, *, binds: ArgName) -> CType: return argumenttype_type(a.type, mutable=a.is_write, binds=binds) # Translation of a (non-multi) return type from JIT to C++ # NB: if need translations on return types, make this return CType too. Need to # take care; ArgName is misnomer now, and inputs are permitted to conflict with outputs # so need to make sure you don't have trouble def returntype_type(t: Type, *, mutable: bool) -> str: # placeholder is ignored r = valuetype_type(t, binds="__placeholder__") if r is not None: return r.cpp_type() if isinstance(t, BaseType): if t.name == BaseTy.Tensor: if mutable: return 'Tensor &' else: return 'Tensor' elif isinstance(t, ListType): elem = returntype_type(t.elem, mutable=mutable) assert t.size is None, f"fixed size list returns not supported: {t}" return f"std::vector<{elem}>" raise AssertionError(f"unrecognized return type {t}") # Translation of a single return to its C++ type def return_type(r: Return) -> str: return returntype_type(r.type, mutable=r.is_write) # Translation of a full (possibly multi) return from JIT to its C++ type def returns_type(rs: Sequence[Return]) -> str: if len(rs) == 0: return 'void' elif len(rs) == 1: return return_type(rs[0]) else: args = ','.join(map(return_type, rs)) return f'std::tuple<{args}>' def return_names(f: NativeFunction) -> Sequence[str]: returns: List[str] = [] for i, r in enumerate(f.func.returns): # If we have an inplace function, the return argument is # implicitly named self. # TODO: Consider incorporating this into the data model if f.func.name.name.inplace: assert i == 0, "illegal inplace function with multiple returns" name = 'self' # If we are out function, the name is the name of the # corresponding output function (r.name will get recorded # in field_name later.) elif f.func.is_out_fn(): name = f.func.arguments.out[i].name # If the return argument is explicitly named... elif r.name: name_conflict = any(r.name == a.name for a in f.func.schema_order_arguments()) if name_conflict and not f.func.is_out_fn(): name = f'{r.name}_return' else: name = r.name # If there is no explicit name, we just name the output result, # unless it's a multi-return, in which case it's result0, # result1, etc (zero-indexed) else: name = 'result' if len(f.func.returns) == 1 else f'result{i}' returns.append(name) return returns JIT_TO_CPP_DEFAULT = { 'False': 'false', 'True': 'true', 'None': 'c10::nullopt', # UGH this one is type directed 'Mean': 'at::Reduction::Mean', '[]': '{}', 'contiguous_format': 'MemoryFormat::Contiguous', 'long': 'at::kLong', } # Convert a JIT default into C++ expression representing the default def default_expr(d: str, t: Type) -> str: if d == 'None' and str(t) == 'Tensor?': return '{}' if isinstance(t, BaseType) and t.name is BaseTy.str: # Schema allows single quotes but C++ needs double if len(d) >= 2 and d[0] == "'" and d[-1] == "'": s = '' i = 1 while i + 1 < len(d): if d[i] != '\\': if d[i] == '"': s += '\\"' else: s += d[i] i += 1 else: if d[i + 1] == "'": s += "'" else: s += d[i:i + 2] i += 2 return f'"{s}"' if isinstance(t, OptionalType): if d == 'None': return 'c10::nullopt' return default_expr(d, t.elem) if isinstance(t, ListType): if (d.startswith('[') and d.endswith(']')): return '{' + d[1:-1] + '}' elif t.size is None: # NOTE: Sized lists can have scalar defaults raise ValueError(f"Expected a list default '[...]' but found: '{d}'") return JIT_TO_CPP_DEFAULT.get(d, d) # Convert an argument into its C++ API form def argument( a: Union[Argument, TensorOptionsArguments, SelfArgument], *, cpp_no_default_args: Set[str], method: bool, faithful: bool, has_tensor_options: bool ) -> List[Binding]: def sub_argument(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> List[Binding]: return argument( a, cpp_no_default_args=cpp_no_default_args, method=method, faithful=faithful, has_tensor_options=has_tensor_options) if isinstance(a, Argument): binds: ArgName if a.name == "memory_format" and has_tensor_options: binds = SpecialArgName.possibly_redundant_memory_format else: binds = a.name default: Optional[str] = None if a.name not in cpp_no_default_args and a.default is not None: default = default_expr(a.default, a.type) return [Binding( ctype=argument_type(a, binds=binds), name=a.name, default=default, argument=a, )] elif isinstance(a, TensorOptionsArguments): if faithful: return sub_argument(a.dtype) + sub_argument(a.layout) + \ sub_argument(a.device) + sub_argument(a.pin_memory) else: default = None # Enforced by NativeFunction.__post_init__ assert 'options' not in cpp_no_default_args if all(x.default == "None" for x in a.all()): default = '{}' elif a.dtype.default == "long": default = 'at::kLong' # TODO: this is wrong return [Binding( ctype=ConstRefCType(BaseCType('TensorOptions', 'options')), name='options', default=default, argument=a, )] elif isinstance(a, SelfArgument): if method: # Caller is responsible for installing implicit this in context! return [] else: return sub_argument(a.argument) else: assert_never(a) def arguments( arguments: Arguments, *, faithful: bool, method: bool, cpp_no_default_args: Set[str] ) -> List[Binding]: args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] if faithful: args.extend(arguments.non_out) args.extend(arguments.out) else: args.extend(arguments.out) args.extend(arguments.non_out) return [ r.no_default() if faithful else r for a in args for r in argument( a, faithful=faithful, method=method, has_tensor_options=arguments.tensor_options is not None, cpp_no_default_args=cpp_no_default_args) ]