*_scatter ops should preserve input stride/storage_offset (#91029)

It turns out that we *do* need to update *_scatter ops to return the exact same strides as their inputs. I added a test to `test/test_functionalization.py`, which now trips thanks to Ed's functionalization stride debugging check. It only actually ends up tripping silent correctness if you try to .backward() on that function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91029
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh 2022-12-22 15:37:24 +00:00 committed by PyTorch MergeBot
parent a32916190d
commit c47bdd7522
22 changed files with 353 additions and 184 deletions

View File

@ -146,10 +146,14 @@ functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_st
void FunctionalTensorWrapper::commit_update() {
auto storage_impl = functional_storage_impl();
storage_impl->add_update(value_, view_metas_);
// Invariant: commit_update() is called during an inplace operation.
// Tensor inputs to the operation are synced before runnig the op,
// so the current tensor must be up-to-date with its alias at this point.
generation_ = storage_impl->generation();
// As an optimization, we used to mark the tensor here as "up-to-date",
// That way, code like:
// x = torch.ones(1'000'000)
// x[0].add_(1)
// doesn't result in an unnecessary materialization of the base.
// This optimization results in the slice temporarily haven't incorrect
// stride/storage_offset though, and DCE should handle that optimization anyway.
// generation_ = storage_impl->generation();
}
bool FunctionalTensorWrapper::is_up_to_date() const {

View File

@ -16,8 +16,8 @@ MemOverlap has_internal_overlap(TensorImpl* t) {
return MemOverlap::No;
}
auto strides = t->strides();
auto sizes = t->sizes();
auto strides = t->sym_strides();
auto sizes = t->sym_sizes();
for (const auto i : c10::irange(strides.size())) {
if (strides[i] == 0 && sizes[i] > 1) {
return MemOverlap::Yes;

View File

@ -5,13 +5,14 @@
// LICENSE file in the root directory of this source tree.
#include <ATen/functorch/BatchRulesHelper.h>
#include <iostream>
#include <ATen/Operators.h>
#include <ATen/functorch/PlumbingHelper.h>
#include <ATen/functorch/BatchedFallback.h>
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/native/IndexKernel.h>
#include <ATen/native/IndexingUtils.h>
#include <iostream>
#include <torch/library.h>
namespace at { namespace functorch {
@ -1074,6 +1075,12 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(scatter_add, scatter_add_batch_rule);
VMAP_SUPPORT2(scatter, reduce, scatter_reduce_batch_rule);
VMAP_SUPPORT2(scatter, value_reduce, scatter_value_reduce_batch_rule);
// as_strided_scatter does not work with the for-loop fallback today,
// because as_strided_scatter will return an output that matches
// the strides/storage_offset of its input.
// With the for loop fallback, each input tensor is a slice into
// the larger batched tensor.
m.impl("as_strided_scatter", torch::CppFunction::makeFromBoxedFunction<&vmapErrorFallback>());
}
}}

View File

@ -396,5 +396,9 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
}
}
void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
TORCH_CHECK(false, "Error: ", op.operator_name(), " requires special handling, and does not yet have a batching rule. Feel free to file a github issue!");
}
}
} // namespace at

View File

@ -32,6 +32,8 @@ namespace functorch {
// write batching rules for operators whenever possible.
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
// The vmap fallback emits a warning by default, but it may be disabled if
// the user finds it to be too annoying.
TORCH_API bool isVmapFallbackWarningEnabled();

View File

@ -8,6 +8,7 @@
#include <ATen/native/quantized/Copy.h>
#include <ATen/native/mps/Copy.h>
#include <ATen/native/vulkan/ops/Copy.h>
#include <ATen/native/TensorShape.h>
#include <ATen/quantized/Quantizer.h>
#include <ATen/vulkan/Context.h>
#include <ATen/metal/Context.h>
@ -278,32 +279,6 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
return self;
}
// NB: cribbed from https://github.com/pytorch/pytorch/pull/88198
at::Tensor clone_preserve_strides(const at::Tensor& self) {
TORCH_INTERNAL_ASSERT(self.has_storage());
// In cases where the input tensor has internal memory overlap, we cannot actually
// preserve the strides/storage_offset of the input tensor, because
// *_scatter ops will try to copy_() into the cloned tensor.
// However, this should **never** show up in functionalized user code;
// most aten ops that try to mutate a tensor with internal memory overlap would error anyway.
//
// The one place that this does come up is in autograd - if there's a select_scatter
// in the forward, then autograd will generate one for the backward.
// If the input to the select_scatter is grad_output, then this could be an expanded tensor
// with internal overlap.
//if (at::has_internal_overlap(self) == at::MemOverlap::Yes) {
// return self.clone();
//}
auto dtype_size = self.dtype().itemsize();
auto nbytes = self.storage().sym_nbytes();
TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0);
auto numel = nbytes / dtype_size;
auto self_full_size = self.as_strided_symint({numel}, {1}, 0);
auto clone = self_full_size.clone();
auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
return out;
}
Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) {
// copy() is the "functional" form of copy_(). It exists so we can properly functionalize copy_(), but:
// (1) It isn't exposed to the frontend (no python bindings)

View File

@ -3801,22 +3801,58 @@ std::vector<Tensor> unflatten_dense_tensors(const Tensor& flat, TensorList tenso
return outputs;
}
// Clones a tensor by cloning the underlying storage that it came from,
// which allows us to replicate the exact strides/storage_offset in the cloned tensor.
// Note [*_scatter ops preserve strides]
// In order for functionalization to preserve stride correctness, the *_scatter
// operators that it calls must preserve the striding behavior of their inputs.
// Specifically, the output of *_scatter(base, mutated_view, ...)
// should have identical size/stride/storage_offset to "base".
at::Tensor clone_preserve_strides(const at::Tensor& self) {
TORCH_INTERNAL_ASSERT(self.has_storage());
// In cases where the input tensor has internal memory overlap, we cannot actually
// preserve the strides/storage_offset of the input tensor, because
// *_scatter ops will try to copy_() into the cloned tensor.
// However, this should **never** show up in functionalized user code;
// most aten ops that try to mutate a tensor with internal memory overlap would error anyway.
//
// The one place that this does come up is in autograd - if there's a select_scatter
// in the forward, then autograd will generate one for the backward.
// If the input to the select_scatter is grad_output, then this could be an expanded tensor
// with internal overlap.
if (at::has_internal_overlap(self) == at::MemOverlap::Yes) {
return self.clone();
}
auto dtype_size = self.dtype().itemsize();
auto nbytes = self.storage().sym_nbytes();
TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0);
auto numel = nbytes / dtype_size;
auto self_full_size = self.as_strided_symint({numel}, {1}, 0);
auto clone = self_full_size.clone();
auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
return out;
}
at::Tensor slice_scatter(const at::Tensor& self, const at::Tensor& src, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step) {
auto output = self.clone();
// See Note [*_scatter ops preserve strides]
auto output = clone_preserve_strides(self);
auto slice = output.slice(dim, start, end, step);
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
slice.copy_(src);
return output;
}
at::Tensor select_scatter_symint(const at::Tensor& self, const at::Tensor& src, int64_t dim, c10::SymInt index) {
auto output = self.clone();
auto output = clone_preserve_strides(self);
auto slice = output.select_symint(dim, index);
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
slice.copy_(src);
return output;
}
at::Tensor diagonal_scatter(const at::Tensor& self, const at::Tensor& src, int64_t offset, int64_t dim1, int64_t dim2) {
auto output = self.clone();
// See Note [*_scatter ops preserve strides]
auto output = clone_preserve_strides(self);
auto slice = output.diagonal(offset, dim1, dim2);
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
slice.copy_(src);
@ -3825,7 +3861,8 @@ at::Tensor diagonal_scatter(const at::Tensor& self, const at::Tensor& src, int64
at::Tensor as_strided_scatter_symint(const at::Tensor& self, const at::Tensor& src, at::SymIntArrayRef size, at::SymIntArrayRef stride, c10::optional<c10::SymInt> storage_offset) {
// See Note [as_strided_scatter backward support]
TORCH_INTERNAL_ASSERT(!self.requires_grad() || self.is_contiguous(), "as_strided_scatter is currently only supported for contiguous inputs");
auto output = self.clone();
// See Note [*_scatter ops preserve strides]
auto output = clone_preserve_strides(self);
auto slice = output.as_strided_symint(size, stride, std::move(storage_offset));
TORCH_CHECK(slice.sym_sizes() == src.sym_sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sym_sizes(), ", slice size = ", slice.sym_sizes());
slice.copy_(src);

View File

@ -5,6 +5,9 @@
namespace at {
namespace native {
TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self);
inline bool cat_should_skip_tensor(const Tensor& t) {
return t.numel() == 0 && t.dim() == 1;
}

View File

@ -613,8 +613,8 @@ def forward(self, primals_1):
t_1 = torch.ops.aten.t.default(clone); clone = None
select_scatter = torch.ops.aten.select_scatter.default(t_1, mul, 0, 0); t_1 = mul = None
t_2 = torch.ops.aten.t.default(select_scatter); select_scatter = None
t_3 = torch.ops.aten.t.default(t_2); t_2 = None
return [t_3, 3, 3, 1, 3, 0]""")
t_4 = torch.ops.aten.t.default(t_2); t_2 = None
return [t_4, 3, 3, 1, 3, 0]""")
def test_view_and_inplace_view(self):
def f(a, b):
@ -683,11 +683,12 @@ def forward(self, primals_1):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
t_1 = torch.ops.aten.t.default(as_strided_5); as_strided_5 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
as_strided_6 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
t_1 = torch.ops.aten.t.default(as_strided_6); as_strided_6 = None
mul_1 = torch.ops.aten.mul.Tensor(t_1, 3); t_1 = None
return [mul, mul_1, 4, 1, 0]""")
return [as_strided_3, mul_1, 4, 1, 0]""")
def test_input_mutation_aliases_other_input(self):
def f(a, b):
@ -712,10 +713,11 @@ def forward(self, primals_1):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = None
as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 2); as_strided_scatter = None
add_1 = torch.ops.aten.add.Tensor(add, as_strided_4); as_strided_4 = None
return [add, add_1]""")
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 2); as_strided_scatter = None
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_5 = None
return [as_strided_2, add_1]""")
def test_input_mutation_aliases_other_input2(self):
def f(a, b):
@ -736,10 +738,11 @@ def forward(self, primals_1):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = None
as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0); as_strided_scatter = None
add_1 = torch.ops.aten.add.Tensor(add, as_strided_4); as_strided_4 = None
return [add, add_1]""")
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0); as_strided_scatter = None
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_5 = None
return [as_strided_2, add_1]""")
def test_input_mutation_aliases_and_output_alias(self):
def f(a, b):
@ -758,9 +761,11 @@ def forward(self, primals_1):
self.assertExpectedInline(fw_graph.code.strip(), """\
def forward(self, primals_1):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0); clone = None
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
return [add, 4, 1, 0]""")
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
return [as_strided_2, 4, 1, 0]""")
def test_input_aliased_with_mutation_output_alias(self):
def f(a, b, c):
@ -783,10 +788,12 @@ def forward(self, primals_1):
self.assertExpectedInline(fw_graph.code.strip(), """\
def forward(self, primals_1, primals_2):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0); clone = None
as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
return [mul, add, 4, 1, 0]""")
return [as_strided_2, add, 4, 1, 0]""")
def test_input_metadata_mutation_aliases(self):
def f(a, b):
@ -829,11 +836,12 @@ def forward(self, primals_1, primals_2):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
mul = torch.ops.aten.mul.Tensor(as_strided, 2); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
add = torch.ops.aten.add.Tensor(as_strided_2, 1); as_strided_2 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
add = torch.ops.aten.add.Tensor(as_strided_3, 1); as_strided_3 = None
add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
return [mul, add, add_1]""")
return [as_strided_2, add, add_1]""")
def test_input_mutation_aliases_bases_out_of_order(self):
# This tests our calling convention: if b and d are aliased, then the outer calling convention
@ -864,12 +872,13 @@ def forward(self, primals_1, primals_2, primals_3):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = None
as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
t_1 = torch.ops.aten.t.default(as_strided_4); as_strided_4 = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
t_1 = torch.ops.aten.t.default(as_strided_5); as_strided_5 = None
add_2 = torch.ops.aten.add.Tensor(add_1, t_1); add_1 = t_1 = None
return [add, add_2, 4, 1, 0, 4, 1, 0]""")
return [as_strided_2, add_2, 4, 1, 0, 4, 1, 0]""")
# Mondo test that tests a combination of:
# input is mutated, that aliases another input (so we make a synthetic base)
@ -913,10 +922,11 @@ def forward(self, primals_1, primals_2):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = None
as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
add = torch.ops.aten.add.Tensor(as_strided_4, mul); as_strided_4 = None
return [mul, add, 2, 2, 1, 2, 0, 2, 2, 2, 1, 0]""")
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = None
return [as_strided_2, add, 2, 2, 1, 2, 0, 2, 2, 2, 1, 0]""")
def test_no_grad_input_output(self):
def f(a, b):

View File

@ -3717,7 +3717,7 @@ class TestFunctionalize(TestCase):
z = y2[0]
z.add_(tmp)
return y
self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device))
self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device), skip_vmap=True)
# See https://github.com/pytorch/functorch/issues/780
def test_linear(self, device):
@ -3840,6 +3840,7 @@ def forward(self, x_1) -> torch.Tensor:
view_copy = torch.ops.aten.view_copy.default(x_1, [4, 2])
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
copy_ = torch.ops.aten.copy_.default(x_1, view_copy_1); x_1 = None
return view_copy_1
""")
@ -3883,6 +3884,7 @@ def forward(self, inpt_1) -> torch.Tensor:
view_copy_1 = torch.ops.aten.view_copy.default(add, [4]); add = None
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None
view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4]); add_1 = None
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4])
return view_copy_2
""")
@ -3912,6 +3914,7 @@ def forward(self, inpt_1) -> torch.Tensor:
getitem = aminmax[0]
getitem_1 = aminmax[1]; aminmax = None
view_copy_2 = torch.ops.aten.view_copy.default(getitem_1, [2, 2]); getitem_1 = None
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4])
return (view_copy_2, getitem)
""")
@ -3934,6 +3937,7 @@ def forward(self, x_1) -> torch.Tensor:
view = torch.ops.aten.view.default(x_1, [4, 2])
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
view_2 = torch.ops.aten.view.default(view_1, [4, 2])
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = None
return view_1
""")

View File

@ -460,6 +460,7 @@ class TestOperators(TestCase):
# AssertionError: Tensor-likes are not close!
xfail('as_strided'),
xfail('as_strided', 'partial_views'),
xfail('as_strided_scatter'),
decorate('linalg.det', 'singular',
decorator=expectedFailureIf(IS_MACOS and IS_X86)),
}))
@ -965,6 +966,8 @@ class TestOperators(TestCase):
xfail('tensor_split'), # data_ptr composite compliance
xfail('quantile'), # at::equal batching rule (cpu), also, in-place vmap (cuda)
skip('as_strided'), # Test runner cannot handle this
# requires special handling, and does not yet have a batching rule. Feel free to file a github issue!
xfail('as_strided_scatter'),
xfail('nn.functional.gaussian_nll_loss'), # .item or data-dependent control flow
xfail('scatter'), # forward-mode AD does not support at::scatter
xfail('nanquantile'), # at::equal batching rule (cpu), also, in-place vmap (cuda)

View File

@ -3266,6 +3266,7 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('nn.functional.alpha_dropout', ''), # randomness
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # randomness
xfail('as_strided'), # Our test runner can't handle this; manual test exists
xfail('as_strided_scatter'), # no batching rule implemented, default doesnt work
skip('new_empty_strided'), # empty tensor data is garbage so it's hard to make comparisons with it
xfail('nn.functional.fractional_max_pool3d'), # randomness
xfail('nn.functional.fractional_max_pool2d'), # randomness

View File

@ -11,17 +11,14 @@ from torch.utils._pytree import tree_map, tree_map_only, tree_flatten
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.reinplace import reinplace
from torch._dispatch.python import enable_crossref_functionalize, enable_python_dispatcher
from torch.multiprocessing.reductions import StorageWeakRef
import unittest
def are_aliased(x, y):
if x._base is None and y._base is None:
return False
if x._base is not None and y._base is None:
return x._base is y
if x._base is None and y._base is not None:
return y._base is x
return x._base is y._base
x_storage = StorageWeakRef(x.storage())
y_storage = StorageWeakRef(y.storage())
return x_storage == y_storage
# We can unify testing and use functionalize() here instead
# if/when functorch moves into core.
@ -196,23 +193,26 @@ def forward(self, arg0_1):
clone = torch.ops.aten.clone.default(view_copy); view_copy = None
view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128])
relu = torch.ops.aten.relu.default(view_copy_1); view_copy_1 = None
view_copy_2 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = None
sum_1 = torch.ops.aten.sum.default(relu)
view_copy_2 = torch.ops.aten.view_copy.default(relu, [1, 1024, 128, 128]); relu = None
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [16, 64, 128, 128]); view_copy_2 = None
view_copy_4 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = None
sum_1 = torch.ops.aten.sum.default(view_copy_3)
ones_like = torch.ops.aten.ones_like.default(sum_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False, memory_format = torch.preserve_format); sum_1 = None
expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None
view_copy_3 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]); expand_copy = None
new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_3, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
copy = torch.ops.aten.copy.default(new_empty_strided, view_copy_3); new_empty_strided = view_copy_3 = None
view_copy_4 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
view_copy_5 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
clone_1 = torch.ops.aten.clone.default(view_copy_5, memory_format = torch.contiguous_format)
threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, relu, 0); clone_1 = relu = None
copy_1 = torch.ops.aten.copy.default(view_copy_5, threshold_backward); view_copy_5 = threshold_backward = None
view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]); copy = None
detach_copy = torch.ops.aten.detach_copy.default(view_copy_6); view_copy_6 = None
view_copy_7 = torch.ops.aten.view_copy.default(copy_1, [1, 1024, 128, 128]); copy_1 = None
view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [16, 64, 128, 128]); view_copy_7 = None
detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_8); view_copy_8 = None
view_copy_5 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]); expand_copy = None
new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_5, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
copy = torch.ops.aten.copy.default(new_empty_strided, view_copy_5); new_empty_strided = view_copy_5 = None
view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
view_copy_7 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
clone_1 = torch.ops.aten.clone.default(view_copy_7, memory_format = torch.contiguous_format)
threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, view_copy_3, 0); clone_1 = view_copy_3 = None
copy_1 = torch.ops.aten.copy.default(view_copy_7, threshold_backward); view_copy_7 = threshold_backward = None
view_copy_8 = torch.ops.aten.view_copy.default(copy_1, [1, 1024, 128, 128]); copy_1 = None
view_copy_9 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128])
view_copy_10 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]); copy = None
detach_copy = torch.ops.aten.detach_copy.default(view_copy_10); view_copy_10 = None
view_copy_11 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]); view_copy_8 = None
detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_11); view_copy_11 = None
return detach_copy_1
""") # noqa: B950
@ -234,10 +234,11 @@ def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2])
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1)
copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None
return add
return view_copy_2
""")
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
@ -249,10 +250,11 @@ def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view = torch.ops.aten.view.default(arg0_1, [4, 2])
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
view_1 = torch.ops.aten.view.default(add, [4, 2])
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
view_2 = torch.ops.aten.view.default(view_1, [4, 2])
mul = torch.ops.aten.mul.Tensor(view_1, view_1)
copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = None
return add
return view_2
""")
def test_simple_out(self):
@ -351,6 +353,7 @@ def forward(self, arg0_1):
view_copy = torch.ops.aten.view_copy.default(lift_fresh_copy, [-1]); lift_fresh_copy = None
add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [3]); add = None
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [-1])
return view_copy_1
""")
@ -365,9 +368,19 @@ def forward(self, arg0_1):
view = torch.ops.aten.view.default(lift_fresh_copy, [-1]); lift_fresh_copy = None
add = torch.ops.aten.add_.Tensor(view, 1)
view_1 = torch.ops.aten.view.default(view, [3]); view = None
view_2 = torch.ops.aten.view.default(view_1, [-1])
return view_1
""")
def test_advanced_indexing_correct_strides(self):
def f(a):
# This test requires that *_scatter ops are able to return
# non-contiguous tensors.
b = a.clone()[:, 1]
c = torch.ones_like(b, dtype=torch.bool)
d = b.masked_fill_(c, 0)
return d
self.assert_functionalization(f, torch.ones(2, 2), reapply_views=True)
def test_tensor_list_mixed_functional_nonfunctional(self):
nonfunctional_tensor = torch.ones(2, dtype=torch.long)
@ -456,6 +469,7 @@ def forward(self, arg0_1):
as_strided_copy = torch.ops.aten.as_strided_copy.default(arg0_1, [2], [2], 1)
add = torch.ops.aten.add.Tensor(as_strided_copy, 1); as_strided_copy = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1); add = None
as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [2], [2], 1)
copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = None
return as_strided_scatter
""")
@ -522,8 +536,10 @@ def forward(self, arg0_1):
def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
clone = torch.ops.aten.clone.default(arg0_1)
diagonal_copy = torch.ops.aten.diagonal_copy.default(clone); clone = None
diagonal_copy = torch.ops.aten.diagonal_copy.default(clone)
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(clone, add); clone = add = None
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_scatter = None
mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
return mul
""")
@ -536,8 +552,9 @@ def forward(self, arg0_1):
def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
clone = torch.ops.aten.clone.default(arg0_1)
diagonal = torch.ops.aten.diagonal.default(clone); clone = None
diagonal = torch.ops.aten.diagonal.default(clone)
add = torch.ops.aten.add_.Tensor(diagonal, ones); diagonal = ones = None
diagonal_1 = torch.ops.aten.diagonal.default(clone); clone = None
mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
return mul
""")
@ -561,6 +578,7 @@ def forward(self, arg0_1):
diagonal_copy = torch.ops.aten.diagonal_copy.default(arg0_1)
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add); add = None
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = None
return diagonal_scatter
""")
@ -590,11 +608,15 @@ def forward(self, arg0_1):
split_copy_1 = torch.ops.aten.split_copy.Tensor(arg0_1, 2)
getitem_2 = split_copy_1[0]
getitem_3 = split_copy_1[1]; split_copy_1 = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = add = None
slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None
split_copy_2 = torch.ops.aten.split_copy.Tensor(slice_scatter, 2)
getitem_4 = split_copy_2[0]
getitem_5 = split_copy_2[1]; split_copy_2 = None
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_5); getitem_5 = None
mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None
return add
return diagonal_copy_1
""") # noqa: B950
def test_view_inplace(self):
@ -619,8 +641,10 @@ def forward(self, arg0_1):
transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None
select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None
transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None
transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None
return transpose_copy_3
transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
select_copy_1 = torch.ops.aten.select_copy.int(transpose_copy_3, 0, 0); transpose_copy_3 = None
transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None
return transpose_copy_4
""") # noqa: B950
def test_optional_tensor_list(self):
@ -643,9 +667,10 @@ def forward(self, arg0_1):
arange = torch.ops.aten.arange.default(4, device = device(type='cpu'), pin_memory = False)
arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
index_put = torch.ops.aten.index_put.default(view_copy, [arange], arange_1); view_copy = arange = arange_1 = None
view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2])
view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2]); index_put = None
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [8])
copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None
return index_put
return view_copy_2
""") # noqa: B950
def test_scalars(self):
@ -667,9 +692,10 @@ def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None
mul = torch.ops.aten.mul.Tensor(add, 2)
div = torch.ops.aten.div.Tensor(mul, 1); mul = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
mul = torch.ops.aten.mul.Tensor(view_copy_2, 2); view_copy_2 = None
div = torch.ops.aten.div.Tensor(mul, 1); mul = None
copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None
return div
""")
@ -773,33 +799,41 @@ def forward(self, arg0_1):
getitem = split_copy[0]
getitem_1 = split_copy[1]; split_copy = None
add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0); view_copy_1 = None
view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4])
view_copy_3 = torch.ops.aten.view_copy.default(add, [8]); add = None
view_copy_4 = torch.ops.aten.view_copy.default(view_copy_3, [2, 4]); view_copy_3 = None
transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_4, 1, 0); view_copy_4 = None
view_copy_2 = torch.ops.aten.view_copy.default(add, [8]); add = None
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [2, 4]); view_copy_2 = None
transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_3, 1, 0); view_copy_3 = None
unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None
squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None
slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = None
slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = add_1 = None
unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None
squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None
transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None
view_copy_5 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]); transpose_copy_2 = None
view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [4, 2]); view_copy_5 = None
view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [8])
view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [2, 4]); view_copy_7 = None
select_copy_1 = torch.ops.aten.select_copy.int(view_copy_8, 0, 0); view_copy_8 = None
view_copy_9 = torch.ops.aten.view_copy.default(view_copy_6, [8]); view_copy_6 = None
view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]); view_copy_9 = None
transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_10, 1, 0); view_copy_10 = None
view_copy_4 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]); transpose_copy_2 = None
view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 2]); view_copy_4 = None
view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [8])
view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [2, 4]); view_copy_6 = None
transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_7, 1, 0); view_copy_7 = None
unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0); transpose_copy_3 = None
squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None
split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None
getitem_2 = split_copy_1[0]
getitem_3 = split_copy_1[1]; split_copy_1 = None
view_copy_11 = torch.ops.aten.view_copy.default(getitem_2, [4]); getitem_2 = None
add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_11); select_copy_1 = view_copy_11 = None
return add_1
select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0); view_copy_1 = None
view_copy_8 = torch.ops.aten.view_copy.default(getitem_2, [4])
view_copy_9 = torch.ops.aten.view_copy.default(view_copy_5, [8])
view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]); view_copy_9 = None
select_copy_1 = torch.ops.aten.select_copy.int(view_copy_10, 0, 0); view_copy_10 = None
view_copy_11 = torch.ops.aten.view_copy.default(view_copy_5, [8]); view_copy_5 = None
view_copy_12 = torch.ops.aten.view_copy.default(view_copy_11, [2, 4]); view_copy_11 = None
transpose_copy_4 = torch.ops.aten.transpose_copy.int(view_copy_12, 1, 0); view_copy_12 = None
unsqueeze_copy_4 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_4, 0); transpose_copy_4 = None
squeeze_copy_4 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_4); unsqueeze_copy_4 = None
split_copy_2 = torch.ops.aten.split_copy.Tensor(squeeze_copy_4, 2); squeeze_copy_4 = None
getitem_4 = split_copy_2[0]
getitem_5 = split_copy_2[1]; split_copy_2 = None
view_copy_13 = torch.ops.aten.view_copy.default(getitem_4, [4]); getitem_4 = None
add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_13); select_copy_1 = view_copy_13 = None
return getitem_2
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
@ -818,10 +852,7 @@ def forward(self, arg0_1):
split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None
getitem = split[0]
getitem_1 = split[1]; split = None
add_1 = torch.ops.aten.add_.Tensor(getitem, ones); ones = None
select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None
clone = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format)
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None
add_1 = torch.ops.aten.add_.Tensor(getitem, ones); getitem = ones = None
view_2 = torch.ops.aten.view.default(add, [8]); add = None
view_3 = torch.ops.aten.view.default(view_2, [2, 4]); view_2 = None
transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0); view_3 = None
@ -832,11 +863,22 @@ def forward(self, arg0_1):
transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0); squeeze_2 = None
view_4 = torch.ops.aten.view.default(transpose_2, [8]); transpose_2 = None
view_5 = torch.ops.aten.view.default(view_4, [4, 2]); view_4 = None
view_6 = torch.ops.aten.view.default(view_5, [8]); view_5 = None
view_6 = torch.ops.aten.view.default(view_5, [8])
view_7 = torch.ops.aten.view.default(view_6, [2, 4]); view_6 = None
select_1 = torch.ops.aten.select.int(view_7, 0, 0); view_7 = None
transpose_3 = torch.ops.aten.transpose.int(view_7, 1, 0); view_7 = None
unsqueeze_3 = torch.ops.aten.unsqueeze.default(transpose_3, 0); transpose_3 = None
squeeze_3 = torch.ops.aten.squeeze.default(unsqueeze_3); unsqueeze_3 = None
split_1 = torch.ops.aten.split.Tensor(squeeze_3, 2); squeeze_3 = None
getitem_2 = split_1[0]
getitem_3 = split_1[1]; split_1 = None
select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None
clone = torch.ops.aten.clone.default(getitem_2, memory_format = torch.contiguous_format)
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None
view_8 = torch.ops.aten.view.default(view_5, [8]); view_5 = None
view_9 = torch.ops.aten.view.default(view_8, [2, 4]); view_8 = None
select_1 = torch.ops.aten.select.int(view_9, 0, 0); view_9 = None
add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = None
return getitem
return getitem_2
""")
def test_reapply_views_simple(self):
@ -856,10 +898,11 @@ def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view = torch.ops.aten.view.default(arg0_1, [4, 2])
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
view_1 = torch.ops.aten.view.default(add, [4, 2])
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
view_2 = torch.ops.aten.view.default(view_1, [4, 2])
mul = torch.ops.aten.mul.Tensor(view_1, view_1)
copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = None
return add
return view_2
""")
def test_aliases_maintained_after_pass_when_reapplying_views(self):
@ -904,10 +947,14 @@ def forward(self, arg0_1):
def forward(self, arg0_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None
add = torch.ops.aten.add.Tensor(copy, arg0_1); copy = arg0_1 = None
return add
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None
diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None
diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None
return diagonal_copy_2
""")
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
@ -917,10 +964,12 @@ def forward(self, arg0_1):
def forward(self, arg0_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
copy = torch.ops.aten.copy_.default(diagonal, arg0_1)
add = torch.ops.aten.add_.Tensor(diagonal, arg0_1); arg0_1 = None
return diagonal
diagonal = torch.ops.aten.diagonal.default(zeros)
copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None
diagonal_1 = torch.ops.aten.diagonal.default(zeros)
add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None
diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None
return diagonal_2
""")
# Test 2: copy_() with same dtype, different shape
@ -932,10 +981,14 @@ def forward(self, arg0_1):
def forward(self, arg0_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None
add = torch.ops.aten.add.Tensor(copy, arg0_1); copy = arg0_1 = None
return add
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None
diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None
diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None
return diagonal_copy_2
""")
reinplaced_logs = self.get_logs(f, torch.ones(1), reapply_views=True, run_reinplace=True)
@ -945,10 +998,12 @@ def forward(self, arg0_1):
def forward(self, arg0_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
copy = torch.ops.aten.copy_.default(diagonal, arg0_1)
add = torch.ops.aten.add_.Tensor(diagonal, arg0_1); arg0_1 = None
return diagonal
diagonal = torch.ops.aten.diagonal.default(zeros)
copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None
diagonal_1 = torch.ops.aten.diagonal.default(zeros)
add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None
diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None
return diagonal_2
""")
# Test 3: copy_() with different dtype, same shape
@ -960,10 +1015,14 @@ def forward(self, arg0_1):
def forward(self, arg0_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None
add = torch.ops.aten.add.Tensor(copy, arg0_1); copy = arg0_1 = None
return add
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None
diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None
diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None
return diagonal_copy_2
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True)
@ -973,10 +1032,12 @@ def forward(self, arg0_1):
def forward(self, arg0_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
copy = torch.ops.aten.copy_.default(diagonal, arg0_1)
add = torch.ops.aten.add_.Tensor(diagonal, arg0_1); arg0_1 = None
return diagonal
diagonal = torch.ops.aten.diagonal.default(zeros)
copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None
diagonal_1 = torch.ops.aten.diagonal.default(zeros)
add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None
diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None
return diagonal_2
""") # noqa: B950
# Test 4: copy_() with different dtype, different shape
@ -988,10 +1049,14 @@ def forward(self, arg0_1):
def forward(self, arg0_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None
add = torch.ops.aten.add.Tensor(copy, arg0_1); copy = arg0_1 = None
return add
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None
diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None
diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None
return diagonal_copy_2
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True)
@ -1001,10 +1066,12 @@ def forward(self, arg0_1):
def forward(self, arg0_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
copy = torch.ops.aten.copy_.default(diagonal, arg0_1)
add = torch.ops.aten.add_.Tensor(diagonal, arg0_1); arg0_1 = None
return diagonal
diagonal = torch.ops.aten.diagonal.default(zeros)
copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None
diagonal_1 = torch.ops.aten.diagonal.default(zeros)
add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None
diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None
return diagonal_2
""") # noqa: B950
def test_expand_symint(self):
@ -1042,6 +1109,7 @@ def forward(self, arg0_1):
diagonal_copy = torch.ops.aten.diagonal_copy.default(add)
fill = torch.ops.aten.fill.Scalar(diagonal_copy, 0); diagonal_copy = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill); add = fill = None
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
return diagonal_scatter
""")
@ -1054,6 +1122,7 @@ def forward(self, arg0_1):
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
diagonal = torch.ops.aten.diagonal.default(add)
fill = torch.ops.aten.fill_.Scalar(diagonal, 0); diagonal = None
diagonal_1 = torch.ops.aten.diagonal.default(add)
return add
""")
@ -1086,9 +1155,12 @@ def forward(self, arg0_1):
view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]); add_1 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None
view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]); as_strided_scatter = None
view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]); view_copy_4 = None
view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4])
as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]); view_copy_5 = None
add_2 = torch.ops.aten.add.Tensor(as_strided_copy_2, 1); as_strided_copy_2 = None
view_copy_6 = torch.ops.aten.view_copy.default(as_strided_copy_2, [-1]); as_strided_copy_2 = None
view_copy_7 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]); view_copy_4 = None
as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(view_copy_7, [3, 3], [3, 1]); view_copy_7 = None
add_2 = torch.ops.aten.add.Tensor(as_strided_copy_3, 1); as_strided_copy_3 = None
return add_2
""") # noqa: B950
@ -1108,10 +1180,13 @@ def forward(self, arg0_1):
as_strided_1 = torch.ops.aten.as_strided.default(view_2, [3, 3], [3, 1])
view_3 = torch.ops.aten.view.default(view_1, [3, 3]); view_1 = None
view_4 = torch.ops.aten.view.default(view_2, [8, 2]); view_2 = None
view_5 = torch.ops.aten.view.default(view_4, [4, 4]); view_4 = None
view_5 = torch.ops.aten.view.default(view_4, [4, 4])
as_strided_2 = torch.ops.aten.as_strided.default(view_5, [3, 3], [3, 1]); view_5 = None
add_2 = torch.ops.aten.add_.Tensor(as_strided_2, 1)
return as_strided_2
view_6 = torch.ops.aten.view.default(as_strided_2, [-1]); as_strided_2 = None
view_7 = torch.ops.aten.view.default(view_4, [4, 4]); view_4 = None
as_strided_3 = torch.ops.aten.as_strided.default(view_7, [3, 3], [3, 1]); view_7 = None
add_2 = torch.ops.aten.add_.Tensor(as_strided_3, 1)
return as_strided_3
""")
def test_resize_larger_valid(self):
@ -1143,6 +1218,7 @@ def forward(self, arg0_1):
view_copy = torch.ops.aten.view_copy.default(resize, [25]); resize = None
fill = torch.ops.aten.fill.Scalar(view_copy, 1); view_copy = None
view_copy_1 = torch.ops.aten.view_copy.default(fill, [5, 5]); fill = None
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [25])
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1)
return (view_copy_1, add_1)
""")
@ -1158,6 +1234,7 @@ def forward(self, arg0_1):
view = torch.ops.aten.view.default(add, [25]); add = None
fill = torch.ops.aten.fill_.Scalar(view, 1)
view_1 = torch.ops.aten.view.default(view, [5, 5]); view = None
view_2 = torch.ops.aten.view.default(view_1, [25])
add_1 = torch.ops.aten.add.Tensor(view_1, 1)
return (view_1, add_1)
""")
@ -1241,6 +1318,7 @@ def forward(self, arg0_1):
select_copy = torch.ops.aten.select_copy.int(zeros, 0, 5)
fill = torch.ops.aten.fill.Scalar(select_copy, 1); select_copy = None
select_scatter = torch.ops.aten.select_scatter.default(zeros, fill, 0, 5); zeros = fill = None
select_copy_1 = torch.ops.aten.select_copy.int(select_scatter, 0, 5)
return select_scatter
""") # noqa: B950
@ -1253,6 +1331,7 @@ def forward(self, arg0_1):
zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
select = torch.ops.aten.select.int(zeros, 0, 5)
fill = torch.ops.aten.fill_.Scalar(select, 1); select = None
select_1 = torch.ops.aten.select.int(zeros, 0, 5)
return zeros
""")
@ -1290,16 +1369,18 @@ def forward(self, arg0_1, arg1_1, arg2_1):
view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); getitem_3 = None
mean = torch.ops.aten.mean.dim(view_copy_2, [0]); view_copy_2 = None
copy = torch.ops.aten.copy.default(alias_copy, mean); alias_copy = mean = None
alias_copy_1 = torch.ops.aten.alias_copy.default(arg2_1)
alias_copy_1 = torch.ops.aten.alias_copy.default(copy); copy = None
alias_copy_2 = torch.ops.aten.alias_copy.default(alias_copy_1)
alias_copy_3 = torch.ops.aten.alias_copy.default(arg2_1)
view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100])
view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); getitem_4 = None
mean_1 = torch.ops.aten.mean.dim(view_copy_4, [0]); view_copy_4 = None
copy_1 = torch.ops.aten.copy.default(alias_copy_1, mean_1); alias_copy_1 = mean_1 = None
copy_1 = torch.ops.aten.copy.default(alias_copy_3, mean_1); alias_copy_3 = mean_1 = None
alias_copy_4 = torch.ops.aten.alias_copy.default(copy_1); copy_1 = None
alias_copy_5 = torch.ops.aten.alias_copy.default(alias_copy_4)
view_copy_5 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]); getitem = None
alias_copy_2 = torch.ops.aten.alias_copy.default(copy); copy = None
copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_2); arg1_1 = alias_copy_2 = None
alias_copy_3 = torch.ops.aten.alias_copy.default(copy_1); copy_1 = None
copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_3); arg2_1 = alias_copy_3 = None
copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1); arg1_1 = alias_copy_1 = None
copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4); arg2_1 = alias_copy_4 = None
return view_copy_5
""") # noqa: B950
@ -1327,16 +1408,18 @@ def forward(self, arg0_1, arg1_1, arg2_1):
view_2 = torch.ops.aten.view.default(getitem_3, [20, 100]); getitem_3 = None
mean = torch.ops.aten.mean.dim(view_2, [0]); view_2 = None
copy = torch.ops.aten.copy.default(alias, mean); alias = mean = None
alias_1 = torch.ops.aten.alias.default(arg2_1)
alias_1 = torch.ops.aten.alias.default(copy); copy = None
alias_2 = torch.ops.aten.alias.default(alias_1)
alias_3 = torch.ops.aten.alias.default(arg2_1)
view_3 = torch.ops.aten.view.default(getitem_4, [20, 100])
view_4 = torch.ops.aten.view.default(getitem_4, [20, 100]); getitem_4 = None
mean_1 = torch.ops.aten.mean.dim(view_4, [0]); view_4 = None
copy_1 = torch.ops.aten.copy.default(alias_1, mean_1); alias_1 = mean_1 = None
copy_1 = torch.ops.aten.copy.default(alias_3, mean_1); alias_3 = mean_1 = None
alias_4 = torch.ops.aten.alias.default(copy_1); copy_1 = None
alias_5 = torch.ops.aten.alias.default(alias_4)
view_5 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]); getitem = None
alias_2 = torch.ops.aten.alias.default(copy); copy = None
copy_ = torch.ops.aten.copy_.default(arg1_1, alias_2); arg1_1 = alias_2 = None
alias_3 = torch.ops.aten.alias.default(copy_1); copy_1 = None
copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_3); arg2_1 = alias_3 = None
copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1); arg1_1 = alias_1 = None
copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4); arg2_1 = alias_4 = None
return view_5
""") # noqa: B950

View File

@ -154,7 +154,10 @@ def forward(self, a__1):
view_4 = torch.ops.aten.view.default(view_2, []); view_2 = None
view_5 = torch.ops.aten.view.default(view_3, [4]); view_3 = None
view_6 = torch.ops.aten.view.default(view_5, [-1])
add_1 = torch.ops.aten.add_.Tensor(view_5, view_6); view_6 = None
select_2 = torch.ops.aten.select.int(view_6, 0, 0); view_6 = None
view_7 = torch.ops.aten.view.default(select_2, [-1]); select_2 = None
view_8 = torch.ops.aten.view.default(view_5, [-1])
add_1 = torch.ops.aten.add_.Tensor(view_5, view_8); view_8 = None
return view_5
""")
@ -187,6 +190,9 @@ def forward(self, a__1):
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = None
slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
select_2 = torch.ops.aten.select.int(slice_2, 1, 1); slice_2 = None
slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
select_3 = torch.ops.aten.select.int(slice_3, 1, 1); slice_3 = None
select_4 = torch.ops.aten.select.int(select_3, 0, 1); select_3 = None
return clone
""")
@ -347,6 +353,8 @@ def forward(self):
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None
copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = None
slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
slice_4 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807); slice_4 = None
return zeros
""")

View File

@ -1344,6 +1344,8 @@ def _maybe_unwrap_functional_tensor(maybe_tensor, *, reapply_views: bool):
# This can happen if we functionalize a fn that returns a global,
# which was never wrapped properly.
return maybe_tensor
# Sync any pending updates on the output tensor
torch._sync(maybe_tensor)
return _unwrap_functional_tensor(maybe_tensor, reapply_views)

View File

@ -1654,12 +1654,12 @@ def meta_select(self, dim, index):
@register_meta(aten.select_scatter.default)
def meta_select_scatter(self, src, dim, index):
return torch.empty_like(self)
return utils.clone_preserve_strides(self)
@register_meta(aten.slice_scatter.default)
def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
return torch.empty_like(self)
return utils.clone_preserve_strides(self)
# TODO: Deduplicate this with canonicalize_dim

View File

@ -1828,7 +1828,7 @@ def _as_strided_scatter_meta(
lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
)
return _clone_meta(input)
return utils.clone_preserve_strides(input)
_as_strided_scatter_doc = """

View File

@ -1658,3 +1658,22 @@ def device_or_default(device: Optional[torch.device]) -> torch.device:
def layout_or_default(layout: Optional[torch.layout]) -> torch.layout:
return layout if layout is not None else torch.strided
def clone_preserve_strides(x):
needed_size = compute_required_storage_length(
x.size(), x.stride(), x.storage_offset()
)
# Our eager implementations for *_scatter ops are all primitives w.r.t autograd,
# so these as_strided() calls are not seen by autograd.
# We need to mimic this behavior in our ref/prim implementations.
# TODO: a better way to handle this would be with a new op, "_unsafe_as_strided"
# We should revisit this when we add a compositional as_strided op,
# and also as part of https://github.com/pytorch/pytorch/issues/90507
try:
old = torch._C._dispatch_tls_is_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView)
torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView, True)
buffer = torch.as_strided(x, (needed_size,), (1,), 0).clone()
return torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
finally:
torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView, old)

View File

@ -275,6 +275,7 @@ def backwards_not_supported(prim):
def redispatch_prim(args, kwargs):
g = torch._C._AutoDispatchBelowAutograd()
try:
old = torch._C._dispatch_tls_is_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView)
return prim(*args, **kwargs)
finally:
del g

View File

@ -3699,7 +3699,7 @@ def diagonal_scatter(
dim1: int = 0,
dim2: int = 1,
) -> TensorLikeType:
out = input.clone()
out = utils.clone_preserve_strides(input)
diag = out.diagonal(offset, dim1, dim2)
check(
diag.shape == src.shape,

View File

@ -10824,6 +10824,7 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip('Fails in most cases, passes on LAZY for some reason'), 'TestCommon', 'test_variant_consistency_eager'), # noqa: B950
DecorateInfo(unittest.skip('Fails on cuda + rocm'), 'TestCommon', 'test_complex_half_reference_testing'),
DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'),
DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
DecorateInfo(unittest.skip('Passes on complex128 and float64 only'), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
# AssertionError: Tensor-likes are not close! (new_empty_strided.default)
DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'),)),
@ -18436,6 +18437,8 @@ python_ref_db = [
"_refs.as_strided_scatter",
torch_opinfo_name="as_strided_scatter",
supports_nvfuser=False,
# returns a view of an intermediate tensor (as_strided)
validate_view_consistency=False,
),
PythonRefInfo(
"_refs.broadcast_shapes",
@ -18509,6 +18512,8 @@ python_ref_db = [
torch_opinfo_name="diagonal_scatter",
supports_out=True,
supports_nvfuser=False,
# returns a view of an intermediate tensor (as_strided)
validate_view_consistency=False,
),
PythonRefInfo(
"_refs.diag_embed",

View File

@ -351,9 +351,9 @@ def emit_view_functionalization_body(
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
{return_type} reference_tensor_output;
if (compute_reference_meta) {{
{meta_conversion_str}
at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
{meta_conversion_str}
reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
}}
// This function adds the above view meta to the current tensor and replays them off the base,
@ -387,9 +387,9 @@ def emit_view_functionalization_body(
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
{return_type} reference_tensor_output;
if (compute_reference_meta) {{
{meta_conversion_str}
at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
{meta_conversion_str}
reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
}}
{return_type} tmp_output;
@ -500,7 +500,8 @@ def wrap_propagate_mutations_and_return(
updates.append(
f"""\
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
at::functionalization::impl::commit_update({outer_arg});"""
at::functionalization::impl::commit_update({outer_arg});
at::functionalization::impl::sync({outer_arg});"""
)
# Finally, we return:
@ -611,9 +612,9 @@ def emit_inplace_functionalization_body(
// Before converting the mutable op to its functional variant, run meta tensors through the original op.
// This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants.
// (We can only do this for inplace ops today though, because they technicaly all support meta tensors).
{meta_conversion_str}
at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
{meta_conversion_str}
at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(a.name for a in meta_call_ctx)});
}}
{unwrap_tensor_args_str}