pytorch/torchgen/api/cpp.py
Brian Hirsh 960758b0b7 fix overload ambiguity with functional ops; fix _foreach op grouping (#80556)
This should fix the last issue that @anijain2305 hit when running ResNet with TorchDynamo <> functionalization.

Today if you try to call an `OpOverloadPacket` from python with some arguments, we will use the types of those arguments to perform overload resolution. With some functional variants of ops, this can be ambiguous.

Today this affects just one op: `_fused_moving_avg_obs_fq_helper`, although it would potentially affect e.g. `native_batch_norm` in the future.

Example:
```
# There are technically two overloads:
# torch.ops.aten._fused_moving_avg_obs_fq_helper.default (returns 2 argument, mutates 4 of its inputs inplace)
# torch.ops.aten._fused_moving_avg_obs_fq_helper.functional (returns 6 argument, mutates none of its inputs)

# We pick the wrong one - no way to know that we should pick the functional one, just from the call site.
outs = torch.ops.aten._fused_moving_avg_obs_fq_helper(a, a, a, a, a, a, a, 1.0, 0, 1, 0)
# raises an error - tries to call the overload with only 2 returns
return _fused_moving_avg_obs_fq_helper_functional[5]
```

Specifically, functionalization will bake `_fused_moving_avg_obs_fq_helper.functional` into the graph, but when AOTAutograd tries to compile with TorchScript, it needs to remove the overload name (TS doesn't know how to parse overload names directly, so we need to remove the overload name and let it infer the right overload at runtime later- so it picks the wrong one).

The situation is pretty similar to inplace; `ops.aten.add` and `ops.aten.add_` represent two different `OverloadPacket` objects; they can't be overloads of the same op, because their schemas would be ambiguous - the alias annotations are different, but that isn't enough to disambiguate).

In this PR, I try to fix the situation in a pretty similar way to how we handle `inplace` in the data model: `inplace` ops get their own base operator name, but they are represented as a flag inside of `BaseOperatorName` in the data model.

Two other important changes that I made as part of this PR:

(1) Originally, there were ~100 different `*_functional` operators: e.g. we had operators named `resize.functional` and `zero.functional`. The `_functional` bit isn't actually necessary in most cases: it's only necessary for operators that **also** have a `SchemaKind.mutable` variant, where `_fused_moving_avg_obs_fq_helper` is the only op that fits that description today. So I removed the unnecessary notion of "functional" from those other ops. I also added a bunch of assertions to force this restriction.

I think that makes more sense in the long run, because it eliminates an unnecessary difference in the model. E.g. we don't have `add_.Tensor` and `add.Tensor_functional`. We just have `add_.Tensor` and `add.Tensor`.

(2) I noticed that we actually still weren't pairing up a bunch of `_foreach` operators correctly, because their input arguments were different (`self` vs. `tensors`). Since they're private API's, I went ahead and changed the argument names directly so they get matched up. Before this PR, we were generating a separate `_foreach_add` and `_foreach_add.functional` variant in a bunch of cases, that really did the same thing (but happened to have a different name for the first argument).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80556
Approved by: https://github.com/ezyang, https://github.com/albanD
2022-07-06 12:45:11 +00:00

414 lines
14 KiB
Python

from torchgen.model import (
Argument,
Arguments,
BaseTy,
BaseType,
FunctionSchema,
ListType,
NativeFunction,
OptionalType,
Return,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.api.types import (
ArgName,
BaseCType,
Binding,
ConstRefCType,
NamedCType,
CType,
MutRefCType,
ArrayCType,
ListCType,
VectorCType,
ArrayRefCType,
OptionalCType,
TupleCType,
SpecialArgName,
boolT,
scalarT,
tensorListT,
dimnameListT,
tensorT,
voidT,
longT,
SymIntT,
symIntArrayRefT,
BaseTypeToCppMapping,
intArrayRefT,
optionalIntArrayRefT,
tensorOptionsT,
)
from torchgen import local
from torchgen.utils import assert_never
from typing import Optional, Sequence, Union, List, Set
# This file describes the translation of JIT schema to the public C++
# API, which is what people use when they call functions like at::add.
#
# Prominent characteristics of the C++ API:
#
# - dtype, layout, device and pin_memory are collected into
# a single C++ type TensorOptions (the native functions API
# also has this, but tensor options is really most relevant
# for the C++ API; it makes calling kwarg factory functions
# pleasant)
#
# - defaulting lives here (in fact, the dispatcher is completely
# oblivious of defaults!)
#
# BTW: policy on name collisions: we try not to have types with
# collisions, but functions are fair game to collide
def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str:
name = str(func.name.name)
if func.is_symint_fn():
name += "_symint"
if func.is_out_fn():
if faithful_name_for_out_overloads:
name += "_outf"
else:
name += "_out"
return name
# Translation of "value types" in JIT schema to C++ API type. Value
# types look the same no matter if they are argument types or return
# types. Returns None if the type in question is not a value type.
def valuetype_type(
t: Type, *, binds: ArgName, remove_non_owning_ref_types: bool = False
) -> Optional[NamedCType]:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
return None
if remove_non_owning_ref_types:
if t.name == BaseTy.str:
raise AssertionError(
"string ref->value conversion: not implemented yet"
)
# All other BaseType currently map directly to BaseCppTypes.
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
elif isinstance(t, OptionalType):
elem = valuetype_type(t.elem, binds=binds)
if elem is None:
return None
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
if str(t.elem) == "bool":
assert t.size is not None
return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
else:
return None
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translation of types occuring in JIT arguments to a C++ argument type.
# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
# For example, we'll return std::vector<int> instead of IntArrayRef.
# See Note [translation from C++ reference to value types]
def argumenttype_type(
t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False
) -> NamedCType:
# If it's a value type, do the value type translation
r = valuetype_type(
t, binds=binds, remove_non_owning_ref_types=remove_non_owning_ref_types
)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
else:
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
elif t.name == BaseTy.Scalar:
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
else:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if str(t.elem) == "Tensor":
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(
binds, MutRefCType(BaseCType(tensorT))
) # TODO: fix this discrepancy
else:
return NamedCType(
binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
)
elif str(t.elem) == "Scalar":
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
return NamedCType(binds, BaseCType(optionalIntArrayRefT))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
# TODO: remove these special cases, ArrayRef fallthrough works fine
if str(t.elem) == "int":
if remove_non_owning_ref_types:
return NamedCType(binds, VectorCType(BaseCType(longT)))
else:
return NamedCType(binds, BaseCType(intArrayRefT))
if str(t.elem) == "SymInt":
if remove_non_owning_ref_types:
return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
else:
return NamedCType(binds, BaseCType(symIntArrayRefT))
elif str(t.elem) == "Tensor":
return NamedCType(binds, BaseCType(tensorListT))
elif str(t.elem) == "Scalar":
return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
elif str(t.elem) == "SymInt":
return NamedCType(binds, BaseCType(symIntArrayRefT))
elif str(t.elem) == "Dimname":
return NamedCType(binds, BaseCType(dimnameListT))
elif str(t.elem) == "Tensor?":
return NamedCType(
binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
)
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, ArrayRefCType(elem.type))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translate a JIT argument into its C++ type
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
# Translation of a (non-multi) return type from JIT to C++
# N.B: returntype_type returns a CType, not a NamedCType.
# This is mostly because of the mismatch between return types and return names.
# e.g. a function with a return type of 'void' has 0 return names,
# and a function with a return type of 'std::tuple' has >1 return name.
def returntype_type(t: Type, *, mutable: bool) -> CType:
# placeholder is ignored
r = valuetype_type(t, binds="__placeholder__")
if r is not None:
return r.type
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable:
if local.use_const_ref_for_mutable_tensors():
return ConstRefCType(BaseCType(tensorT))
else:
return MutRefCType(BaseCType(tensorT))
else:
# Note [Tensor Copy Returns]
# Currently, we use "Argument.is_write" to determine
# whether or not Tensor return types should be copies or references.
# If that ever changes, take a look at other locations of this note!
return BaseCType(tensorT)
elif t.name == BaseTy.Scalar:
return BaseCType(scalarT)
elif isinstance(t, ListType):
assert (
not mutable
), "Native functions should never return a mutable tensor list. They should return void."
elem = returntype_type(t.elem, mutable=False)
assert t.size is None, f"fixed size list returns not supported: {t}"
return VectorCType(elem)
raise AssertionError(f"unrecognized return type {t}")
# Translation of a single return to its C++ type
def return_type(r: Return) -> CType:
return returntype_type(r.type, mutable=r.is_write)
# Translation of a full (possibly multi) return from JIT to its C++ type
def returns_type(rs: Sequence[Return]) -> CType:
if len(rs) == 0:
return BaseCType(voidT)
elif len(rs) == 1:
return return_type(rs[0])
else:
return TupleCType([return_type(r) for r in rs])
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
returns: List[str] = []
for i, r in enumerate(f.func.returns):
# If we have an inplace function, the return argument is
# implicitly named self.
# TODO: Consider incorporating this into the data model
if f.func.name.name.inplace:
assert i == 0, "illegal inplace function with multiple returns"
name = "self"
# If we are out function, the name is the name of the
# corresponding output function (r.name will get recorded
# in field_name later.)
elif f.func.is_out_fn():
name = f.func.arguments.out[i].name
# If the return argument is explicitly named...
elif r.name:
name_conflict = any(
r.name == a.name for a in f.func.schema_order_arguments()
)
if name_conflict and not f.func.is_out_fn():
name = f"{r.name}_return"
else:
name = r.name
# If there is no explicit name and no fallback name was passed in, we just name the output result,
# unless it's a multi-return, in which case it's result0,
# result1, etc (zero-indexed)
else:
name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
returns.append(name)
return returns
JIT_TO_CPP_DEFAULT = {
"False": "false",
"True": "true",
"None": "c10::nullopt", # UGH this one is type directed
"Mean": "at::Reduction::Mean",
"[]": "{}",
"contiguous_format": "MemoryFormat::Contiguous",
"long": "at::kLong",
}
# Convert a JIT default into C++ expression representing the default
def default_expr(d: str, t: Type) -> str:
if d == "None" and str(t) == "Tensor?":
return "{}"
if isinstance(t, BaseType) and t.name is BaseTy.str:
# Schema allows single quotes but C++ needs double
if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
s = ""
i = 1
while i + 1 < len(d):
if d[i] != "\\":
if d[i] == '"':
s += '\\"'
else:
s += d[i]
i += 1
else:
if d[i + 1] == "'":
s += "'"
else:
s += d[i : i + 2]
i += 2
return f'"{s}"'
if isinstance(t, OptionalType):
if d == "None":
return "c10::nullopt"
return default_expr(d, t.elem)
if isinstance(t, ListType):
if d.startswith("[") and d.endswith("]"):
return "{" + d[1:-1] + "}"
elif t.size is None:
# NOTE: Sized lists can have scalar defaults
raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
return JIT_TO_CPP_DEFAULT.get(d, d)
# Convert an argument into its C++ API form
def argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument],
*,
cpp_no_default_args: Set[str],
method: bool,
faithful: bool,
has_tensor_options: bool,
) -> List[Binding]:
def sub_argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument]
) -> List[Binding]:
return argument(
a,
cpp_no_default_args=cpp_no_default_args,
method=method,
faithful=faithful,
has_tensor_options=has_tensor_options,
)
if isinstance(a, Argument):
binds: ArgName
if a.name == "memory_format" and has_tensor_options:
binds = SpecialArgName.possibly_redundant_memory_format
else:
binds = a.name
default: Optional[str] = None
if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type)
return [
Binding(
nctype=argument_type(a, binds=binds),
name=a.name,
default=default,
argument=a,
)
]
elif isinstance(a, TensorOptionsArguments):
if faithful:
return (
sub_argument(a.dtype)
+ sub_argument(a.layout)
+ sub_argument(a.device)
+ sub_argument(a.pin_memory)
)
else:
default = None
# Enforced by NativeFunction.__post_init__
assert "options" not in cpp_no_default_args
if all(x.default == "None" for x in a.all()):
default = "{}"
elif a.dtype.default == "long":
default = "at::kLong" # TODO: this is wrong
return [
Binding(
nctype=NamedCType("options", BaseCType(tensorOptionsT)),
name="options",
default=default,
argument=a,
)
]
elif isinstance(a, SelfArgument):
if method:
# Caller is responsible for installing implicit this in context!
return []
else:
return sub_argument(a.argument)
else:
assert_never(a)
def arguments(
arguments: Arguments, *, faithful: bool, method: bool, cpp_no_default_args: Set[str]
) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
if faithful:
args.extend(arguments.non_out)
args.extend(arguments.out)
else:
args.extend(arguments.out)
args.extend(arguments.non_out)
return [
r.no_default() if faithful else r
for a in args
for r in argument(
a,
faithful=faithful,
method=method,
has_tensor_options=arguments.tensor_options is not None,
cpp_no_default_args=cpp_no_default_args,
)
]