[inductor] modify index_reduce to pass opinfo tests (#106429)

1. add a python meta registration, to fix an issue with the forward pass. The problem was that previously, the C++ meta registration calls [numel()](7b14a14e27/aten/src/ATen/native/TensorAdvancedIndexing.cpp (L329)) which fails (LMK if it's better to fix the C++ implementation to not do this check)
2. Modify the backward to fix an issue in the backward. The backward is not a custom op - it's a custom manual backward implementation. In particular, there's some situations that don't support double backward; the check for whether double backward is allowed requires a .item() call. To fix the meta/fake tensor case, this PR will avoid setting the double backward error only if `GradMode::is_enabled()` - which shouldn't be turned on in PT2.
3. Update skips.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106429
Approved by: https://github.com/zou3519
This commit is contained in:
David Berard 2023-08-10 18:14:00 +00:00 committed by PyTorch MergeBot
parent a14d99bb6c
commit 393e9eed90
6 changed files with 29 additions and 9 deletions

View File

@ -2754,7 +2754,6 @@ aot_autograd_failures = {
xfail('quantile'),
xfail('nanquantile'),
xfail('narrow'),
xfail('index_reduce'),
xfail('istft'),
xfail('linalg.eig'),
xfail('scatter_reduce', 'prod'),

View File

@ -208,7 +208,6 @@ inductor_expected_failures_single_sample["cpu"] = {
"complex": {f16},
"exponential": {f16},
"geometric": {f16},
"index_reduce": {f16, f32, f64},
"linalg.eigh": {f32, f64},
"linalg.eigvalsh": {f32, f64},
"log_normal": {f16},
@ -276,7 +275,6 @@ inductor_expected_failures_single_sample["cuda"] = {
"fft.rfft2": {f16},
"fft.rfftn": {f16},
"geometric": {f16, f32, f64, i32, i64},
"index_reduce": {f16, f32, f64},
"kron": {f16},
"linalg.eig": {f32, f64},
"linalg.eigh": {f32, f64},

View File

@ -1520,7 +1520,6 @@ symbolic_tensor_failures = {
xfail('histc', ''), # Could not run 'aten::histc' with arguments from the 'Meta' backend. This could be because...
xfail('histogram', ''), # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c...
xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition
xfail('index_reduce', ''), # Float
xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition
xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition

View File

@ -309,6 +309,32 @@ def meta_unsqueeze_(self, dim):
return self
@register_meta(aten.index_reduce.default)
def meta_index_reduce(
self: Tensor,
dim: int,
index: Tensor,
source: torch.Tensor,
reduce: str,
*,
include_self: bool = True,
) -> Tensor:
return torch.empty_like(self, memory_format=torch.contiguous_format)
@register_meta(aten.index_reduce_.default)
def meta_index_reduce_(
self: Tensor,
dim: int,
index: Tensor,
source: torch.Tensor,
reduce: str,
*,
include_self: bool = True,
) -> Tensor:
return self
# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
@register_meta(aten.index_select.default)
def meta_index_select(self, dim, index):

View File

@ -6823,7 +6823,9 @@ std::tuple<Tensor, Tensor> index_reduce_backward(
(grad * masked_src_result).index_select(dim, index),
(grad * result).index_select(dim, index) /
source.masked_fill(src_zero, 1));
if ((src_num_zeros > 1).any().item<bool>()) {
// GradMode::is_enabled() - adding the autograd Node is a no-op if autograd
// is disabled this also avoids having the item() call in the usual case
if (GradMode::is_enabled() && (src_num_zeros > 1).any().item<bool>()) {
auto node = std::make_shared<DelayedError>(
"index_reduce(): Double backward is unsupported for source when >1 zeros in source are scattered to the same position in self",
/* num inputs */ 1);

View File

@ -15621,10 +15621,6 @@ op_db: List[OpInfo] = [
OpInfo('index_reduce',
dtypes=all_types_and(torch.float16, torch.bfloat16),
supports_out=True,
skips=(
# Pre-existing condition (calls .item); needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
),
sample_inputs_func=sample_inputs_index_reduce),
OpInfo('__getitem__',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),