mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
During export, we nub out most CIA ops to return NotImplemented to avoid decomposing them during tracing. To recover the existing shape propagation behavior, we register these CIA decomps directly as FakeTensorMode rules as well. The reason we have to do is because when we return NotImplemented, FakeTensor would fallback to running these CIAs with Meta backend causing device branching CIA ops to fail. (because now the device is Meta. One example is sdpa). If we register a kernel directly to FakeTensorMode, we won't fallback to Meta backend. Differential Revision: [D65716260](https://our.internmc.facebook.com/intern/diff/D65716260/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140465 Approved by: https://github.com/bdhirsh
1021 lines
34 KiB
Python
1021 lines
34 KiB
Python
# mypy: ignore-errors
|
|
|
|
import functools
|
|
import itertools
|
|
import math
|
|
import sys
|
|
from typing import Callable, Union
|
|
|
|
import torch
|
|
import torch._custom_op
|
|
import torch._logging
|
|
from torch._dispatch.python import no_python_dispatcher
|
|
from torch._ops import OpOverload
|
|
from torch._prims_common import (
|
|
elementwise_dtypes,
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
|
is_boolean_dtype,
|
|
is_float_dtype,
|
|
is_integer_dtype,
|
|
)
|
|
from torch._subclasses.fake_tensor import (
|
|
DataDependentOutputException,
|
|
DynamicOutputShapeException,
|
|
FakeTensor,
|
|
in_kernel_invocation_manager,
|
|
run_fallback_kernel,
|
|
UnsupportedOperatorException,
|
|
)
|
|
from torch.fx.operator_schemas import normalize_function
|
|
from torch.utils._stats import count_label
|
|
|
|
|
|
pytree = torch.utils._pytree
|
|
|
|
__all__ = [
|
|
"op_implementations_checks",
|
|
"get_fast_op_impls",
|
|
"stride_incorrect_op",
|
|
"has_meta",
|
|
]
|
|
|
|
op_implementations_dict = {}
|
|
op_implementations_checks = []
|
|
|
|
|
|
aten = torch._ops.ops.aten
|
|
|
|
|
|
def ordered_set(*items):
|
|
return dict.fromkeys(items, True)
|
|
|
|
|
|
# This function indicates if the backend device
|
|
# supports non-contiguous tensors
|
|
def is_noncontiguous_supported(device):
|
|
return device.type != "hpu"
|
|
|
|
|
|
_like_tensor_constructors = ordered_set(
|
|
aten.empty_like.default,
|
|
aten.empty_like.out,
|
|
aten.full_like.default,
|
|
aten.full_like.out,
|
|
aten.ones_like.default,
|
|
aten.ones_like.out,
|
|
aten.rand_like.default,
|
|
aten.rand_like.out,
|
|
aten.randn_like.default,
|
|
aten.randn_like.out,
|
|
aten.randint_like.default,
|
|
aten.randint_like.out,
|
|
aten.randint_like.low_dtype,
|
|
aten.randint_like.low_dtype_out,
|
|
aten.zeros_like.default,
|
|
aten.zeros_like.out,
|
|
aten.new_empty.default,
|
|
aten.new_empty.out,
|
|
aten.new_empty_strided.default,
|
|
aten.new_empty_strided.out,
|
|
aten.new_full.default,
|
|
aten.new_full.out,
|
|
aten.new_zeros.default,
|
|
aten.new_zeros.out,
|
|
aten.new_ones.default,
|
|
aten.new_ones.out,
|
|
)
|
|
|
|
|
|
_device_not_kwarg_ops = ordered_set(
|
|
aten._resize_output_.default,
|
|
aten._nested_tensor_from_tensor_list.default,
|
|
aten._nested_tensor_from_tensor_list.out,
|
|
aten.pin_memory.default,
|
|
aten.to.device,
|
|
aten.to.prim_Device,
|
|
aten.is_pinned.default,
|
|
aten._pin_memory.default,
|
|
aten._pin_memory.out,
|
|
aten._resize_output.default,
|
|
aten._resize_output.out,
|
|
)
|
|
|
|
# this op is never actually used
|
|
_non_kwarg_device_constructors = (aten._list_to_tensor,)
|
|
|
|
|
|
def contains_tensor_types(type):
|
|
tensor_type = torch._C.TensorType.get()
|
|
return type.isSubtypeOf(tensor_type) or any(
|
|
contains_tensor_types(e) for e in type.containedTypes()
|
|
)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def _is_tensor_constructor(func: OpOverload):
|
|
assert isinstance(func, OpOverload)
|
|
schema = func._schema
|
|
if any(contains_tensor_types(arg.type) for arg in schema.arguments):
|
|
return False
|
|
# TODO: no real reason to restrict multiple outputs
|
|
return (
|
|
len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
|
|
)
|
|
|
|
|
|
def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
|
|
def impl_decorator(op_impl):
|
|
if isinstance(run_impl_check, OpOverload):
|
|
assert (
|
|
run_impl_check not in op_implementations_dict
|
|
), f"duplicate registration: {run_impl_check}"
|
|
op_implementations_dict[run_impl_check] = op_impl
|
|
elif isinstance(run_impl_check, (list, tuple)):
|
|
for op in run_impl_check:
|
|
register_op_impl(op)(op_impl)
|
|
else:
|
|
assert callable(run_impl_check)
|
|
op_implementations_checks.append((run_impl_check, op_impl))
|
|
|
|
return op_impl
|
|
|
|
return impl_decorator
|
|
|
|
|
|
def _is_op_registered_to_fake_rule(op):
|
|
return op in op_implementations_dict
|
|
|
|
|
|
def _deregister_op_impl(op):
|
|
if op in op_implementations_dict:
|
|
del op_implementations_dict[op]
|
|
for check, impl in op_implementations_checks:
|
|
if check is op:
|
|
op_implementations_checks.remove((check, impl))
|
|
break
|
|
|
|
|
|
@register_op_impl(op_implementations_dict.__contains__)
|
|
def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs):
|
|
return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
|
|
|
|
|
|
@register_op_impl(_is_tensor_constructor)
|
|
@register_op_impl([*_like_tensor_constructors])
|
|
def constructors(fake_mode, func, *args, **kwargs):
|
|
assert func not in _non_kwarg_device_constructors
|
|
_, new_kwargs = normalize_function(
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
|
)
|
|
if "names" in kwargs:
|
|
raise UnsupportedOperatorException(
|
|
"torch.compile doesn't support named tensors"
|
|
)
|
|
|
|
if func in _like_tensor_constructors:
|
|
default_device = new_kwargs["input"].device
|
|
# TODO: file issue
|
|
args = (new_kwargs.pop("input"),)
|
|
else:
|
|
# cpu is default device if none is specified
|
|
default_device = torch.device("cpu")
|
|
args = ()
|
|
out_device = new_kwargs.pop("device", None)
|
|
out_device = out_device if out_device is not None else default_device
|
|
new_kwargs["device"] = torch.device("meta")
|
|
# _like constructors have fake tensor inputs (maybe this causes the non-like
|
|
# to fail? hmmm)
|
|
with in_kernel_invocation_manager(fake_mode):
|
|
r = func(*args, **new_kwargs)
|
|
return FakeTensor(fake_mode, r, out_device)
|
|
|
|
|
|
@register_op_impl(aten.is_pinned.default)
|
|
def non_kwarg_is_pinned(fake_mode, func, *args, **kwargs):
|
|
_, new_kwargs = normalize_function(
|
|
func, args, kwargs, normalize_to_only_use_kwargs=True
|
|
)
|
|
inp = new_kwargs.pop("input")
|
|
# we'll ignore device argument because it is deprecated and not
|
|
# actually used by is_pinned.
|
|
with in_kernel_invocation_manager(fake_mode):
|
|
r = func(inp)
|
|
return r
|
|
|
|
|
|
@register_op_impl(aten.to.prim_Device)
|
|
@register_op_impl(aten.to.device)
|
|
def non_kwarg_to(fake_mode, func, *args, **kwargs):
|
|
_, new_kwargs = normalize_function(
|
|
func, args, kwargs, normalize_to_only_use_kwargs=True
|
|
)
|
|
input_device = new_kwargs["device"]
|
|
out_device = input_device if input_device else new_kwargs["input"].device
|
|
new_kwargs["device"] = torch.device("meta")
|
|
inp = new_kwargs.pop("input")
|
|
with in_kernel_invocation_manager(fake_mode):
|
|
r = func(inp, **new_kwargs)
|
|
# TODO: I think this does the wrong thing if r is inp
|
|
return fake_mode.fake_tensor_converter.from_meta_and_device(
|
|
fake_mode, r, out_device
|
|
)
|
|
|
|
|
|
def stride_incorrect_op(op):
|
|
if op.namespace not in ("aten", "prims"):
|
|
return False
|
|
if op is aten._fft_c2c.default:
|
|
return False
|
|
|
|
op_name = op.name()
|
|
if "fft" in op_name:
|
|
return True
|
|
return False
|
|
|
|
|
|
# These operators have meta implementations with incorrect strides
|
|
@register_op_impl(stride_incorrect_op)
|
|
def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs):
|
|
# This is a workaround for meta implmentations with incorrect strides
|
|
|
|
def is_symbolic(x):
|
|
if isinstance(x, FakeTensor):
|
|
return x._has_symbolic_sizes_strides
|
|
if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
|
|
return True
|
|
return False
|
|
|
|
# For static shapes, we can fall back to eager for the real strides
|
|
if fake_mode.allow_fallback_kernels:
|
|
require_dynamic = any(
|
|
is_symbolic(x) for x in itertools.chain(args, kwargs.values())
|
|
)
|
|
if not require_dynamic:
|
|
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
|
|
return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None)
|
|
|
|
raise UnsupportedOperatorException(func)
|
|
|
|
|
|
# Dont default to default device handling,
|
|
# since the device of `the_template` is ignored
|
|
@register_op_impl(aten.resize_as_.default)
|
|
def resize_as_(fake_mode, func, *args, **kwargs):
|
|
with in_kernel_invocation_manager(fake_mode):
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
|
|
def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
|
|
# TODO: remove me
|
|
return constructors(fake_mode, func, *args, **kwargs)
|
|
|
|
|
|
# index.Tensor data-dependent in only some conditions
|
|
@register_op_impl(
|
|
lambda func: torch.Tag.dynamic_output_shape in func.tags
|
|
and func
|
|
not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]
|
|
)
|
|
def dyn_shape(fake_mode, func, *args, **kwargs):
|
|
raise DynamicOutputShapeException(func)
|
|
|
|
|
|
def _unique(
|
|
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
|
|
):
|
|
if (
|
|
fake_mode.shape_env is None
|
|
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
|
|
):
|
|
# Without symints/symfloats, cannot handle this
|
|
raise DynamicOutputShapeException(func)
|
|
|
|
# Do not use a memo for unique_dim
|
|
if dim is not None or (nnz := arg.unique_memo) is None:
|
|
# Avoid importing sympy at a module level
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
_constrain_range_for_size,
|
|
has_free_symbols,
|
|
)
|
|
|
|
if not has_free_symbols(arg.numel()) and arg.numel() == 0:
|
|
# If numel is zero, then the output size must be zero.
|
|
# In this case, we must not allocate an unbacked SymInt,
|
|
# because if we do, it will immediately get refined to
|
|
# zero, but this will be inconsistent with size oblivious
|
|
# tests (which will continue to claim that the unbacked
|
|
# symint cannot equal zero). We could also unconditionally
|
|
# allocate an unbacked SymInt and not refine its range,
|
|
# but this seems more precise.
|
|
nnz = 0
|
|
else:
|
|
nnz = fake_mode.shape_env.create_unbacked_symint()
|
|
|
|
maxval = sys.maxsize - 1
|
|
|
|
numel = arg.numel() if dim is None else arg.size(dim)
|
|
if not has_free_symbols(numel):
|
|
maxval = int(numel)
|
|
|
|
_constrain_range_for_size(nnz, max=maxval)
|
|
|
|
if dim is None:
|
|
arg.unique_memo = nnz
|
|
|
|
if dim is None:
|
|
ret = [arg.new_empty((nnz,))]
|
|
else:
|
|
ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])]
|
|
|
|
return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu")
|
|
if return_inverse or return_if_dim_and_cpu:
|
|
inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],))
|
|
else:
|
|
inverse = arg.new_empty(0)
|
|
ret.append(inverse)
|
|
|
|
if return_counts or return_if_dim_and_cpu:
|
|
counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],))
|
|
else:
|
|
counts = arg.new_empty(0)
|
|
ret.append(counts)
|
|
|
|
return tuple(ret)
|
|
|
|
|
|
@register_op_impl(aten._unique2.default)
|
|
def unique2(
|
|
fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False
|
|
):
|
|
return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
|
|
|
|
|
|
@register_op_impl(aten.unique_dim.default)
|
|
def unique_dim(
|
|
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
|
|
):
|
|
return _unique(
|
|
fake_mode,
|
|
func,
|
|
arg,
|
|
# normalize dim to be non-negative
|
|
dim if dim >= 0 else dim % max(arg.ndim, 1),
|
|
sorted,
|
|
return_inverse,
|
|
return_counts,
|
|
)
|
|
|
|
|
|
@register_op_impl(aten.repeat_interleave.Tensor)
|
|
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
|
|
if output_size is None:
|
|
if (
|
|
fake_mode.shape_env is None
|
|
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
|
|
):
|
|
raise DynamicOutputShapeException(func)
|
|
|
|
output_size = fake_mode.shape_env.create_unbacked_symint()
|
|
|
|
# Avoid importing sympy at a module level
|
|
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
|
|
|
|
_constrain_range_for_size(output_size)
|
|
# TODO: consider a memo
|
|
return repeats.new_empty(output_size)
|
|
|
|
|
|
@register_op_impl(torch.ops.aten.item.default)
|
|
@register_op_impl(torch.ops.aten._local_scalar_dense.default)
|
|
def local_scalar_dense(fake_mode, func, arg):
|
|
if (r := arg.item_memo) is not None:
|
|
return r
|
|
if fake_mode.shape_env is None or (
|
|
not fake_mode.shape_env.allow_scalar_outputs
|
|
and not fake_mode.allow_scalar_outputs
|
|
):
|
|
# Without symints/symfloats, cannot handle this
|
|
raise DataDependentOutputException(func)
|
|
if is_float_dtype(arg.dtype):
|
|
r = fake_mode.shape_env.create_unbacked_symfloat()
|
|
elif is_integer_dtype(arg.dtype):
|
|
r = fake_mode.shape_env.create_unbacked_symint()
|
|
elif is_boolean_dtype(arg.dtype):
|
|
r = fake_mode.shape_env.create_unbacked_symbool()
|
|
else:
|
|
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
|
|
arg.item_memo = r
|
|
return r
|
|
|
|
|
|
@register_op_impl(torch.ops.aten.nonzero_numpy.default)
|
|
def nonzero_numpy(fake_mode, func, arg):
|
|
return torch.ops.aten.nonzero.default(arg).unbind(1)
|
|
|
|
|
|
@register_op_impl(torch.ops.aten.nonzero.default)
|
|
def nonzero(fake_mode, func, arg):
|
|
if (
|
|
fake_mode.shape_env is None
|
|
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
|
|
):
|
|
# Without symints/symfloats, cannot handle this
|
|
raise DynamicOutputShapeException(func)
|
|
|
|
if (nnz := arg.nonzero_memo) is None:
|
|
# Avoid importing sympy at a module level
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
_constrain_range_for_size,
|
|
has_free_symbols,
|
|
)
|
|
from torch.utils._sympy.numbers import IntInfinity
|
|
from torch.utils._sympy.value_ranges import bound_sympy
|
|
|
|
if not has_free_symbols(arg.numel()) and arg.numel() == 0:
|
|
# If numel is zero, then the output size must be zero.
|
|
# In this case, we must not allocate an unbacked SymInt,
|
|
# because if we do, it will immediately get refined to
|
|
# zero, but this will be inconsistent with size oblivious
|
|
# tests (which will continue to claim that the unbacked
|
|
# symint cannot equal zero). We could also unconditionally
|
|
# allocate an unbacked SymInt and not refine its range,
|
|
# but this seems more precise.
|
|
nnz = 0
|
|
else:
|
|
nnz = fake_mode.shape_env.create_unbacked_symint()
|
|
|
|
maxval = sys.maxsize - 1
|
|
|
|
if not has_free_symbols(arg.numel()):
|
|
maxval = int(arg.numel())
|
|
else:
|
|
prod_node = math.prod(arg.shape).node
|
|
prod_range = bound_sympy(
|
|
prod_node.expr, prod_node.shape_env.var_to_range
|
|
)
|
|
if isinstance(prod_range.upper, IntInfinity):
|
|
maxval = sys.maxsize - 1
|
|
else:
|
|
maxval = prod_range.upper
|
|
|
|
_constrain_range_for_size(nnz, max=maxval)
|
|
|
|
arg.nonzero_memo = nnz
|
|
|
|
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
|
|
|
|
|
|
@register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default)
|
|
def _padded_dense_to_jagged_forward(fake_mode, func, padded, offsets, total_L=None):
|
|
# only one jagged dim is supported for now
|
|
assert len(offsets) == 1
|
|
|
|
if not total_L:
|
|
if (
|
|
fake_mode.shape_env is None
|
|
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
|
|
):
|
|
# Without symints/symfloats, cannot handle this
|
|
raise DynamicOutputShapeException(func)
|
|
|
|
total_L = fake_mode.shape_env.create_unbacked_symint()
|
|
|
|
maxval = sys.maxsize - 1
|
|
|
|
# Avoid importing sympy at a module level
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
_constrain_range_for_size,
|
|
has_free_symbols,
|
|
)
|
|
|
|
if not has_free_symbols(padded.numel()):
|
|
maxval = int(padded.numel())
|
|
|
|
_constrain_range_for_size(total_L, min=0, max=maxval)
|
|
|
|
output_shape = (total_L, *padded.shape[2:])
|
|
return padded.new_empty(output_shape)
|
|
|
|
|
|
@register_op_impl(torch.ops.aten.masked_select.default)
|
|
def masked_select(fake_mode, func, self, mask):
|
|
if (
|
|
fake_mode.shape_env is None
|
|
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
|
|
):
|
|
# Without symints/symfloats, cannot handle this
|
|
raise DynamicOutputShapeException(func)
|
|
|
|
nnz = fake_mode.shape_env.create_unbacked_symint()
|
|
|
|
# see nonzero for commentary
|
|
maxval = sys.maxsize - 1
|
|
|
|
# Avoid importing sympy at a module level
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
_constrain_range_for_size,
|
|
has_free_symbols,
|
|
)
|
|
from torch.utils._sympy.numbers import IntInfinity
|
|
from torch.utils._sympy.value_ranges import bound_sympy
|
|
|
|
# If num elements is expressed symbolically, calculate
|
|
# the concrete value based on upper bounds. Otherwise,
|
|
# we can set max val directly.
|
|
if not has_free_symbols(self.numel()):
|
|
num_elements = int(self.numel())
|
|
else:
|
|
prod_node = math.prod(self.shape).node
|
|
prod_range = bound_sympy(prod_node.expr, prod_node.shape_env.var_to_range)
|
|
if isinstance(prod_range.upper, IntInfinity):
|
|
num_elements = sys.maxsize - 1
|
|
else:
|
|
num_elements = prod_range.upper
|
|
if num_elements > 2:
|
|
maxval = num_elements
|
|
|
|
_constrain_range_for_size(nnz, max=maxval)
|
|
|
|
return self.new_empty((nnz,))
|
|
|
|
|
|
# NB: this must be ordered after local_scalar_dense
|
|
@register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags)
|
|
def data_dep(fake_mode, func, *args, **kwargs):
|
|
raise DataDependentOutputException(func)
|
|
|
|
|
|
# Bool Indices get Expanded as Masks
|
|
# See: IndexingUtils.h:expandTensors
|
|
def check_no_bool_index_tensors(func, self, indices):
|
|
for index in indices:
|
|
if index is not None and index.dtype in (torch.bool, torch.uint8):
|
|
raise DynamicOutputShapeException(func)
|
|
|
|
|
|
def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
|
|
_, new_kwargs = normalize_function(
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
|
)
|
|
|
|
out_device = new_kwargs["input"].device
|
|
with in_kernel_invocation_manager(fake_mode):
|
|
out = func(*args, **kwargs)
|
|
if not is_noncontiguous_supported(out_device):
|
|
out = out.new_empty(out.shape)
|
|
|
|
if out is new_kwargs["input"]:
|
|
return out # copy_
|
|
return FakeTensor(fake_mode, out, out_device)
|
|
|
|
|
|
_is_builtin_namespaces = ordered_set("aten", "prims", "prim")
|
|
|
|
|
|
def is_builtin(op):
|
|
return op.namespace in _is_builtin_namespaces
|
|
|
|
|
|
def has_meta(func):
|
|
return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta")
|
|
|
|
|
|
@register_op_impl(
|
|
lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func)
|
|
)
|
|
def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):
|
|
tensor_lists = []
|
|
for arg in itertools.chain(args, kwargs.values()):
|
|
if (
|
|
isinstance(arg, (list, tuple))
|
|
and len(arg)
|
|
and isinstance(arg[0], torch.Tensor)
|
|
):
|
|
tensor_lists.append(arg)
|
|
|
|
try:
|
|
with in_kernel_invocation_manager(fake_mode):
|
|
out_meta = func(*args, **kwargs)
|
|
except NotImplementedError as not_implemented_error:
|
|
return NotImplemented
|
|
|
|
if not out_meta:
|
|
return out_meta
|
|
|
|
assert tensor_lists
|
|
out_fake = []
|
|
|
|
for i, meta_t in enumerate(out_meta):
|
|
device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists])
|
|
out_fake.append(
|
|
fake_mode.fake_tensor_converter.from_meta_and_device(
|
|
fake_mode, meta_t, device
|
|
)
|
|
)
|
|
|
|
return out_fake
|
|
|
|
|
|
# Dont default to default device handling,
|
|
# Since op can take in non-zero sized cpu
|
|
# index tensors with cuda self
|
|
@register_op_impl(aten.index.Tensor)
|
|
def index_tensor(fake_mode, func, *args, **kwargs):
|
|
from torch._meta_registrations import meta_index_Tensor
|
|
|
|
_, new_kwargs = normalize_function(
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
|
)
|
|
|
|
out_device = new_kwargs["input"].device
|
|
# ensure nonzero call goes to fake tensor
|
|
with fake_mode:
|
|
out = meta_index_Tensor(*args, **kwargs)
|
|
return out.to(out_device)
|
|
|
|
|
|
# Can take mixed meta/non-meta arguments; the meta registration
|
|
# will roughly do the right thing even when given real devices
|
|
@register_op_impl(aten._embedding_bag.default)
|
|
def embedding_bag(fake_mode, func, *args, **kwargs):
|
|
from torch._meta_registrations import meta_embedding_bag
|
|
|
|
with fake_mode:
|
|
return meta_embedding_bag(*args, **kwargs)
|
|
|
|
|
|
# takes in multiple-devices, dont default to default device handling
|
|
@register_op_impl(aten._unsafe_index_put.default)
|
|
@register_op_impl(aten.copy.default)
|
|
@register_op_impl(aten.copy_.default)
|
|
@register_op_impl(aten.slice_scatter.default)
|
|
def multi_device_op_default(fake_mode, func, *args, **kwargs):
|
|
return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
|
|
|
|
|
|
# same with multi_device_op_default, but return the input
|
|
@register_op_impl(aten.copy.out)
|
|
@register_op_impl(aten.slice_scatter.out)
|
|
def multi_device_op_out(fake_mode, func, *args, **kwargs):
|
|
with in_kernel_invocation_manager(fake_mode):
|
|
func(*args, **kwargs)
|
|
|
|
_, new_kwargs = normalize_function(
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
|
)
|
|
|
|
return new_kwargs["input"]
|
|
|
|
|
|
@register_op_impl(aten.index_put.default)
|
|
@register_op_impl(aten.index_put_.default)
|
|
def index_put_impl(fake_mode, func, *args, **kwargs):
|
|
_, new_kwargs = normalize_function(
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
|
)
|
|
|
|
values = new_kwargs["values"]
|
|
self_device = new_kwargs["input"].fake_device
|
|
torch._check(
|
|
self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),
|
|
lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",
|
|
)
|
|
|
|
out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
|
|
if func is aten.index_put_.default:
|
|
return new_kwargs["input"]
|
|
else:
|
|
return out
|
|
|
|
|
|
@register_op_impl(aten._nested_tensor_from_tensor_list.default)
|
|
@register_op_impl(aten._nested_tensor_from_tensor_list.out)
|
|
@register_op_impl(aten._nested_view_from_buffer.default)
|
|
@register_op_impl(aten._nested_view_from_buffer_copy.default)
|
|
def nested_tensors_unsupported(fake_mode, func, *args, **kwargs):
|
|
raise UnsupportedOperatorException(
|
|
"torch.compile does not support strided NestedTensor"
|
|
)
|
|
|
|
|
|
@register_op_impl(
|
|
[
|
|
x
|
|
for x in _device_not_kwarg_ops
|
|
if x
|
|
not in (
|
|
# these are already registered elsewhere
|
|
aten.is_pinned.default,
|
|
aten.to.device,
|
|
aten.to.prim_Device,
|
|
aten._nested_tensor_from_tensor_list.default,
|
|
aten._nested_tensor_from_tensor_list.out,
|
|
)
|
|
]
|
|
)
|
|
def nyi(fake_mode, func, *args, **kwargs):
|
|
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
|
|
|
|
|
|
@register_op_impl([aten.convolution.default, aten.convolution_backward.default])
|
|
def conv(fake_mode, func, *args, **kwargs):
|
|
_, kwargs = normalize_function(
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
|
)
|
|
device = kwargs["input"].fake_device
|
|
# need to re-enable mode so the tensors report fake device
|
|
with fake_mode:
|
|
# if the input is unsqueezed is done in Convolution.cpp we get segfault
|
|
k = kwargs["weight"].ndim
|
|
batch = kwargs["input"].shape[0]
|
|
|
|
# Avoid importing sympy at a module level
|
|
from torch.fx.experimental.symbolic_shapes import has_hint
|
|
|
|
if not has_hint(batch):
|
|
# TODO: We can make this a little more faithful with best effort
|
|
# channels last detection (but only if it's statically obvious!)
|
|
mem_fmt = None
|
|
elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
|
|
mem_fmt = None
|
|
else:
|
|
if func is aten.convolution.default:
|
|
conv_backend = torch._C._select_conv_backend(**kwargs)
|
|
else:
|
|
conv_backend = torch._C._select_conv_backend(
|
|
kwargs["input"],
|
|
kwargs["weight"],
|
|
bias=None,
|
|
stride=kwargs["stride"],
|
|
padding=kwargs["padding"],
|
|
dilation=kwargs["dilation"],
|
|
transposed=kwargs["transposed"],
|
|
output_padding=kwargs["output_padding"],
|
|
groups=kwargs["groups"],
|
|
bias_sizes=kwargs["bias_sizes"],
|
|
)
|
|
mem_fmt = torch._C._conv_determine_backend_memory_format(
|
|
kwargs["input"], kwargs["weight"], conv_backend
|
|
)
|
|
|
|
def convert(t, mem_fmt):
|
|
if t is None:
|
|
return t
|
|
if mem_fmt is not None:
|
|
t = t.to(memory_format=mem_fmt)
|
|
return FakeTensor(fake_mode, t, device)
|
|
|
|
with in_kernel_invocation_manager(fake_mode):
|
|
out = func(**kwargs)
|
|
|
|
if func is aten.convolution.default:
|
|
return convert(out, mem_fmt)
|
|
else:
|
|
return (
|
|
convert(out[0], mem_fmt),
|
|
convert(out[1], mem_fmt),
|
|
convert(out[2], None),
|
|
)
|
|
|
|
|
|
@register_op_impl(torch.ops.aten._pack_padded_sequence.default)
|
|
def _pack_padded_sequence(fake_mode, func, inputs, lengths, batch_first):
|
|
if (
|
|
fake_mode.shape_env is None
|
|
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
|
|
):
|
|
# Without symints/symfloats, cannot handle this
|
|
raise DynamicOutputShapeException(func)
|
|
|
|
new_batch_size = fake_mode.shape_env.create_unbacked_symint()
|
|
|
|
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
|
|
|
|
_constrain_range_for_size(new_batch_size)
|
|
|
|
if not batch_first:
|
|
# Inputs should have shape (batch_size, seq_len, *)
|
|
inputs = inputs.transpose(0, 1)
|
|
|
|
res_size = inputs.shape[1:]
|
|
packed_data = inputs.new_empty(res_size)
|
|
batch_size = inputs.new_empty((new_batch_size,))
|
|
return (packed_data, batch_size)
|
|
|
|
|
|
FAST_OP_IMPLEMENTATIONS = {}
|
|
|
|
|
|
# Unlike register_op_impl, these don't do the slow iteration for
|
|
# run_impl_check, and these run BEFORE decompositions
|
|
def register_fast_op_impl(func: OpOverload):
|
|
def impl_decorator(op_impl):
|
|
FAST_OP_IMPLEMENTATIONS[func] = op_impl
|
|
return op_impl
|
|
|
|
return impl_decorator
|
|
|
|
|
|
# infer_size_impl in ExpandUtils
|
|
def infer_size(a, b):
|
|
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
|
|
|
dimsA = len(a)
|
|
dimsB = len(b)
|
|
ndim = max(dimsA, dimsB)
|
|
expandedSizes = [0] * ndim
|
|
for i in range(ndim - 1, -1, -1):
|
|
offset = ndim - 1 - i
|
|
dimA = dimsA - 1 - offset
|
|
dimB = dimsB - 1 - offset
|
|
sizeA = a[dimA] if dimA >= 0 else 1
|
|
sizeB = b[dimB] if dimB >= 0 else 1
|
|
|
|
# NB: It is very important to test for broadcasting, before testing
|
|
# sizeA == sizeB. This is because the broadcasting tests are likely
|
|
# to be statically known (in particular, if sizeA/sizeB is unbacked
|
|
# but size-like, we will unsoundly assume they never equal 1), but
|
|
# the sizeA == sizeB test may not be statically known. However, once
|
|
# we have established that no broadcasting is happening, the
|
|
# sizeA == sizeB is now expect_true and we can defer it as a runtime
|
|
# assert (this works because Python will return the terminal
|
|
# expression of an or statement as-is, without bool()'ing it; if this
|
|
# were not the case, we'd need to write this using torch.sym_or() or
|
|
# something like that).
|
|
torch._check(
|
|
guard_size_oblivious(sizeA == 1)
|
|
or guard_size_oblivious(sizeB == 1)
|
|
or sizeA == sizeB,
|
|
lambda: f"The size of tensor a ({sizeA}) "
|
|
f"must match the size of tensor b ({sizeB}) "
|
|
f"at non-singleton dimension {i})",
|
|
)
|
|
expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
|
|
return tuple(expandedSizes)
|
|
|
|
|
|
def make_fast_binary_impl(slow_ref):
|
|
def fast_binary_impl(mode, *args, **kwargs):
|
|
def slow(msg):
|
|
count_label(f"slow {msg}")
|
|
with mode:
|
|
return slow_ref(*args, **kwargs)
|
|
|
|
count_label("attempt fast")
|
|
|
|
# Fast path (based off of TensorIterator fast path).
|
|
# Unfortunately, there is no way to easily deduplicate
|
|
# this with either the TensorIterator C++ implementation
|
|
# (which we don't want to SymIntify, and also the algorithm
|
|
# here is slightly different from TensorIterator to allow
|
|
# for broadcasting), nor the PrimTorch implementation
|
|
# (which does not actually implement a fast path.)
|
|
|
|
operands = args
|
|
|
|
# compute_shape
|
|
has_scalars = False
|
|
has_tensors = False
|
|
final_shape = None
|
|
for op in operands:
|
|
shape = op.shape if isinstance(op, torch.Tensor) else ()
|
|
if len(shape) == 0:
|
|
has_scalars = True
|
|
else:
|
|
has_tensors = True
|
|
if final_shape is None:
|
|
final_shape = shape
|
|
# TODO: Minor optimization: track if the shapes
|
|
# were equal so you can skip the equality check
|
|
# below if unnecessary
|
|
final_shape = infer_size(final_shape, shape)
|
|
assert final_shape is not None
|
|
|
|
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
|
|
|
|
# Do some extra safety checks to see if the output
|
|
# stride is obvious
|
|
for op in operands:
|
|
if (
|
|
isinstance(op, torch.Tensor)
|
|
and len(op.shape) == len(final_shape)
|
|
and guard_size_oblivious(sym_eq(op.shape, final_shape))
|
|
):
|
|
break
|
|
else:
|
|
return slow("both tensors nontrivially broadcast")
|
|
|
|
# compute_types
|
|
cpu = torch.device("cpu")
|
|
common_device = cpu
|
|
common_dtype = None
|
|
output_dtype = None
|
|
has_different_input_dtypes = False
|
|
for op in operands:
|
|
if not isinstance(op, torch.Tensor):
|
|
# Use elementwise_dtypes for the tricky case
|
|
has_different_input_dtypes = True
|
|
continue
|
|
if common_device == cpu and not op.device.type == "cpu":
|
|
common_device = op.device
|
|
# Slightly simplified here as target_dtype cannot vary
|
|
if common_dtype is None:
|
|
common_dtype = op.dtype
|
|
elif common_dtype != op.dtype:
|
|
has_different_input_dtypes = True
|
|
|
|
if has_different_input_dtypes:
|
|
# compute promotion
|
|
# TODO: we don't need the compute type
|
|
_, common_dtype = elementwise_dtypes(
|
|
*operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
|
)
|
|
|
|
# check all tensors on same device
|
|
# cpu scalars are assumed allow
|
|
current_cpu_scalars_on_non_cpu = 0
|
|
max_cpu_scalars_on_non_cpu = 1 # hard coded atm
|
|
for op in operands:
|
|
if not isinstance(op, torch.Tensor):
|
|
continue
|
|
if common_device != cpu and op.dim() == 0 and op.device == cpu:
|
|
if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
|
|
return slow("error")
|
|
current_cpu_scalars_on_non_cpu += 1
|
|
elif op.device != common_device:
|
|
return slow("error")
|
|
|
|
# compute_fast_setup_type
|
|
is_contiguous = True
|
|
is_channels_last = True
|
|
# TODO: is_non-overlapping_and_dense (not bound from Python
|
|
# no inplace, no out, everything defined
|
|
|
|
if is_noncontiguous_supported(common_device):
|
|
for op in operands:
|
|
if not isinstance(op, torch.Tensor):
|
|
continue
|
|
is_contiguous = is_contiguous and op.is_contiguous(
|
|
memory_format=torch.contiguous_format
|
|
)
|
|
is_channels_last = is_channels_last and op.is_contiguous(
|
|
memory_format=torch.channels_last
|
|
)
|
|
if is_contiguous:
|
|
# do contiguous
|
|
count_label("fast is_contiguous")
|
|
return FakeTensor(
|
|
mode,
|
|
torch.empty(
|
|
final_shape,
|
|
dtype=common_dtype,
|
|
device="meta",
|
|
memory_format=torch.contiguous_format,
|
|
),
|
|
device=common_device,
|
|
)
|
|
if is_channels_last:
|
|
count_label("fast channels_last")
|
|
# do channels last
|
|
return FakeTensor(
|
|
mode,
|
|
torch.empty(
|
|
final_shape,
|
|
dtype=common_dtype,
|
|
device="meta",
|
|
memory_format=torch.channels_last,
|
|
),
|
|
device=common_device,
|
|
)
|
|
|
|
return slow("no contiguity match")
|
|
|
|
return fast_binary_impl
|
|
|
|
|
|
# disable the python dispatcher to avoid decomposing detach() further
|
|
# (proxy_mode should still decompose detach() though)
|
|
def fast_detach(fake_mode, x):
|
|
with no_python_dispatcher(), in_kernel_invocation_manager(fake_mode):
|
|
out = torch.ops.aten.detach.default(x)
|
|
return FakeTensor(fake_mode, out, x.device)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def get_fast_op_impls():
|
|
import torch._refs
|
|
|
|
register_fast_op_impl(torch.ops.aten.add.Tensor)(
|
|
make_fast_binary_impl(torch._refs.add)
|
|
)
|
|
register_fast_op_impl(torch.ops.aten.sub.Tensor)(
|
|
make_fast_binary_impl(torch._refs.sub)
|
|
)
|
|
register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]
|
|
register_fast_op_impl(torch.ops.aten.div.Tensor)(
|
|
make_fast_binary_impl(torch._refs.div)
|
|
)
|
|
register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach)
|
|
return FAST_OP_IMPLEMENTATIONS
|