mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Previously, we introduced new SymInt overloads for every function we wanted. This led to a lot of boilerplate, and also a lot of confusion about how the overloads needed to be implemented. This PR takes a simpler but more risky approach: just take the original function and changes its ints to SymInts. This is BC-breaking in the following ways: * The C++ API for registering implementations for aten operators will change from int64_t to SymInt whenever you make this change. Code generated registrations in PyTorch do not change as codegen handles the translation automatically, but manual registrations will need to follow the change. Typically, if you now accept a SymInt where you previously only took int64_t, you have to convert it back manually. This will definitely break XLA, see companion PR https://github.com/pytorch/xla/pull/3914 Note that not all dispatch keys get the automatic translation; all the composite keys and Meta keys are modified to take SymInt directly (because they should handle them directly), and so there are adjustments for this. This is not BC-breaking in the following ways: * The user facing C++ API remains compatible. Even if a function changes from int to SymInt, the default C++ binding still takes only ints. (e.g., at::empty(IntArrayRef, ...). To call with SymInts, you must call at::empty_symint instead. This involved adding two more signatures to CppSignatureGroup; in many cases I refactored code to iterate over all signatures in the group instead of hard-coding the two that previously existed. * This is TorchScript compatible; internally we treat SymInts as ints so there is no change to what happens at runtime in TorchScript. In particular, it's OK to reference an empty schema by its old type (using int types), as long as you're not doing string equality (which you shouldn't be), these parse to the same underyling type. Structure of the PR: * The general strategy of this PR is that, even when you write `SymInt` inside `native_functions.yaml`, sometimes, we will treat it *as if* it were an `int`. This idea pervades the codegen changes, where we have a translation from SymInt to c10::SymInt or int64_t, and this is controlled by a symint kwarg which I added and then audited all call sites to decide which I wanted. Here are some of the major places where we pick one or the other: * The C++ FunctionSchema representation represents `SymInt` as `int`. There are a few places we do need to know that we actually have a SymInt and we consult `real_type()` to get the real type in this case. In particular: * When we do schema validation of C++ operator registration, we must compare against true schema (as the C++ API will provide `c10::SymInt`, and this will only be accepted if the schema is `SymInt`. This is handled with cloneWithRealTypes before we check for schema differences. * In `toIValue` argument parsing, we parse against the true schema value. For backwards compatibility reasons, I do still accept ints in many places where Layout/SymInt/etc were expected. (Well, accepting int where SymInt is expected is not BC, it's just the right logic!) * In particular, because SymInt never shows up as type() in FunctionSchema, this means that we no longer need a dedicated Tag::SymInt. This is good, because SymInts never show up in mobile anyway. * Changes to functorch/aten are mostly about tracking changes to the C++ API registration convention. Additionally, since SymInt overloads no longer exist, registrations for SymInt implementations are deleted. In many cases, the old implementations did not properly support SymInts; I did not add any new functionality with this PR, but I did try to annotate with TODOs where this is work to do. Finally, because the signature of `native::` API changed from int to SymInt, I need to find alternative APIs for people who were directly calling these functions to call. Typically, I insert a new dispatch call when perf doesn't matter, or use `at::compositeexplicitautograd` namespace to handle other caes. * The change to `make_boxed_from_unboxed_functor.h` is so that we accept a plain IntList IValue anywhere a SymIntList is expected; these are read-only arguments so covariant typing is OK. * I change how unboxing logic works slightly. Previously, we interpret the C++ type for Layout/etc directly as IntType JIT type, which works well because the incoming IValue is tagged as an integer. Now, we interpret the C++ type for Layout as its true type, e.g., LayoutType (change to `jit_type.h`), but then we accept an int IValue for it anyway. This makes it symmetric with SymInt, where we interpret the C++ type as SymIntType, and then accept SymInt and int IValues for it. * I renamed the `empty.names` overload to `empty_names` to make it less confusing (I kept mixing it up with the real empty overload) * I deleted the `empty.SymInt` overload, which ended up killing a pile of functions. (This was originally a separate PR but the profiler expect test was giving me grief so I folded it in.) * I deleted the LazyDynamicOpsTest tests. These were failing after these changes, and I couldn't figure out why they used to be passing: they make use of `narrow_copy` which didn't actually support SymInts; they were immediately converted to ints. * I bashed LTC into working. The patches made here are not the end of the story. The big problem is that SymInt translates into Value, but what if you have a list of SymInt? This cannot be conveniently represented in the IR today, since variadic Values are not supported. To work around this, I translate SymInt[] into plain int[] (this is fine for tests because LTC dynamic shapes never actually worked); but this will need to be fixed for proper LTC SymInt support. The LTC codegen also looked somewhat questionable; I added comments based on my code reading. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/83628 Approved by: https://github.com/albanD, https://github.com/bdhirsh
449 lines
15 KiB
Python
449 lines
15 KiB
Python
from typing import List, Optional, Sequence, Set, Union
|
|
|
|
from torchgen import local
|
|
from torchgen.api.types import (
|
|
ArgName,
|
|
ArrayCType,
|
|
ArrayRefCType,
|
|
BaseCType,
|
|
BaseTypeToCppMapping,
|
|
Binding,
|
|
boolT,
|
|
ConstRefCType,
|
|
CType,
|
|
dimnameListT,
|
|
intArrayRefT,
|
|
ListCType,
|
|
longT,
|
|
MutRefCType,
|
|
NamedCType,
|
|
OptionalCType,
|
|
optionalIntArrayRefT,
|
|
scalarT,
|
|
SpecialArgName,
|
|
symIntArrayRefT,
|
|
SymIntT,
|
|
tensorListT,
|
|
tensorOptionsT,
|
|
tensorT,
|
|
TupleCType,
|
|
VectorCType,
|
|
voidT,
|
|
)
|
|
from torchgen.model import (
|
|
Argument,
|
|
Arguments,
|
|
BaseTy,
|
|
BaseType,
|
|
FunctionSchema,
|
|
ListType,
|
|
NativeFunction,
|
|
OptionalType,
|
|
Return,
|
|
SelfArgument,
|
|
TensorOptionsArguments,
|
|
Type,
|
|
)
|
|
from torchgen.utils import assert_never
|
|
|
|
# This file describes the translation of JIT schema to the public C++
|
|
# API, which is what people use when they call functions like at::add.
|
|
#
|
|
# Prominent characteristics of the C++ API:
|
|
#
|
|
# - dtype, layout, device and pin_memory are collected into
|
|
# a single C++ type TensorOptions (the native functions API
|
|
# also has this, but tensor options is really most relevant
|
|
# for the C++ API; it makes calling kwarg factory functions
|
|
# pleasant)
|
|
#
|
|
# - defaulting lives here (in fact, the dispatcher is completely
|
|
# oblivious of defaults!)
|
|
#
|
|
# BTW: policy on name collisions: we try not to have types with
|
|
# collisions, but functions are fair game to collide
|
|
|
|
|
|
def name(
|
|
func: FunctionSchema,
|
|
*,
|
|
faithful_name_for_out_overloads: bool = False,
|
|
symint_overload: bool = False,
|
|
) -> str:
|
|
name = str(func.name.name)
|
|
if symint_overload:
|
|
name += "_symint"
|
|
if func.is_out_fn():
|
|
if faithful_name_for_out_overloads:
|
|
name += "_outf"
|
|
else:
|
|
name += "_out"
|
|
|
|
return name
|
|
|
|
|
|
# Translation of "value types" in JIT schema to C++ API type. Value
|
|
# types look the same no matter if they are argument types or return
|
|
# types. Returns None if the type in question is not a value type.
|
|
def valuetype_type(
|
|
t: Type,
|
|
*,
|
|
binds: ArgName,
|
|
remove_non_owning_ref_types: bool = False,
|
|
symint: bool = False,
|
|
) -> Optional[NamedCType]:
|
|
if isinstance(t, BaseType):
|
|
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
|
|
return None
|
|
elif str(t) == "SymInt":
|
|
if symint:
|
|
return NamedCType(binds, BaseCType(SymIntT))
|
|
else:
|
|
return NamedCType(binds, BaseCType(longT))
|
|
if remove_non_owning_ref_types:
|
|
if t.name == BaseTy.str:
|
|
raise AssertionError(
|
|
"string ref->value conversion: not implemented yet"
|
|
)
|
|
# All other BaseType currently map directly to BaseCppTypes.
|
|
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
|
|
elif isinstance(t, OptionalType):
|
|
elem = valuetype_type(t.elem, binds=binds, symint=symint)
|
|
if elem is None:
|
|
return None
|
|
return NamedCType(binds, OptionalCType(elem.type))
|
|
elif isinstance(t, ListType):
|
|
if str(t.elem) == "bool":
|
|
assert t.size is not None
|
|
return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
|
|
else:
|
|
return None
|
|
else:
|
|
raise AssertionError(f"unrecognized type {repr(t)}")
|
|
|
|
|
|
# Translation of types occuring in JIT arguments to a C++ argument type.
|
|
# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
|
|
# For example, we'll return std::vector<int> instead of IntArrayRef.
|
|
# See Note [translation from C++ reference to value types]
|
|
def argumenttype_type(
|
|
t: Type,
|
|
*,
|
|
mutable: bool,
|
|
binds: ArgName,
|
|
remove_non_owning_ref_types: bool = False,
|
|
symint: bool = False,
|
|
) -> NamedCType:
|
|
# If it's a value type, do the value type translation
|
|
r = valuetype_type(
|
|
t,
|
|
binds=binds,
|
|
symint=symint,
|
|
remove_non_owning_ref_types=remove_non_owning_ref_types,
|
|
)
|
|
if r is not None:
|
|
return r
|
|
|
|
if isinstance(t, BaseType):
|
|
if t.name == BaseTy.Tensor:
|
|
if mutable and not local.use_const_ref_for_mutable_tensors():
|
|
return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
|
|
else:
|
|
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
|
|
elif t.name == BaseTy.Scalar:
|
|
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
|
|
else:
|
|
raise AssertionError(f"base type should have been value type {t}")
|
|
elif isinstance(t, OptionalType):
|
|
if str(t.elem) == "Tensor":
|
|
if mutable and not local.use_const_ref_for_mutable_tensors():
|
|
return NamedCType(
|
|
binds, MutRefCType(BaseCType(tensorT))
|
|
) # TODO: fix this discrepancy
|
|
else:
|
|
return NamedCType(
|
|
binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
|
|
)
|
|
elif str(t.elem) == "Scalar":
|
|
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
|
|
elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
|
|
return NamedCType(binds, BaseCType(optionalIntArrayRefT))
|
|
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
|
|
return NamedCType(binds, OptionalCType(elem.type))
|
|
elif isinstance(t, ListType):
|
|
# TODO: remove these special cases, ArrayRef fallthrough works fine
|
|
if str(t.elem) == "int":
|
|
if remove_non_owning_ref_types:
|
|
return NamedCType(binds, VectorCType(BaseCType(longT)))
|
|
else:
|
|
return NamedCType(binds, BaseCType(intArrayRefT))
|
|
if str(t.elem) == "SymInt":
|
|
if remove_non_owning_ref_types:
|
|
if symint:
|
|
return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
|
|
else:
|
|
return NamedCType(binds, VectorCType(BaseCType(longT)))
|
|
else:
|
|
if symint:
|
|
return NamedCType(binds, BaseCType(symIntArrayRefT))
|
|
else:
|
|
return NamedCType(binds, BaseCType(intArrayRefT))
|
|
if str(t.elem) == "Tensor":
|
|
return NamedCType(binds, BaseCType(tensorListT))
|
|
elif str(t.elem) == "Scalar":
|
|
return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
|
|
elif str(t.elem) == "Dimname":
|
|
return NamedCType(binds, BaseCType(dimnameListT))
|
|
elif str(t.elem) == "Tensor?":
|
|
return NamedCType(
|
|
binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
|
|
)
|
|
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
|
|
return NamedCType(binds, ArrayRefCType(elem.type))
|
|
else:
|
|
raise AssertionError(f"unrecognized type {repr(t)}")
|
|
|
|
|
|
# Translate a JIT argument into its C++ type
|
|
def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
|
|
return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
|
|
|
|
|
|
# Translation of a (non-multi) return type from JIT to C++
|
|
# N.B: returntype_type returns a CType, not a NamedCType.
|
|
# This is mostly because of the mismatch between return types and return names.
|
|
# e.g. a function with a return type of 'void' has 0 return names,
|
|
# and a function with a return type of 'std::tuple' has >1 return name.
|
|
def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
|
|
# placeholder is ignored
|
|
r = valuetype_type(t, binds="__placeholder__", symint=symint)
|
|
if r is not None:
|
|
return r.type
|
|
|
|
if isinstance(t, BaseType):
|
|
if t.name == BaseTy.Tensor:
|
|
if mutable:
|
|
if local.use_const_ref_for_mutable_tensors():
|
|
return ConstRefCType(BaseCType(tensorT))
|
|
else:
|
|
return MutRefCType(BaseCType(tensorT))
|
|
else:
|
|
# Note [Tensor Copy Returns]
|
|
# Currently, we use "Argument.is_write" to determine
|
|
# whether or not Tensor return types should be copies or references.
|
|
# If that ever changes, take a look at other locations of this note!
|
|
return BaseCType(tensorT)
|
|
elif t.name == BaseTy.Scalar:
|
|
return BaseCType(scalarT)
|
|
elif isinstance(t, ListType):
|
|
assert (
|
|
not mutable
|
|
), "Native functions should never return a mutable tensor list. They should return void."
|
|
elem = returntype_type(t.elem, mutable=False, symint=symint)
|
|
assert t.size is None, f"fixed size list returns not supported: {t}"
|
|
return VectorCType(elem)
|
|
|
|
raise AssertionError(f"unrecognized return type {t}")
|
|
|
|
|
|
# Translation of a single return to its C++ type
|
|
def return_type(r: Return, *, symint: bool = False) -> CType:
|
|
return returntype_type(r.type, mutable=r.is_write, symint=symint)
|
|
|
|
|
|
# Translation of a full (possibly multi) return from JIT to its C++ type
|
|
def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
|
|
if len(rs) == 0:
|
|
return BaseCType(voidT)
|
|
elif len(rs) == 1:
|
|
return return_type(rs[0], symint=symint)
|
|
else:
|
|
return TupleCType([return_type(r, symint=symint) for r in rs])
|
|
|
|
|
|
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
|
|
returns: List[str] = []
|
|
for i, r in enumerate(f.func.returns):
|
|
# If we have an inplace function, the return argument is
|
|
# implicitly named self.
|
|
# TODO: Consider incorporating this into the data model
|
|
if f.func.name.name.inplace:
|
|
assert i == 0, "illegal inplace function with multiple returns"
|
|
name = "self"
|
|
# If we are out function, the name is the name of the
|
|
# corresponding output function (r.name will get recorded
|
|
# in field_name later.)
|
|
elif f.func.is_out_fn():
|
|
name = f.func.arguments.out[i].name
|
|
# If the return argument is explicitly named...
|
|
elif r.name:
|
|
name_conflict = any(
|
|
r.name == a.name for a in f.func.schema_order_arguments()
|
|
)
|
|
if name_conflict and not f.func.is_out_fn():
|
|
name = f"{r.name}_return"
|
|
else:
|
|
name = r.name
|
|
# If there is no explicit name and no fallback name was passed in, we just name the output result,
|
|
# unless it's a multi-return, in which case it's result0,
|
|
# result1, etc (zero-indexed)
|
|
else:
|
|
name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
|
|
returns.append(name)
|
|
return returns
|
|
|
|
|
|
JIT_TO_CPP_DEFAULT = {
|
|
"False": "false",
|
|
"True": "true",
|
|
"None": "c10::nullopt", # UGH this one is type directed
|
|
"Mean": "at::Reduction::Mean",
|
|
"[]": "{}",
|
|
"contiguous_format": "MemoryFormat::Contiguous",
|
|
"long": "at::kLong",
|
|
}
|
|
|
|
# Convert a JIT default into C++ expression representing the default
|
|
def default_expr(d: str, t: Type) -> str:
|
|
if d == "None" and str(t) == "Tensor?":
|
|
return "{}"
|
|
if isinstance(t, BaseType) and t.name is BaseTy.str:
|
|
# Schema allows single quotes but C++ needs double
|
|
if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
|
|
s = ""
|
|
i = 1
|
|
while i + 1 < len(d):
|
|
if d[i] != "\\":
|
|
if d[i] == '"':
|
|
s += '\\"'
|
|
else:
|
|
s += d[i]
|
|
i += 1
|
|
else:
|
|
if d[i + 1] == "'":
|
|
s += "'"
|
|
else:
|
|
s += d[i : i + 2]
|
|
i += 2
|
|
|
|
return f'"{s}"'
|
|
|
|
if isinstance(t, OptionalType):
|
|
if d == "None":
|
|
return "c10::nullopt"
|
|
|
|
return default_expr(d, t.elem)
|
|
|
|
if isinstance(t, ListType):
|
|
if d.startswith("[") and d.endswith("]"):
|
|
return "{" + d[1:-1] + "}"
|
|
elif t.size is None:
|
|
# NOTE: Sized lists can have scalar defaults
|
|
raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
|
|
|
|
return JIT_TO_CPP_DEFAULT.get(d, d)
|
|
|
|
|
|
# Convert an argument into its C++ API form
|
|
|
|
|
|
def argument(
|
|
a: Union[Argument, TensorOptionsArguments, SelfArgument],
|
|
*,
|
|
cpp_no_default_args: Set[str],
|
|
method: bool,
|
|
faithful: bool,
|
|
symint: bool = False,
|
|
has_tensor_options: bool,
|
|
) -> List[Binding]:
|
|
def sub_argument(
|
|
a: Union[Argument, TensorOptionsArguments, SelfArgument]
|
|
) -> List[Binding]:
|
|
return argument(
|
|
a,
|
|
cpp_no_default_args=cpp_no_default_args,
|
|
method=method,
|
|
faithful=faithful,
|
|
symint=symint,
|
|
has_tensor_options=has_tensor_options,
|
|
)
|
|
|
|
if isinstance(a, Argument):
|
|
binds: ArgName
|
|
if a.name == "memory_format" and has_tensor_options:
|
|
binds = SpecialArgName.possibly_redundant_memory_format
|
|
else:
|
|
binds = a.name
|
|
default: Optional[str] = None
|
|
if a.name not in cpp_no_default_args and a.default is not None:
|
|
default = default_expr(a.default, a.type)
|
|
return [
|
|
Binding(
|
|
nctype=argument_type(a, binds=binds, symint=symint),
|
|
name=a.name,
|
|
default=default,
|
|
argument=a,
|
|
)
|
|
]
|
|
elif isinstance(a, TensorOptionsArguments):
|
|
if faithful:
|
|
return (
|
|
sub_argument(a.dtype)
|
|
+ sub_argument(a.layout)
|
|
+ sub_argument(a.device)
|
|
+ sub_argument(a.pin_memory)
|
|
)
|
|
else:
|
|
default = None
|
|
# Enforced by NativeFunction.__post_init__
|
|
assert "options" not in cpp_no_default_args
|
|
if all(x.default == "None" for x in a.all()):
|
|
default = "{}"
|
|
elif a.dtype.default == "long":
|
|
default = "at::kLong" # TODO: this is wrong
|
|
return [
|
|
Binding(
|
|
nctype=NamedCType("options", BaseCType(tensorOptionsT)),
|
|
name="options",
|
|
default=default,
|
|
argument=a,
|
|
)
|
|
]
|
|
elif isinstance(a, SelfArgument):
|
|
if method:
|
|
# Caller is responsible for installing implicit this in context!
|
|
return []
|
|
else:
|
|
return sub_argument(a.argument)
|
|
else:
|
|
assert_never(a)
|
|
|
|
|
|
def arguments(
|
|
arguments: Arguments,
|
|
*,
|
|
faithful: bool,
|
|
symint: bool = False,
|
|
method: bool,
|
|
cpp_no_default_args: Set[str],
|
|
) -> List[Binding]:
|
|
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
|
if faithful:
|
|
args.extend(arguments.non_out)
|
|
args.extend(arguments.out)
|
|
else:
|
|
args.extend(arguments.out)
|
|
args.extend(arguments.non_out)
|
|
return [
|
|
r.no_default() if faithful else r
|
|
for a in args
|
|
for r in argument(
|
|
a,
|
|
faithful=faithful,
|
|
symint=symint,
|
|
method=method,
|
|
has_tensor_options=arguments.tensor_options is not None,
|
|
cpp_no_default_args=cpp_no_default_args,
|
|
)
|
|
]
|