mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes for PyTorch/XLA functionalization integration (#94537)
Fixes for PyTorch/XLA functionalization integration --- Some notable changes include: - More asserts in `FunctionalTensorWrapper`, so bugs show up more cleanly in cases where we e.g. forget to wrap an output - Make the *_scatter ops `CompositeExplicitAutogradNonFunctional`, so we get a better error message and XLA doesn't accidentally try to us them - Fix LTC/XLA codegen in core to handle multi-tensor out= ops with no returns - Better erroring: Allow XLA to use the CPU fallback from core in a way so that it always errors on view ops, which XLA should no longer see. - Update MetaConverter to exclude XLA tensors in raising NotImplemented… - Add `_propagate_xla_data` op - Add meta tensor support for some ops Pull Request resolved: https://github.com/pytorch/pytorch/pull/94537 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
f397d1700f
commit
3095c95828
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
|||
503401a24e532a9019ef140199319221294045ee
|
||||
f9963f6c2d34b9662f93e5518adb15949be05f65
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#else
|
||||
#include <ATen/ops/_propagate_xla_data.h>
|
||||
#include <ATen/ops/_to_copy.h>
|
||||
#endif
|
||||
|
||||
|
|
@ -51,6 +52,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value)
|
|||
),
|
||||
value_(value)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
|
||||
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||
set_constructor_metadata();
|
||||
}
|
||||
|
||||
|
|
@ -130,6 +133,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
|
|||
),
|
||||
value_(view_value)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
|
||||
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||
set_constructor_metadata();
|
||||
// Copy the original tensor's ViewMeta vector and push the current one.
|
||||
if (!base->view_metas_.empty()) {
|
||||
|
|
@ -168,7 +173,9 @@ void FunctionalTensorWrapper::mutate_view_meta(at::functionalization::ViewMeta m
|
|||
// So, these ops are special - they're mutation AND view ops. They get special codegen.
|
||||
// An example is transpose_, e.g. `a.transpose_()`
|
||||
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
value_ = meta.forward_fn(value_, meta.out_index);
|
||||
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||
}
|
||||
|
||||
// Note [Functionalization: Mutation Removal]
|
||||
|
|
@ -200,15 +207,20 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) {
|
|||
// TODO: going to need to change this if we want nested functionalize() transforms.
|
||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other));
|
||||
value_ = other;
|
||||
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||
// out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor.
|
||||
// We need to propagate that metadata mutation to the wrapper (new size).
|
||||
set_sizes_and_strides(value_.sym_sizes(), value_.sym_strides(), value_.sym_storage_offset());
|
||||
auto sizes_ = value_.sym_sizes();
|
||||
auto strides_ = value_.sym_strides();
|
||||
auto storage_offset_ = value_.sym_storage_offset();
|
||||
set_sizes_and_strides(sizes_, strides_, storage_offset_);
|
||||
if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
|
||||
// .to() should not re-entrantly go through functionalization.
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
// and we want _to_copy() to show up in the graph, not the composite .to() operator
|
||||
// (this can happen if autograd has already run by the time we enter this code)
|
||||
value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout()));
|
||||
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -243,6 +255,7 @@ void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) {
|
|||
// Then it's safe to throw out the old storage and replace it with the new, larger one.
|
||||
storage_ = c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(other));
|
||||
value_ = other;
|
||||
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||
generation_ = 0;
|
||||
// And update the metadata on the wrapper to reflect the new sizes and strides
|
||||
set_sizes_and_strides(value_.sizes(), value_.strides());
|
||||
|
|
@ -484,6 +497,24 @@ void replace_(const ITensorListRef functional_tensor, ITensorListRef other) {
|
|||
}
|
||||
}
|
||||
|
||||
void propagate_xla_data(const Tensor& functional_tensor, const Tensor& other) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
|
||||
if (functional_tensor.key_set().has(c10::DispatchKey::XLA)) {
|
||||
at::_propagate_xla_data(at::functionalization::impl::unsafeGetFunctionalWrapper(functional_tensor)
|
||||
->value(), other);
|
||||
}
|
||||
}
|
||||
|
||||
void propagate_xla_data(const ITensorListRef functional_tensor, ITensorListRef other) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size());
|
||||
auto functional_tensor_it = functional_tensor.begin();
|
||||
auto other_it = other.begin();
|
||||
for (const auto i : c10::irange(functional_tensor.size())) {
|
||||
(void)i; // Suppress unused variable warning
|
||||
propagate_xla_data(*functional_tensor_it++, *other_it++);
|
||||
}
|
||||
}
|
||||
|
||||
void commit_update(const Tensor& functional_tensor) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
|
||||
unsafeGetFunctionalWrapper(functional_tensor)->commit_update();
|
||||
|
|
|
|||
|
|
@ -224,6 +224,15 @@ TORCH_API void replace_(
|
|||
TORCH_API void commit_update(const Tensor& functional_tensor);
|
||||
TORCH_API void commit_update(ITensorListRef functional_tensor);
|
||||
|
||||
// These two methods are XLA-specific logic and are no-ops
|
||||
// for the normal functionalization flow.
|
||||
TORCH_API void propagate_xla_data(
|
||||
const Tensor& functional_tensor,
|
||||
const Tensor& other);
|
||||
TORCH_API void propagate_xla_data(
|
||||
const ITensorListRef functional_tensor,
|
||||
ITensorListRef other);
|
||||
|
||||
Tensor create_functional_tensor_with_view_meta(
|
||||
const Tensor& view_to_wrap,
|
||||
const Tensor& base,
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ c10::optional<c10::Device> compute_target_device(std::vector<at::Tensor>& t_args
|
|||
}
|
||||
|
||||
|
||||
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views) {
|
||||
auto& schema_args = op.schema().arguments();
|
||||
const auto num_arguments = schema_args.size();
|
||||
auto arguments = torch::jit::last(stack, num_arguments);
|
||||
|
|
@ -176,9 +176,15 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
|||
} else {
|
||||
dev_str << "<none>";
|
||||
}
|
||||
TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
|
||||
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
|
||||
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
|
||||
if (error_on_views) {
|
||||
TORCH_CHECK(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
|
||||
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
|
||||
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
|
||||
} else {
|
||||
TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
|
||||
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
|
||||
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
|
||||
}
|
||||
}
|
||||
// Case (2): copy case. Copy the cpu output tensor to the original device.
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ namespace at { namespace native {
|
|||
|
||||
// This function implements a boxed fallback to CPU.
|
||||
// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
|
||||
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
||||
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false);
|
||||
|
||||
// This is a helper function that backends can use to directly call their boxed CPU fallback
|
||||
// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
|
||||
|
|
|
|||
|
|
@ -320,6 +320,10 @@ void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src) {
|
|||
copy_stub(iter.device_type(), iter, /*non_blocking=*/false);
|
||||
}
|
||||
|
||||
void _propagate_xla_data(const Tensor& input, const Tensor& output) {
|
||||
TORCH_INTERNAL_ASSERT(input.device().type() == kXLA, "This op should only be called by XLA")
|
||||
}
|
||||
|
||||
DEFINE_DISPATCH(copy_stub);
|
||||
|
||||
} // namespace native
|
||||
|
|
|
|||
|
|
@ -5091,7 +5091,7 @@
|
|||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: slice_scatter
|
||||
CompositeExplicitAutogradNonFunctional: slice_scatter
|
||||
autogen: slice_scatter.out
|
||||
tags: core
|
||||
|
||||
|
|
@ -5100,7 +5100,7 @@
|
|||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: select_scatter_symint
|
||||
CompositeExplicitAutogradNonFunctional: select_scatter_symint
|
||||
autogen: select_scatter.out
|
||||
|
||||
- func: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor
|
||||
|
|
@ -5108,7 +5108,7 @@
|
|||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: diagonal_scatter
|
||||
CompositeExplicitAutogradNonFunctional: diagonal_scatter
|
||||
autogen: diagonal_scatter.out
|
||||
|
||||
- func: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
|
||||
|
|
@ -5116,7 +5116,7 @@
|
|||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: as_strided_scatter_symint
|
||||
CompositeExplicitAutogradNonFunctional: as_strided_scatter_symint
|
||||
autogen: as_strided_scatter.out
|
||||
|
||||
- func: smm(Tensor self, Tensor mat2) -> Tensor
|
||||
|
|
@ -14724,4 +14724,7 @@
|
|||
python_module: nn
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: wait_tensor
|
||||
|
||||
# This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts.
|
||||
- func: _propagate_xla_data(Tensor input, Tensor output) -> ()
|
||||
variants: function
|
||||
|
|
|
|||
|
|
@ -1221,6 +1221,11 @@ class TestMeta(TestCase):
|
|||
r = torch.empty(2 ** 52, device='meta', dtype=torch.qint8)
|
||||
self.assertEqual(r.device.type, 'meta')
|
||||
|
||||
def test_nan_to_num(self):
|
||||
t = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14], device='meta')
|
||||
r = t.nan_to_num()
|
||||
self.assertEqual(r.device.type, 'meta')
|
||||
|
||||
@onlyCPU
|
||||
def test_meta_autograd_no_error(self):
|
||||
lib = torch.library.Library("meta_test", "DEF")
|
||||
|
|
|
|||
|
|
@ -2676,6 +2676,27 @@ def meta_upsample_bilinear2d_aa(
|
|||
)
|
||||
|
||||
|
||||
# From aten/src/ATen/native/cuda/AmpKernels.cu
|
||||
@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
|
||||
def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
|
||||
check(found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor.")
|
||||
check(inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor.")
|
||||
check(
|
||||
found_inf.dtype.is_floating_point, lambda: "found_inf must be a float tensor."
|
||||
)
|
||||
check(
|
||||
inv_scale.dtype.is_floating_point, lambda: "inv_scale must be a float tensor."
|
||||
)
|
||||
|
||||
|
||||
# From aten/src/ATen/native/UnaryOps.cpp
|
||||
@register_meta([aten.nan_to_num.default, aten.nan_to_num.out])
|
||||
@out_wrapper()
|
||||
def nan_to_num(self, nan=None, posinf=None, neginf=None):
|
||||
result_size = list(self.size())
|
||||
return self.new_empty(result_size)
|
||||
|
||||
|
||||
# We must also trigger meta registrations from PrimTorch ref
|
||||
# decompositions
|
||||
import torch._refs
|
||||
|
|
|
|||
|
|
@ -463,7 +463,7 @@ class MetaConverter:
|
|||
or (ignore_subclass and isinstance(t, torch.Tensor))
|
||||
or isinstance(t, FakeTensor)
|
||||
):
|
||||
if any(
|
||||
if t.device.type != "xla" and any(
|
||||
[
|
||||
t.is_sparse_csr,
|
||||
t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@
|
|||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
||||
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/Functions.h>
|
||||
|
|
@ -1304,7 +1305,7 @@ std::vector<Shape> compute_shape_select_scatter(
|
|||
/*layout=*/c10::make_optional(src.layout()),
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||
/*pin_memory=*/c10::nullopt);
|
||||
auto out_meta = at::compositeexplicitautograd::select_scatter(
|
||||
auto out_meta = at::compositeexplicitautogradnonfunctional::select_scatter(
|
||||
self_meta, src_meta, dim, index);
|
||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
|
@ -1329,7 +1330,7 @@ std::vector<Shape> compute_shape_diagonal_scatter(
|
|||
/*layout=*/c10::make_optional(src.layout()),
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||
/*pin_memory=*/c10::nullopt);
|
||||
auto out_meta = at::compositeexplicitautograd::diagonal_scatter(
|
||||
auto out_meta = at::compositeexplicitautogradnonfunctional::diagonal_scatter(
|
||||
self_meta, src_meta, offset, dim1, dim2);
|
||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
|
@ -1355,8 +1356,9 @@ std::vector<Shape> compute_shape_slice_scatter_symint(
|
|||
/*layout=*/c10::make_optional(src.layout()),
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||
/*pin_memory=*/c10::nullopt);
|
||||
auto out_meta = at::compositeexplicitautograd::slice_scatter_symint(
|
||||
self_meta, src_meta, dim, start, end, step);
|
||||
auto out_meta =
|
||||
at::compositeexplicitautogradnonfunctional::slice_scatter_symint(
|
||||
self_meta, src_meta, dim, start, end, step);
|
||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
||||
|
|
@ -1380,8 +1382,9 @@ std::vector<Shape> compute_shape_as_strided_scatter_symint(
|
|||
/*layout=*/c10::make_optional(src.layout()),
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||
/*pin_memory=*/c10::nullopt);
|
||||
auto out_meta = at::compositeexplicitautograd::as_strided_scatter_symint(
|
||||
self_meta, src_meta, size, stride, storage_offset);
|
||||
auto out_meta =
|
||||
at::compositeexplicitautogradnonfunctional::as_strided_scatter_symint(
|
||||
self_meta, src_meta, size, stride, storage_offset);
|
||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -432,6 +432,7 @@ Tensor internal_new_from_data(
|
|||
// to dispatch to it.
|
||||
// TODO: arguably it should have an autograd implementation that noops
|
||||
at::AutoDispatchBelowADInplaceOrView guard;
|
||||
|
||||
return at::lift_fresh(tensor);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -323,10 +323,21 @@ class RegisterDispatchKey:
|
|||
for i, ret_name in enumerate(return_names)
|
||||
)
|
||||
returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
|
||||
else:
|
||||
elif len(return_names) == 1:
|
||||
ret_name = return_names[0]
|
||||
updates = f"{copy_op}({func_res}, {ret_name});"
|
||||
returns = ret_name
|
||||
else:
|
||||
assert len(f.func.arguments.out) == 1
|
||||
returns = ""
|
||||
out_arg = f.func.arguments.out[0]
|
||||
if out_arg.type.is_list_like():
|
||||
updates = f"""\
|
||||
for (int64_t i = 0; i < {func_res}.size(); ++i) {{
|
||||
{copy_op}({func_res}[i], {out_arg.name}[i]);
|
||||
}}"""
|
||||
else:
|
||||
updates = f"{copy_op}({func_res}, {out_arg.name});"
|
||||
|
||||
functional_sig = self.wrapper_kernel_sig(g.functional)
|
||||
wrapper_name = sig.name()
|
||||
|
|
|
|||
|
|
@ -499,6 +499,7 @@ def wrap_propagate_mutations_and_return(
|
|||
):
|
||||
updates.append(
|
||||
f"""\
|
||||
at::functionalization::impl::propagate_xla_data({outer_arg}, {inner_ret});
|
||||
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
|
||||
at::functionalization::impl::commit_update({outer_arg});
|
||||
at::functionalization::impl::sync({outer_arg});"""
|
||||
|
|
@ -541,6 +542,11 @@ def emit_inplace_functionalization_body(
|
|||
for a in f.func.arguments.flat_all
|
||||
if a.type.is_tensor_like() and a.annotation is None
|
||||
]
|
||||
non_mutated_tensor_names = [
|
||||
a.name
|
||||
for a in f.func.arguments.flat_all
|
||||
if a.type == BaseType(BaseTy.Tensor) and a.annotation is None
|
||||
]
|
||||
# all mutable inputs must be functional tensors in order to participate in functionalization
|
||||
check_all_mutated_args_are_functional = " && ".join(
|
||||
["true"]
|
||||
|
|
@ -556,6 +562,14 @@ def emit_inplace_functionalization_body(
|
|||
for a in non_mutated_names
|
||||
]
|
||||
)
|
||||
|
||||
check_any_non_mutated_tensors_are_xla = " || ".join(
|
||||
["false"]
|
||||
+ [
|
||||
f"{a}.device().type() == c10::DeviceType::XLA"
|
||||
for a in non_mutated_tensor_names
|
||||
]
|
||||
)
|
||||
# These are used in the cases where we don't functionalize and redispatch to the inplace op
|
||||
# case 1: we hit an inplace op that doesn't have an out-of-place equivalent
|
||||
# case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops)
|
||||
|
|
@ -619,7 +633,9 @@ def emit_inplace_functionalization_body(
|
|||
}}
|
||||
{unwrap_tensor_args_str}
|
||||
if (!({check_all_mutated_args_are_functional})) {{
|
||||
if (({check_any_non_mutated_args_are_functional})) {{
|
||||
// We want to disable this check if there are any XLA tensors.
|
||||
// cpu_tensor.copy_(xla_tensor) is valid code.
|
||||
if (!({check_any_non_mutated_tensors_are_xla}) && ({check_any_non_mutated_args_are_functional})) {{
|
||||
// case 1: trying to mutate a non functional tensor with a functional tensor is an error
|
||||
TORCH_INTERNAL_ASSERT(false,
|
||||
"mutating a non-functional tensor with a functional tensor is not allowed.",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user