Fixes for PyTorch/XLA functionalization integration (#94537)

Fixes for PyTorch/XLA functionalization integration

---
Some notable changes include:
- More asserts in `FunctionalTensorWrapper`, so bugs show up more cleanly in cases where we e.g. forget to wrap an output
- Make the *_scatter ops `CompositeExplicitAutogradNonFunctional`, so we get a better error message and XLA doesn't accidentally try to us them
- Fix LTC/XLA codegen in core to handle multi-tensor out= ops with no returns
- Better erroring: Allow XLA to use the CPU fallback from core in a way so that it always errors on view ops, which XLA should no longer see.
- Update MetaConverter to exclude XLA tensors in raising NotImplemented…
- Add `_propagate_xla_data` op
- Add meta tensor support for some ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94537
Approved by: https://github.com/bdhirsh
This commit is contained in:
Wonjoo Lee 2023-03-02 23:02:30 +00:00 committed by PyTorch MergeBot
parent f397d1700f
commit 3095c95828
14 changed files with 130 additions and 20 deletions

View File

@ -1 +1 @@
503401a24e532a9019ef140199319221294045ee
f9963f6c2d34b9662f93e5518adb15949be05f65

View File

@ -13,6 +13,7 @@
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/_propagate_xla_data.h>
#include <ATen/ops/_to_copy.h>
#endif
@ -51,6 +52,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value)
),
value_(value)
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
set_constructor_metadata();
}
@ -130,6 +133,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
),
value_(view_value)
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
set_constructor_metadata();
// Copy the original tensor's ViewMeta vector and push the current one.
if (!base->view_metas_.empty()) {
@ -168,7 +173,9 @@ void FunctionalTensorWrapper::mutate_view_meta(at::functionalization::ViewMeta m
// So, these ops are special - they're mutation AND view ops. They get special codegen.
// An example is transpose_, e.g. `a.transpose_()`
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
at::AutoDispatchSkipFunctionalize guard;
value_ = meta.forward_fn(value_, meta.out_index);
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
}
// Note [Functionalization: Mutation Removal]
@ -200,15 +207,20 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) {
// TODO: going to need to change this if we want nested functionalize() transforms.
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other));
value_ = other;
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
// out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor.
// We need to propagate that metadata mutation to the wrapper (new size).
set_sizes_and_strides(value_.sym_sizes(), value_.sym_strides(), value_.sym_storage_offset());
auto sizes_ = value_.sym_sizes();
auto strides_ = value_.sym_strides();
auto storage_offset_ = value_.sym_storage_offset();
set_sizes_and_strides(sizes_, strides_, storage_offset_);
if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
// .to() should not re-entrantly go through functionalization.
at::AutoDispatchSkipFunctionalize guard;
// and we want _to_copy() to show up in the graph, not the composite .to() operator
// (this can happen if autograd has already run by the time we enter this code)
value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout()));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
}
}
@ -243,6 +255,7 @@ void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) {
// Then it's safe to throw out the old storage and replace it with the new, larger one.
storage_ = c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(other));
value_ = other;
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
generation_ = 0;
// And update the metadata on the wrapper to reflect the new sizes and strides
set_sizes_and_strides(value_.sizes(), value_.strides());
@ -484,6 +497,24 @@ void replace_(const ITensorListRef functional_tensor, ITensorListRef other) {
}
}
void propagate_xla_data(const Tensor& functional_tensor, const Tensor& other) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
if (functional_tensor.key_set().has(c10::DispatchKey::XLA)) {
at::_propagate_xla_data(at::functionalization::impl::unsafeGetFunctionalWrapper(functional_tensor)
->value(), other);
}
}
void propagate_xla_data(const ITensorListRef functional_tensor, ITensorListRef other) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size());
auto functional_tensor_it = functional_tensor.begin();
auto other_it = other.begin();
for (const auto i : c10::irange(functional_tensor.size())) {
(void)i; // Suppress unused variable warning
propagate_xla_data(*functional_tensor_it++, *other_it++);
}
}
void commit_update(const Tensor& functional_tensor) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
unsafeGetFunctionalWrapper(functional_tensor)->commit_update();

View File

@ -224,6 +224,15 @@ TORCH_API void replace_(
TORCH_API void commit_update(const Tensor& functional_tensor);
TORCH_API void commit_update(ITensorListRef functional_tensor);
// These two methods are XLA-specific logic and are no-ops
// for the normal functionalization flow.
TORCH_API void propagate_xla_data(
const Tensor& functional_tensor,
const Tensor& other);
TORCH_API void propagate_xla_data(
const ITensorListRef functional_tensor,
ITensorListRef other);
Tensor create_functional_tensor_with_view_meta(
const Tensor& view_to_wrap,
const Tensor& base,

View File

@ -65,7 +65,7 @@ c10::optional<c10::Device> compute_target_device(std::vector<at::Tensor>& t_args
}
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views) {
auto& schema_args = op.schema().arguments();
const auto num_arguments = schema_args.size();
auto arguments = torch::jit::last(stack, num_arguments);
@ -176,9 +176,15 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
} else {
dev_str << "<none>";
}
TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
if (error_on_views) {
TORCH_CHECK(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
} else {
TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
}
}
// Case (2): copy case. Copy the cpu output tensor to the original device.

View File

@ -11,7 +11,7 @@ namespace at { namespace native {
// This function implements a boxed fallback to CPU.
// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false);
// This is a helper function that backends can use to directly call their boxed CPU fallback
// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.

View File

@ -320,6 +320,10 @@ void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src) {
copy_stub(iter.device_type(), iter, /*non_blocking=*/false);
}
void _propagate_xla_data(const Tensor& input, const Tensor& output) {
TORCH_INTERNAL_ASSERT(input.device().type() == kXLA, "This op should only be called by XLA")
}
DEFINE_DISPATCH(copy_stub);
} // namespace native

View File

@ -5091,7 +5091,7 @@
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: slice_scatter
CompositeExplicitAutogradNonFunctional: slice_scatter
autogen: slice_scatter.out
tags: core
@ -5100,7 +5100,7 @@
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: select_scatter_symint
CompositeExplicitAutogradNonFunctional: select_scatter_symint
autogen: select_scatter.out
- func: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor
@ -5108,7 +5108,7 @@
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: diagonal_scatter
CompositeExplicitAutogradNonFunctional: diagonal_scatter
autogen: diagonal_scatter.out
- func: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
@ -5116,7 +5116,7 @@
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: as_strided_scatter_symint
CompositeExplicitAutogradNonFunctional: as_strided_scatter_symint
autogen: as_strided_scatter.out
- func: smm(Tensor self, Tensor mat2) -> Tensor
@ -14724,4 +14724,7 @@
python_module: nn
dispatch:
CompositeExplicitAutograd: wait_tensor
# This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts.
- func: _propagate_xla_data(Tensor input, Tensor output) -> ()
variants: function

View File

@ -1221,6 +1221,11 @@ class TestMeta(TestCase):
r = torch.empty(2 ** 52, device='meta', dtype=torch.qint8)
self.assertEqual(r.device.type, 'meta')
def test_nan_to_num(self):
t = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14], device='meta')
r = t.nan_to_num()
self.assertEqual(r.device.type, 'meta')
@onlyCPU
def test_meta_autograd_no_error(self):
lib = torch.library.Library("meta_test", "DEF")

View File

@ -2676,6 +2676,27 @@ def meta_upsample_bilinear2d_aa(
)
# From aten/src/ATen/native/cuda/AmpKernels.cu
@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
check(found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor.")
check(inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor.")
check(
found_inf.dtype.is_floating_point, lambda: "found_inf must be a float tensor."
)
check(
inv_scale.dtype.is_floating_point, lambda: "inv_scale must be a float tensor."
)
# From aten/src/ATen/native/UnaryOps.cpp
@register_meta([aten.nan_to_num.default, aten.nan_to_num.out])
@out_wrapper()
def nan_to_num(self, nan=None, posinf=None, neginf=None):
result_size = list(self.size())
return self.new_empty(result_size)
# We must also trigger meta registrations from PrimTorch ref
# decompositions
import torch._refs

View File

@ -463,7 +463,7 @@ class MetaConverter:
or (ignore_subclass and isinstance(t, torch.Tensor))
or isinstance(t, FakeTensor)
):
if any(
if t.device.type != "xla" and any(
[
t.is_sparse_csr,
t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],

View File

@ -51,6 +51,7 @@
#include <ATen/AccumulateType.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/Functions.h>
@ -1304,7 +1305,7 @@ std::vector<Shape> compute_shape_select_scatter(
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::select_scatter(
auto out_meta = at::compositeexplicitautogradnonfunctional::select_scatter(
self_meta, src_meta, dim, index);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
@ -1329,7 +1330,7 @@ std::vector<Shape> compute_shape_diagonal_scatter(
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::diagonal_scatter(
auto out_meta = at::compositeexplicitautogradnonfunctional::diagonal_scatter(
self_meta, src_meta, offset, dim1, dim2);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
@ -1355,8 +1356,9 @@ std::vector<Shape> compute_shape_slice_scatter_symint(
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::slice_scatter_symint(
self_meta, src_meta, dim, start, end, step);
auto out_meta =
at::compositeexplicitautogradnonfunctional::slice_scatter_symint(
self_meta, src_meta, dim, start, end, step);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
@ -1380,8 +1382,9 @@ std::vector<Shape> compute_shape_as_strided_scatter_symint(
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::as_strided_scatter_symint(
self_meta, src_meta, size, stride, storage_offset);
auto out_meta =
at::compositeexplicitautogradnonfunctional::as_strided_scatter_symint(
self_meta, src_meta, size, stride, storage_offset);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}

View File

@ -432,6 +432,7 @@ Tensor internal_new_from_data(
// to dispatch to it.
// TODO: arguably it should have an autograd implementation that noops
at::AutoDispatchBelowADInplaceOrView guard;
return at::lift_fresh(tensor);
}

View File

@ -323,10 +323,21 @@ class RegisterDispatchKey:
for i, ret_name in enumerate(return_names)
)
returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
else:
elif len(return_names) == 1:
ret_name = return_names[0]
updates = f"{copy_op}({func_res}, {ret_name});"
returns = ret_name
else:
assert len(f.func.arguments.out) == 1
returns = ""
out_arg = f.func.arguments.out[0]
if out_arg.type.is_list_like():
updates = f"""\
for (int64_t i = 0; i < {func_res}.size(); ++i) {{
{copy_op}({func_res}[i], {out_arg.name}[i]);
}}"""
else:
updates = f"{copy_op}({func_res}, {out_arg.name});"
functional_sig = self.wrapper_kernel_sig(g.functional)
wrapper_name = sig.name()

View File

@ -499,6 +499,7 @@ def wrap_propagate_mutations_and_return(
):
updates.append(
f"""\
at::functionalization::impl::propagate_xla_data({outer_arg}, {inner_ret});
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
at::functionalization::impl::commit_update({outer_arg});
at::functionalization::impl::sync({outer_arg});"""
@ -541,6 +542,11 @@ def emit_inplace_functionalization_body(
for a in f.func.arguments.flat_all
if a.type.is_tensor_like() and a.annotation is None
]
non_mutated_tensor_names = [
a.name
for a in f.func.arguments.flat_all
if a.type == BaseType(BaseTy.Tensor) and a.annotation is None
]
# all mutable inputs must be functional tensors in order to participate in functionalization
check_all_mutated_args_are_functional = " && ".join(
["true"]
@ -556,6 +562,14 @@ def emit_inplace_functionalization_body(
for a in non_mutated_names
]
)
check_any_non_mutated_tensors_are_xla = " || ".join(
["false"]
+ [
f"{a}.device().type() == c10::DeviceType::XLA"
for a in non_mutated_tensor_names
]
)
# These are used in the cases where we don't functionalize and redispatch to the inplace op
# case 1: we hit an inplace op that doesn't have an out-of-place equivalent
# case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops)
@ -619,7 +633,9 @@ def emit_inplace_functionalization_body(
}}
{unwrap_tensor_args_str}
if (!({check_all_mutated_args_are_functional})) {{
if (({check_any_non_mutated_args_are_functional})) {{
// We want to disable this check if there are any XLA tensors.
// cpu_tensor.copy_(xla_tensor) is valid code.
if (!({check_any_non_mutated_tensors_are_xla}) && ({check_any_non_mutated_args_are_functional})) {{
// case 1: trying to mutate a non functional tensor with a functional tensor is an error
TORCH_INTERNAL_ASSERT(false,
"mutating a non-functional tensor with a functional tensor is not allowed.",