mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes for PyTorch/XLA functionalization integration (#94537)
Fixes for PyTorch/XLA functionalization integration --- Some notable changes include: - More asserts in `FunctionalTensorWrapper`, so bugs show up more cleanly in cases where we e.g. forget to wrap an output - Make the *_scatter ops `CompositeExplicitAutogradNonFunctional`, so we get a better error message and XLA doesn't accidentally try to us them - Fix LTC/XLA codegen in core to handle multi-tensor out= ops with no returns - Better erroring: Allow XLA to use the CPU fallback from core in a way so that it always errors on view ops, which XLA should no longer see. - Update MetaConverter to exclude XLA tensors in raising NotImplemented… - Add `_propagate_xla_data` op - Add meta tensor support for some ops Pull Request resolved: https://github.com/pytorch/pytorch/pull/94537 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
f397d1700f
commit
3095c95828
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
||||||
503401a24e532a9019ef140199319221294045ee
|
f9963f6c2d34b9662f93e5518adb15949be05f65
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
#ifndef AT_PER_OPERATOR_HEADERS
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
#else
|
#else
|
||||||
|
#include <ATen/ops/_propagate_xla_data.h>
|
||||||
#include <ATen/ops/_to_copy.h>
|
#include <ATen/ops/_to_copy.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -51,6 +52,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value)
|
||||||
),
|
),
|
||||||
value_(value)
|
value_(value)
|
||||||
{
|
{
|
||||||
|
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
|
||||||
|
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||||
set_constructor_metadata();
|
set_constructor_metadata();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -130,6 +133,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
|
||||||
),
|
),
|
||||||
value_(view_value)
|
value_(view_value)
|
||||||
{
|
{
|
||||||
|
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
|
||||||
|
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||||
set_constructor_metadata();
|
set_constructor_metadata();
|
||||||
// Copy the original tensor's ViewMeta vector and push the current one.
|
// Copy the original tensor's ViewMeta vector and push the current one.
|
||||||
if (!base->view_metas_.empty()) {
|
if (!base->view_metas_.empty()) {
|
||||||
|
|
@ -168,7 +173,9 @@ void FunctionalTensorWrapper::mutate_view_meta(at::functionalization::ViewMeta m
|
||||||
// So, these ops are special - they're mutation AND view ops. They get special codegen.
|
// So, these ops are special - they're mutation AND view ops. They get special codegen.
|
||||||
// An example is transpose_, e.g. `a.transpose_()`
|
// An example is transpose_, e.g. `a.transpose_()`
|
||||||
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
|
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
|
||||||
|
at::AutoDispatchSkipFunctionalize guard;
|
||||||
value_ = meta.forward_fn(value_, meta.out_index);
|
value_ = meta.forward_fn(value_, meta.out_index);
|
||||||
|
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note [Functionalization: Mutation Removal]
|
// Note [Functionalization: Mutation Removal]
|
||||||
|
|
@ -200,15 +207,20 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) {
|
||||||
// TODO: going to need to change this if we want nested functionalize() transforms.
|
// TODO: going to need to change this if we want nested functionalize() transforms.
|
||||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other));
|
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other));
|
||||||
value_ = other;
|
value_ = other;
|
||||||
|
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||||
// out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor.
|
// out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor.
|
||||||
// We need to propagate that metadata mutation to the wrapper (new size).
|
// We need to propagate that metadata mutation to the wrapper (new size).
|
||||||
set_sizes_and_strides(value_.sym_sizes(), value_.sym_strides(), value_.sym_storage_offset());
|
auto sizes_ = value_.sym_sizes();
|
||||||
|
auto strides_ = value_.sym_strides();
|
||||||
|
auto storage_offset_ = value_.sym_storage_offset();
|
||||||
|
set_sizes_and_strides(sizes_, strides_, storage_offset_);
|
||||||
if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
|
if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
|
||||||
// .to() should not re-entrantly go through functionalization.
|
// .to() should not re-entrantly go through functionalization.
|
||||||
at::AutoDispatchSkipFunctionalize guard;
|
at::AutoDispatchSkipFunctionalize guard;
|
||||||
// and we want _to_copy() to show up in the graph, not the composite .to() operator
|
// and we want _to_copy() to show up in the graph, not the composite .to() operator
|
||||||
// (this can happen if autograd has already run by the time we enter this code)
|
// (this can happen if autograd has already run by the time we enter this code)
|
||||||
value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout()));
|
value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout()));
|
||||||
|
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -243,6 +255,7 @@ void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) {
|
||||||
// Then it's safe to throw out the old storage and replace it with the new, larger one.
|
// Then it's safe to throw out the old storage and replace it with the new, larger one.
|
||||||
storage_ = c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(other));
|
storage_ = c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(other));
|
||||||
value_ = other;
|
value_ = other;
|
||||||
|
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
|
||||||
generation_ = 0;
|
generation_ = 0;
|
||||||
// And update the metadata on the wrapper to reflect the new sizes and strides
|
// And update the metadata on the wrapper to reflect the new sizes and strides
|
||||||
set_sizes_and_strides(value_.sizes(), value_.strides());
|
set_sizes_and_strides(value_.sizes(), value_.strides());
|
||||||
|
|
@ -484,6 +497,24 @@ void replace_(const ITensorListRef functional_tensor, ITensorListRef other) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void propagate_xla_data(const Tensor& functional_tensor, const Tensor& other) {
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
|
||||||
|
if (functional_tensor.key_set().has(c10::DispatchKey::XLA)) {
|
||||||
|
at::_propagate_xla_data(at::functionalization::impl::unsafeGetFunctionalWrapper(functional_tensor)
|
||||||
|
->value(), other);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void propagate_xla_data(const ITensorListRef functional_tensor, ITensorListRef other) {
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size());
|
||||||
|
auto functional_tensor_it = functional_tensor.begin();
|
||||||
|
auto other_it = other.begin();
|
||||||
|
for (const auto i : c10::irange(functional_tensor.size())) {
|
||||||
|
(void)i; // Suppress unused variable warning
|
||||||
|
propagate_xla_data(*functional_tensor_it++, *other_it++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void commit_update(const Tensor& functional_tensor) {
|
void commit_update(const Tensor& functional_tensor) {
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
|
||||||
unsafeGetFunctionalWrapper(functional_tensor)->commit_update();
|
unsafeGetFunctionalWrapper(functional_tensor)->commit_update();
|
||||||
|
|
|
||||||
|
|
@ -224,6 +224,15 @@ TORCH_API void replace_(
|
||||||
TORCH_API void commit_update(const Tensor& functional_tensor);
|
TORCH_API void commit_update(const Tensor& functional_tensor);
|
||||||
TORCH_API void commit_update(ITensorListRef functional_tensor);
|
TORCH_API void commit_update(ITensorListRef functional_tensor);
|
||||||
|
|
||||||
|
// These two methods are XLA-specific logic and are no-ops
|
||||||
|
// for the normal functionalization flow.
|
||||||
|
TORCH_API void propagate_xla_data(
|
||||||
|
const Tensor& functional_tensor,
|
||||||
|
const Tensor& other);
|
||||||
|
TORCH_API void propagate_xla_data(
|
||||||
|
const ITensorListRef functional_tensor,
|
||||||
|
ITensorListRef other);
|
||||||
|
|
||||||
Tensor create_functional_tensor_with_view_meta(
|
Tensor create_functional_tensor_with_view_meta(
|
||||||
const Tensor& view_to_wrap,
|
const Tensor& view_to_wrap,
|
||||||
const Tensor& base,
|
const Tensor& base,
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ c10::optional<c10::Device> compute_target_device(std::vector<at::Tensor>& t_args
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views) {
|
||||||
auto& schema_args = op.schema().arguments();
|
auto& schema_args = op.schema().arguments();
|
||||||
const auto num_arguments = schema_args.size();
|
const auto num_arguments = schema_args.size();
|
||||||
auto arguments = torch::jit::last(stack, num_arguments);
|
auto arguments = torch::jit::last(stack, num_arguments);
|
||||||
|
|
@ -176,10 +176,16 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||||
} else {
|
} else {
|
||||||
dev_str << "<none>";
|
dev_str << "<none>";
|
||||||
}
|
}
|
||||||
|
if (error_on_views) {
|
||||||
|
TORCH_CHECK(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
|
||||||
|
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
|
||||||
|
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
|
||||||
|
} else {
|
||||||
TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
|
TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
|
||||||
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
|
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
|
||||||
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
|
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// Case (2): copy case. Copy the cpu output tensor to the original device.
|
// Case (2): copy case. Copy the cpu output tensor to the original device.
|
||||||
|
|
||||||
// We technically might not have a target device, e.g. if you call torch.cat() with an empty list
|
// We technically might not have a target device, e.g. if you call torch.cat() with an empty list
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ namespace at { namespace native {
|
||||||
|
|
||||||
// This function implements a boxed fallback to CPU.
|
// This function implements a boxed fallback to CPU.
|
||||||
// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
|
// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
|
||||||
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false);
|
||||||
|
|
||||||
// This is a helper function that backends can use to directly call their boxed CPU fallback
|
// This is a helper function that backends can use to directly call their boxed CPU fallback
|
||||||
// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
|
// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
|
||||||
|
|
|
||||||
|
|
@ -320,6 +320,10 @@ void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src) {
|
||||||
copy_stub(iter.device_type(), iter, /*non_blocking=*/false);
|
copy_stub(iter.device_type(), iter, /*non_blocking=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void _propagate_xla_data(const Tensor& input, const Tensor& output) {
|
||||||
|
TORCH_INTERNAL_ASSERT(input.device().type() == kXLA, "This op should only be called by XLA")
|
||||||
|
}
|
||||||
|
|
||||||
DEFINE_DISPATCH(copy_stub);
|
DEFINE_DISPATCH(copy_stub);
|
||||||
|
|
||||||
} // namespace native
|
} // namespace native
|
||||||
|
|
|
||||||
|
|
@ -5091,7 +5091,7 @@
|
||||||
device_check: NoCheck
|
device_check: NoCheck
|
||||||
device_guard: False
|
device_guard: False
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: slice_scatter
|
CompositeExplicitAutogradNonFunctional: slice_scatter
|
||||||
autogen: slice_scatter.out
|
autogen: slice_scatter.out
|
||||||
tags: core
|
tags: core
|
||||||
|
|
||||||
|
|
@ -5100,7 +5100,7 @@
|
||||||
device_check: NoCheck
|
device_check: NoCheck
|
||||||
device_guard: False
|
device_guard: False
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: select_scatter_symint
|
CompositeExplicitAutogradNonFunctional: select_scatter_symint
|
||||||
autogen: select_scatter.out
|
autogen: select_scatter.out
|
||||||
|
|
||||||
- func: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor
|
- func: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor
|
||||||
|
|
@ -5108,7 +5108,7 @@
|
||||||
device_check: NoCheck
|
device_check: NoCheck
|
||||||
device_guard: False
|
device_guard: False
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: diagonal_scatter
|
CompositeExplicitAutogradNonFunctional: diagonal_scatter
|
||||||
autogen: diagonal_scatter.out
|
autogen: diagonal_scatter.out
|
||||||
|
|
||||||
- func: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
|
- func: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
|
||||||
|
|
@ -5116,7 +5116,7 @@
|
||||||
device_check: NoCheck
|
device_check: NoCheck
|
||||||
device_guard: False
|
device_guard: False
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: as_strided_scatter_symint
|
CompositeExplicitAutogradNonFunctional: as_strided_scatter_symint
|
||||||
autogen: as_strided_scatter.out
|
autogen: as_strided_scatter.out
|
||||||
|
|
||||||
- func: smm(Tensor self, Tensor mat2) -> Tensor
|
- func: smm(Tensor self, Tensor mat2) -> Tensor
|
||||||
|
|
@ -14724,4 +14724,7 @@
|
||||||
python_module: nn
|
python_module: nn
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: wait_tensor
|
CompositeExplicitAutograd: wait_tensor
|
||||||
|
|
||||||
|
# This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts.
|
||||||
|
- func: _propagate_xla_data(Tensor input, Tensor output) -> ()
|
||||||
variants: function
|
variants: function
|
||||||
|
|
|
||||||
|
|
@ -1221,6 +1221,11 @@ class TestMeta(TestCase):
|
||||||
r = torch.empty(2 ** 52, device='meta', dtype=torch.qint8)
|
r = torch.empty(2 ** 52, device='meta', dtype=torch.qint8)
|
||||||
self.assertEqual(r.device.type, 'meta')
|
self.assertEqual(r.device.type, 'meta')
|
||||||
|
|
||||||
|
def test_nan_to_num(self):
|
||||||
|
t = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14], device='meta')
|
||||||
|
r = t.nan_to_num()
|
||||||
|
self.assertEqual(r.device.type, 'meta')
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_meta_autograd_no_error(self):
|
def test_meta_autograd_no_error(self):
|
||||||
lib = torch.library.Library("meta_test", "DEF")
|
lib = torch.library.Library("meta_test", "DEF")
|
||||||
|
|
|
||||||
|
|
@ -2676,6 +2676,27 @@ def meta_upsample_bilinear2d_aa(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# From aten/src/ATen/native/cuda/AmpKernels.cu
|
||||||
|
@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
|
||||||
|
def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
|
||||||
|
check(found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor.")
|
||||||
|
check(inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor.")
|
||||||
|
check(
|
||||||
|
found_inf.dtype.is_floating_point, lambda: "found_inf must be a float tensor."
|
||||||
|
)
|
||||||
|
check(
|
||||||
|
inv_scale.dtype.is_floating_point, lambda: "inv_scale must be a float tensor."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# From aten/src/ATen/native/UnaryOps.cpp
|
||||||
|
@register_meta([aten.nan_to_num.default, aten.nan_to_num.out])
|
||||||
|
@out_wrapper()
|
||||||
|
def nan_to_num(self, nan=None, posinf=None, neginf=None):
|
||||||
|
result_size = list(self.size())
|
||||||
|
return self.new_empty(result_size)
|
||||||
|
|
||||||
|
|
||||||
# We must also trigger meta registrations from PrimTorch ref
|
# We must also trigger meta registrations from PrimTorch ref
|
||||||
# decompositions
|
# decompositions
|
||||||
import torch._refs
|
import torch._refs
|
||||||
|
|
|
||||||
|
|
@ -463,7 +463,7 @@ class MetaConverter:
|
||||||
or (ignore_subclass and isinstance(t, torch.Tensor))
|
or (ignore_subclass and isinstance(t, torch.Tensor))
|
||||||
or isinstance(t, FakeTensor)
|
or isinstance(t, FakeTensor)
|
||||||
):
|
):
|
||||||
if any(
|
if t.device.type != "xla" and any(
|
||||||
[
|
[
|
||||||
t.is_sparse_csr,
|
t.is_sparse_csr,
|
||||||
t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
|
t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@
|
||||||
|
|
||||||
#include <ATen/AccumulateType.h>
|
#include <ATen/AccumulateType.h>
|
||||||
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
||||||
|
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
|
||||||
#include <ATen/Dispatch.h>
|
#include <ATen/Dispatch.h>
|
||||||
#include <ATen/ExpandUtils.h>
|
#include <ATen/ExpandUtils.h>
|
||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
|
|
@ -1304,7 +1305,7 @@ std::vector<Shape> compute_shape_select_scatter(
|
||||||
/*layout=*/c10::make_optional(src.layout()),
|
/*layout=*/c10::make_optional(src.layout()),
|
||||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||||
/*pin_memory=*/c10::nullopt);
|
/*pin_memory=*/c10::nullopt);
|
||||||
auto out_meta = at::compositeexplicitautograd::select_scatter(
|
auto out_meta = at::compositeexplicitautogradnonfunctional::select_scatter(
|
||||||
self_meta, src_meta, dim, index);
|
self_meta, src_meta, dim, index);
|
||||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
@ -1329,7 +1330,7 @@ std::vector<Shape> compute_shape_diagonal_scatter(
|
||||||
/*layout=*/c10::make_optional(src.layout()),
|
/*layout=*/c10::make_optional(src.layout()),
|
||||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||||
/*pin_memory=*/c10::nullopt);
|
/*pin_memory=*/c10::nullopt);
|
||||||
auto out_meta = at::compositeexplicitautograd::diagonal_scatter(
|
auto out_meta = at::compositeexplicitautogradnonfunctional::diagonal_scatter(
|
||||||
self_meta, src_meta, offset, dim1, dim2);
|
self_meta, src_meta, offset, dim1, dim2);
|
||||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
@ -1355,7 +1356,8 @@ std::vector<Shape> compute_shape_slice_scatter_symint(
|
||||||
/*layout=*/c10::make_optional(src.layout()),
|
/*layout=*/c10::make_optional(src.layout()),
|
||||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||||
/*pin_memory=*/c10::nullopt);
|
/*pin_memory=*/c10::nullopt);
|
||||||
auto out_meta = at::compositeexplicitautograd::slice_scatter_symint(
|
auto out_meta =
|
||||||
|
at::compositeexplicitautogradnonfunctional::slice_scatter_symint(
|
||||||
self_meta, src_meta, dim, start, end, step);
|
self_meta, src_meta, dim, start, end, step);
|
||||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
@ -1380,7 +1382,8 @@ std::vector<Shape> compute_shape_as_strided_scatter_symint(
|
||||||
/*layout=*/c10::make_optional(src.layout()),
|
/*layout=*/c10::make_optional(src.layout()),
|
||||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||||
/*pin_memory=*/c10::nullopt);
|
/*pin_memory=*/c10::nullopt);
|
||||||
auto out_meta = at::compositeexplicitautograd::as_strided_scatter_symint(
|
auto out_meta =
|
||||||
|
at::compositeexplicitautogradnonfunctional::as_strided_scatter_symint(
|
||||||
self_meta, src_meta, size, stride, storage_offset);
|
self_meta, src_meta, size, stride, storage_offset);
|
||||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -432,6 +432,7 @@ Tensor internal_new_from_data(
|
||||||
// to dispatch to it.
|
// to dispatch to it.
|
||||||
// TODO: arguably it should have an autograd implementation that noops
|
// TODO: arguably it should have an autograd implementation that noops
|
||||||
at::AutoDispatchBelowADInplaceOrView guard;
|
at::AutoDispatchBelowADInplaceOrView guard;
|
||||||
|
|
||||||
return at::lift_fresh(tensor);
|
return at::lift_fresh(tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -323,10 +323,21 @@ class RegisterDispatchKey:
|
||||||
for i, ret_name in enumerate(return_names)
|
for i, ret_name in enumerate(return_names)
|
||||||
)
|
)
|
||||||
returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
|
returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
|
||||||
else:
|
elif len(return_names) == 1:
|
||||||
ret_name = return_names[0]
|
ret_name = return_names[0]
|
||||||
updates = f"{copy_op}({func_res}, {ret_name});"
|
updates = f"{copy_op}({func_res}, {ret_name});"
|
||||||
returns = ret_name
|
returns = ret_name
|
||||||
|
else:
|
||||||
|
assert len(f.func.arguments.out) == 1
|
||||||
|
returns = ""
|
||||||
|
out_arg = f.func.arguments.out[0]
|
||||||
|
if out_arg.type.is_list_like():
|
||||||
|
updates = f"""\
|
||||||
|
for (int64_t i = 0; i < {func_res}.size(); ++i) {{
|
||||||
|
{copy_op}({func_res}[i], {out_arg.name}[i]);
|
||||||
|
}}"""
|
||||||
|
else:
|
||||||
|
updates = f"{copy_op}({func_res}, {out_arg.name});"
|
||||||
|
|
||||||
functional_sig = self.wrapper_kernel_sig(g.functional)
|
functional_sig = self.wrapper_kernel_sig(g.functional)
|
||||||
wrapper_name = sig.name()
|
wrapper_name = sig.name()
|
||||||
|
|
|
||||||
|
|
@ -499,6 +499,7 @@ def wrap_propagate_mutations_and_return(
|
||||||
):
|
):
|
||||||
updates.append(
|
updates.append(
|
||||||
f"""\
|
f"""\
|
||||||
|
at::functionalization::impl::propagate_xla_data({outer_arg}, {inner_ret});
|
||||||
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
|
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
|
||||||
at::functionalization::impl::commit_update({outer_arg});
|
at::functionalization::impl::commit_update({outer_arg});
|
||||||
at::functionalization::impl::sync({outer_arg});"""
|
at::functionalization::impl::sync({outer_arg});"""
|
||||||
|
|
@ -541,6 +542,11 @@ def emit_inplace_functionalization_body(
|
||||||
for a in f.func.arguments.flat_all
|
for a in f.func.arguments.flat_all
|
||||||
if a.type.is_tensor_like() and a.annotation is None
|
if a.type.is_tensor_like() and a.annotation is None
|
||||||
]
|
]
|
||||||
|
non_mutated_tensor_names = [
|
||||||
|
a.name
|
||||||
|
for a in f.func.arguments.flat_all
|
||||||
|
if a.type == BaseType(BaseTy.Tensor) and a.annotation is None
|
||||||
|
]
|
||||||
# all mutable inputs must be functional tensors in order to participate in functionalization
|
# all mutable inputs must be functional tensors in order to participate in functionalization
|
||||||
check_all_mutated_args_are_functional = " && ".join(
|
check_all_mutated_args_are_functional = " && ".join(
|
||||||
["true"]
|
["true"]
|
||||||
|
|
@ -556,6 +562,14 @@ def emit_inplace_functionalization_body(
|
||||||
for a in non_mutated_names
|
for a in non_mutated_names
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
check_any_non_mutated_tensors_are_xla = " || ".join(
|
||||||
|
["false"]
|
||||||
|
+ [
|
||||||
|
f"{a}.device().type() == c10::DeviceType::XLA"
|
||||||
|
for a in non_mutated_tensor_names
|
||||||
|
]
|
||||||
|
)
|
||||||
# These are used in the cases where we don't functionalize and redispatch to the inplace op
|
# These are used in the cases where we don't functionalize and redispatch to the inplace op
|
||||||
# case 1: we hit an inplace op that doesn't have an out-of-place equivalent
|
# case 1: we hit an inplace op that doesn't have an out-of-place equivalent
|
||||||
# case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops)
|
# case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops)
|
||||||
|
|
@ -619,7 +633,9 @@ def emit_inplace_functionalization_body(
|
||||||
}}
|
}}
|
||||||
{unwrap_tensor_args_str}
|
{unwrap_tensor_args_str}
|
||||||
if (!({check_all_mutated_args_are_functional})) {{
|
if (!({check_all_mutated_args_are_functional})) {{
|
||||||
if (({check_any_non_mutated_args_are_functional})) {{
|
// We want to disable this check if there are any XLA tensors.
|
||||||
|
// cpu_tensor.copy_(xla_tensor) is valid code.
|
||||||
|
if (!({check_any_non_mutated_tensors_are_xla}) && ({check_any_non_mutated_args_are_functional})) {{
|
||||||
// case 1: trying to mutate a non functional tensor with a functional tensor is an error
|
// case 1: trying to mutate a non functional tensor with a functional tensor is an error
|
||||||
TORCH_INTERNAL_ASSERT(false,
|
TORCH_INTERNAL_ASSERT(false,
|
||||||
"mutating a non-functional tensor with a functional tensor is not allowed.",
|
"mutating a non-functional tensor with a functional tensor is not allowed.",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user