pytorch/torchgen/api/lazy.py
Edward Z. Yang 8fae7027b3 Don't introduce new overload for SymInt (#83628)
Previously, we introduced new SymInt overloads for every function we wanted.  This led to a lot of boilerplate, and also a lot of confusion about how the overloads needed to be implemented.

This PR takes a simpler but more risky approach: just take the original function and changes its ints to SymInts.

This is BC-breaking in the following ways:

* The C++ API for registering implementations for aten operators will change from int64_t to SymInt whenever you make this change. Code generated registrations in PyTorch do not change as codegen handles the translation automatically, but manual registrations will need to follow the change.  Typically, if you now accept a SymInt where you previously only took int64_t, you have to convert it back manually.  This will definitely break XLA, see companion PR https://github.com/pytorch/xla/pull/3914 Note that not all dispatch keys get the automatic translation; all the composite keys and Meta keys are modified to take SymInt directly (because they should handle them directly), and so there are adjustments for this.

This is not BC-breaking in the following ways:

* The user facing C++ API remains compatible.  Even if a function changes from int to SymInt, the default C++ binding still takes only ints.  (e.g., at::empty(IntArrayRef, ...).  To call with SymInts, you must call at::empty_symint instead. This involved adding two more signatures to CppSignatureGroup; in many cases I refactored code to iterate over all signatures in the group instead of hard-coding the two that previously existed.
* This is TorchScript compatible; internally we treat SymInts as ints so there is no change to what happens at runtime in TorchScript. In particular, it's OK to reference an empty schema by its old type (using int types), as long as you're not doing string equality (which you shouldn't be), these parse to the same underyling type.

Structure of the PR:

* The general strategy of this PR is that, even when you write `SymInt` inside `native_functions.yaml`, sometimes, we will treat it *as if* it were an `int`. This idea pervades the codegen changes, where we have a translation from SymInt to c10::SymInt or int64_t, and this is controlled by a symint kwarg which I added and then audited all call sites to decide which I wanted. Here are some of the major places where we pick one or the other:
  * The C++ FunctionSchema representation represents `SymInt` as `int`. There are a few places we do need to know that we actually have a SymInt and we consult `real_type()` to get the real type in this case. In particular:
    * When we do schema validation of C++ operator registration, we must compare against true schema (as the C++ API will provide `c10::SymInt`, and this will only be accepted if the schema is `SymInt`. This is handled with cloneWithRealTypes before we check for schema differences.
    * In `toIValue` argument parsing, we parse against the true schema value. For backwards compatibility reasons, I do still accept ints in many places where Layout/SymInt/etc were expected. (Well, accepting int where SymInt is expected is not BC, it's just the right logic!)
  * In particular, because SymInt never shows up as type() in FunctionSchema, this means that we no longer need a dedicated Tag::SymInt. This is good, because SymInts never show up in mobile anyway.
* Changes to functorch/aten are mostly about tracking changes to the C++ API registration convention. Additionally, since SymInt overloads no longer exist, registrations for SymInt implementations are deleted. In many cases, the old implementations did not properly support SymInts; I did not add any new functionality with this PR, but I did try to annotate with TODOs where this is work to do. Finally, because the signature of `native::` API changed from int to SymInt, I need to find alternative APIs for people who were directly calling these functions to call. Typically, I insert a new dispatch call when perf doesn't matter, or use `at::compositeexplicitautograd` namespace to handle other caes.
* The change to `make_boxed_from_unboxed_functor.h` is so that we accept a plain IntList IValue anywhere a SymIntList is expected; these are read-only arguments so covariant typing is OK.
* I change how unboxing logic works slightly. Previously, we interpret the C++ type for Layout/etc directly as IntType JIT type, which works well because the incoming IValue is tagged as an integer. Now, we interpret the C++ type for Layout as its true type, e.g., LayoutType (change to `jit_type.h`), but then we accept an int IValue for it anyway. This makes it symmetric with SymInt, where we interpret the C++ type as SymIntType, and then accept SymInt and int IValues for it.
* I renamed the `empty.names` overload to `empty_names` to make it less confusing (I kept mixing it up with the real empty overload)
* I deleted the `empty.SymInt` overload, which ended up killing a pile of functions. (This was originally a separate PR but the profiler expect test was giving me grief so I folded it in.)
* I deleted the LazyDynamicOpsTest tests. These were failing after these changes, and I couldn't figure out why they used to be passing: they make use of `narrow_copy` which didn't actually support SymInts; they were immediately converted to ints.
* I bashed LTC into working. The patches made here are not the end of the story. The big problem is that SymInt translates into Value, but what if you have a list of SymInt? This cannot be conveniently represented in the IR today, since variadic Values are not supported. To work around this, I translate SymInt[] into plain int[] (this is fine for tests because LTC dynamic shapes never actually worked); but this will need to be fixed for proper LTC SymInt support. The LTC codegen also looked somewhat questionable; I added comments based on my code reading.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83628
Approved by: https://github.com/albanD, https://github.com/bdhirsh
2022-08-23 22:04:07 +00:00

448 lines
16 KiB
Python

from typing import Any, Dict, List, Optional, Tuple, Union
from torchgen.api.types import (
BaseCppType,
BaseCType,
boolT,
CType,
deviceT,
doubleT,
layoutT,
ListCType,
longT,
memoryFormatT,
NamedCType,
OptionalCType,
scalarT,
scalarTypeT,
stringT,
SymIntT,
VectorCType,
)
from torchgen.model import (
Argument,
BaseTy,
BaseType,
FunctionSchema,
ListType,
OperatorName,
OptionalType,
Return,
TensorOptionsArguments,
Type,
)
_valueT = None
# A ValueT is an IR type which represents the computation of a Tensor. In other
# words, a PyTorch user will do operations on lazy tensors, and each output lazy
# tensor internally tracks a ValueT representing the IR node that would have
# actually produced the value of this tensor for real.
#
# This is configurable because different lazy tensor backends (LTC vs XLA) will
# have different IR representations. (Though, arguably, after unification they
# shouldn't!)
def getValueT() -> BaseCppType:
global _valueT
if not _valueT:
raise NotImplementedError(
"The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
)
return _valueT
def setValueT(val: BaseCppType) -> None:
global _valueT
_valueT = val
# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
# making it easier to represent special properties of an arg.
tensorListValueT = BaseCppType("torch::lazy", "Value")
def process_ir_type(
typ: Type, properties: "LazyIrProperties"
) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
"""
This function takes a type from NativeFunctions and converts it for use with
lazy tensor codegen.
Type conversion for lazy currently consists of
(1) changing at::Tensors into lazy::Values
(2) wrapping everything in a BaseCType
(3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
(1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
This is incomplete- there are assertions in places that it's expected to need to add
more types as the codegen is used with more operators.
"""
if isinstance(typ, BaseType):
if typ.name == BaseTy.Tensor:
return BaseCType(getValueT())
elif typ.name == BaseTy.Scalar:
if properties.TreatScalarsAsConstants:
return BaseCType(scalarT)
# at::scalar has special handling,
# and is wrapped in an lazy::Value just like at::tensor
return BaseCType(getValueT())
elif typ.name == BaseTy.ScalarType:
return BaseCType(scalarTypeT)
elif typ.name == BaseTy.int:
return BaseCType(longT)
elif typ.name == BaseTy.SymInt:
return BaseCType(getValueT())
elif typ.name == BaseTy.bool:
return BaseCType(boolT)
elif typ.name == BaseTy.float:
return BaseCType(doubleT)
elif typ.name == BaseTy.str:
return BaseCType(stringT)
elif typ.name == BaseTy.Device:
return BaseCType(deviceT)
elif typ.name == BaseTy.Layout:
return BaseCType(layoutT)
elif typ.name == BaseTy.MemoryFormat:
return BaseCType(memoryFormatT)
else:
raise AssertionError(f"TODO add support for type {repr(typ)}")
elif isinstance(typ, OptionalType):
return OptionalCType(process_ir_type(typ.elem, properties))
elif isinstance(typ, ListType):
if str(typ.elem) == "Tensor?":
# TODO(whc) is this actually correct? or should it use a Vector like above
return ListCType(OptionalCType(BaseCType(getValueT())))
elif str(typ.elem) == "Tensor":
# this is a TensorList which comes in from GetTensorList as a Value
return BaseCType(tensorListValueT)
elif typ.elem == BaseType(BaseTy.SymInt):
# TODO: return a value type. The problem here is analogous to
# the problem with tensorListValueT: if you have SymInt[] you
# cannot conveniently save the list of Value directly, as nodes
# expect to save values as a vector for ALL arguments. So you
# need a separate IR node that represents all of the size nodes
# assembled into a list. I'm not an LTC dev so I don't want to
# figure it out right now. Y'all figure it out...
return VectorCType(BaseCType(longT))
else:
return VectorCType(process_ir_type(typ.elem, properties))
else:
raise AssertionError(f"unrecognized type {repr(typ)}")
# TODO: Determining this based off of CType is bad; this should be computed
# from Type directly; then the same logic as process_ir_type can be used
#
# Invariant: passed typ should be an *owning* CType (e.g., we will report
# that ArrayRef<Value> is NOT a value type)
def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool:
"""
Given a type, determine if it is a Value-like type. This is equivalent to
being Tensor-like, but assumes the type has already been transformed.
"""
if isinstance(typ, BaseCType):
# I am regretting my naming conventions, but now we are wrapping at::scalar in
# lazy value, while preserving other 'scalar' types as scalars in the IR
treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
return (
typ.type == getValueT()
or (typ.type == scalarT and not treat_scalars_as_constants)
or typ.type == SymIntT
)
elif typ == VectorCType(BaseCType(SymIntT)):
# TODO: report True for this
return False
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
return isValueType(typ.elem, properties)
return False
def isSymIntType(typ: Type) -> bool:
return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
def isWrappedScalarType(typ: Type) -> bool:
"""
Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
Since we literally change the type from scalarT to valueT, information is lost.
This function helps build a list of wrapped scalars to save that information
"""
if isinstance(typ, BaseType):
# I am regretting my naming conventions, but now we are wrapping at::scalar in
# lazy value, while preserving other 'scalar' types as scalars in the IR
return typ.name == BaseTy.Scalar
elif isinstance(typ, (OptionalType, ListType)):
return isWrappedScalarType(typ.elem)
return False
# TODO: dedupe with Type.is_generator_like
def isGeneratorType(typ: Type) -> bool:
if isinstance(typ, BaseType):
return typ.name == BaseTy.Generator
elif isinstance(typ, (OptionalType)):
return isGeneratorType(typ.elem)
return False
# This class caches a few derived properties computed from an Argument
# and LazyIrProperties
class LazyArgument:
name: str
orig_type: Type
lazy_type_: Optional[CType]
is_wrapped_scalar: bool
is_generator: bool
# TODO: this is lies, it is false for symint list
is_symint_or_list: bool
# true if this argument is or contains a lazy IR value
is_lazy_value: bool
def __init__(self, arg: Argument, properties: "LazyIrProperties"):
self.name = arg.name
self.orig_type = arg.type
self.is_optional = isinstance(arg.type, OptionalType)
self.is_generator = isGeneratorType(arg.type)
if self.is_generator:
assert (
self.is_optional
), "We expect all generators are optional since currently they are"
# there is no handling for generators in TorchScript IR (or XLA)
# so we fall back to eager if the (optional)generator has value, and otherwise
# its null and safe to exclude from lazy IR
self.lazy_type_ = None
else:
self.lazy_type_ = process_ir_type(arg.type, properties)
self.is_wrapped_scalar = isWrappedScalarType(arg.type)
self.is_symint_or_list = (
isSymIntType(arg.type)
# TODO: lists of symints are not currently treated as value types
# or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
)
self.is_lazy_value = not self.is_generator and isValueType(
self.lazy_type, properties
)
@property
def lazy_type(self) -> CType:
assert (
self.lazy_type_ is not None
), f"Attempted to access lazy_type for invalid argument {self.name}"
return self.lazy_type_
class LazyIrProperties:
"""Collection of properties for an IR node
The property groups are listed below. Each group is mutually
exclusive, meaning that only one property from each group can be True
at any one time. The properties can be accessed as if they were normal
attributes. The mutual exclusivity is automatically handled.
"""
Properties: Tuple[Tuple[str, ...], ...] = (
(
"ShapePrecompute", # Assume shape has been precomputed
"ShapeCompute", # Need to compute the shape on construction
"ShapeCache", # Utilize the shape cache to defer computation
),
(
"Lower", # Codegen full lower function
"LowerDeclOnly", # Codegen only lower function declaration
),
(
"CanBeReused", # Codegen full reuse function
"CanBeReusedDeclOnly", # Codegen only reuse function declaration
),
(
"CreateFn", # Codegen full create function
"CreateFnDeclOnly", # Codegen only create function declaration
),
(
"TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values
),
)
def __init__(self, *default_properties: str):
properties: Dict[Tuple[str, ...], Optional[str]] = {
p: None for p in LazyIrProperties.Properties
}
self.__dict__["properties"] = properties
for p in default_properties:
setattr(self, p, True)
def __getattr__(self, key: str) -> Any:
properties = self.__dict__["properties"]
for values in LazyIrProperties.Properties:
if key in values:
return properties[values] == key
return self.__getattribute__(key)
def __setattr__(self, key: str, value: Any) -> Any:
properties = self.__dict__["properties"]
for values in LazyIrProperties.Properties:
if key in values:
properties[values] = key if value else None
return value
raise KeyError(f"Invalid property: {key}")
# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
# but carries type information from a native FunctionSchema modified for use with IR nodes,
# and preserving original argument names.
#
# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
class LazyIrSchema:
# The name of the operator this function schema describes.
name: "OperatorName"
positional_args: Tuple[LazyArgument, ...]
keyword_args: Tuple[LazyArgument, ...]
# TODO: Need to handle collisions with argument names at some point
returns: Tuple["Return", ...]
# if this schema has a Generator arg, list its orig ctype/name but don't
# build a LazyArgument since lazy IR doesn't support it
generator_arg: Optional[NamedCType] = None
properties: LazyIrProperties = LazyIrProperties(
# default properties
"ShapePrecompute",
"Lower",
"CanBeReused",
)
opkind: Optional[str] = None
def __init__(
self, func: FunctionSchema, properties: Optional[LazyIrProperties] = None
):
if properties:
self.properties = properties
positional_args: List[LazyArgument] = []
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
if arg_field == "self_arg" and func.arguments.self_arg is not None:
arg = getattr(func.arguments, "self_arg").argument
positional_args.append(LazyArgument(arg, self.properties))
elif getattr(func.arguments, arg_field) is not None:
positional_args.extend(
LazyArgument(arg, self.properties)
for arg in getattr(func.arguments, arg_field)
)
self.positional_args = tuple(positional_args)
keyword_args: List[LazyArgument] = []
for arg_field in [
"pre_tensor_options_kwarg_only",
"tensor_options",
"post_tensor_options_kwarg_only",
"out",
]:
curr_args = getattr(func.arguments, arg_field)
if curr_args is not None:
if isinstance(curr_args, TensorOptionsArguments):
curr_args = curr_args.all()
for arg in curr_args:
if isGeneratorType(arg.type):
assert (
self.generator_arg is None
), "We expect there is only one generator arg"
self.generator_arg = NamedCType(arg.name, arg.type)
keyword_args.extend(
LazyArgument(arg, self.properties) for arg in curr_args
)
self.keyword_args = tuple(keyword_args)
self.name = func.name
self.returns = func.returns
@property
def node_name(self) -> str:
"""
Return camel-case version of op in node.
Note: This function also appends any `overload_name` in the operation.
For example, if the op is `bitwise_and.Tensor`, the returned name
will be `BitwiseAndTensor`.
"""
op_name = f"{self.name.name}_{self.name.overload_name}".lower()
return "".join(word.capitalize() or "" for word in op_name.split("_"))
@property
def aten_name(self) -> str:
return str(self.name.name)
@property
def base_name(self) -> str:
return f"{self.name.name.base}"
def filtered_args(
self,
positional: bool = True,
keyword: bool = True,
values: bool = True,
scalars: bool = True,
generator: bool = False,
) -> List[LazyArgument]:
# This function maintains the sorted order of arguments but provides different filtered views.
# Some parts of the code care about kwargs vs args (TS lowerings),
# other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
# Generators are special cased, as they are needed for fallback/shape-inference but not supported
# in TS lowerings and therefore also omitted from lazy IR.
args: List[LazyArgument] = []
if positional:
args.extend(self.positional_args)
if keyword:
args.extend(self.keyword_args)
if values and scalars and generator:
return args
elif values and scalars:
return [a for a in args if not a.is_generator]
elif values:
return [a for a in args if a.is_lazy_value]
elif scalars:
return [
a
for a in args
if not a.is_lazy_value and (generator or not a.is_generator)
]
return []
@property
def positional_values(self) -> List[LazyArgument]:
return self.filtered_args(
positional=True, keyword=False, values=True, scalars=False
)
@property
def positional_scalars(self) -> List[LazyArgument]:
return self.filtered_args(
positional=True, keyword=False, values=False, scalars=True
)
@property
def keyword_values(self) -> List[LazyArgument]:
return self.filtered_args(
positional=False, keyword=True, values=True, scalars=False
)
@property
def keyword_scalars(self) -> List[LazyArgument]:
return self.filtered_args(
positional=False, keyword=True, values=False, scalars=True
)