mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[inductor] modify index_reduce to pass opinfo tests (#106429)
1. add a python meta registration, to fix an issue with the forward pass. The problem was that previously, the C++ meta registration calls [numel()](7b14a14e27/aten/src/ATen/native/TensorAdvancedIndexing.cpp (L329)) which fails (LMK if it's better to fix the C++ implementation to not do this check)
2. Modify the backward to fix an issue in the backward. The backward is not a custom op - it's a custom manual backward implementation. In particular, there's some situations that don't support double backward; the check for whether double backward is allowed requires a .item() call. To fix the meta/fake tensor case, this PR will avoid setting the double backward error only if `GradMode::is_enabled()` - which shouldn't be turned on in PT2.
3. Update skips.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106429
Approved by: https://github.com/zou3519
This commit is contained in:
parent
a14d99bb6c
commit
393e9eed90
|
|
@ -2754,7 +2754,6 @@ aot_autograd_failures = {
|
|||
xfail('quantile'),
|
||||
xfail('nanquantile'),
|
||||
xfail('narrow'),
|
||||
xfail('index_reduce'),
|
||||
xfail('istft'),
|
||||
xfail('linalg.eig'),
|
||||
xfail('scatter_reduce', 'prod'),
|
||||
|
|
|
|||
|
|
@ -208,7 +208,6 @@ inductor_expected_failures_single_sample["cpu"] = {
|
|||
"complex": {f16},
|
||||
"exponential": {f16},
|
||||
"geometric": {f16},
|
||||
"index_reduce": {f16, f32, f64},
|
||||
"linalg.eigh": {f32, f64},
|
||||
"linalg.eigvalsh": {f32, f64},
|
||||
"log_normal": {f16},
|
||||
|
|
@ -276,7 +275,6 @@ inductor_expected_failures_single_sample["cuda"] = {
|
|||
"fft.rfft2": {f16},
|
||||
"fft.rfftn": {f16},
|
||||
"geometric": {f16, f32, f64, i32, i64},
|
||||
"index_reduce": {f16, f32, f64},
|
||||
"kron": {f16},
|
||||
"linalg.eig": {f32, f64},
|
||||
"linalg.eigh": {f32, f64},
|
||||
|
|
|
|||
|
|
@ -1520,7 +1520,6 @@ symbolic_tensor_failures = {
|
|||
xfail('histc', ''), # Could not run 'aten::histc' with arguments from the 'Meta' backend. This could be because...
|
||||
xfail('histogram', ''), # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c...
|
||||
xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('index_reduce', ''), # Float
|
||||
xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition
|
||||
xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
|
||||
|
|
|
|||
|
|
@ -309,6 +309,32 @@ def meta_unsqueeze_(self, dim):
|
|||
return self
|
||||
|
||||
|
||||
@register_meta(aten.index_reduce.default)
|
||||
def meta_index_reduce(
|
||||
self: Tensor,
|
||||
dim: int,
|
||||
index: Tensor,
|
||||
source: torch.Tensor,
|
||||
reduce: str,
|
||||
*,
|
||||
include_self: bool = True,
|
||||
) -> Tensor:
|
||||
return torch.empty_like(self, memory_format=torch.contiguous_format)
|
||||
|
||||
|
||||
@register_meta(aten.index_reduce_.default)
|
||||
def meta_index_reduce_(
|
||||
self: Tensor,
|
||||
dim: int,
|
||||
index: Tensor,
|
||||
source: torch.Tensor,
|
||||
reduce: str,
|
||||
*,
|
||||
include_self: bool = True,
|
||||
) -> Tensor:
|
||||
return self
|
||||
|
||||
|
||||
# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
|
||||
@register_meta(aten.index_select.default)
|
||||
def meta_index_select(self, dim, index):
|
||||
|
|
|
|||
|
|
@ -6823,7 +6823,9 @@ std::tuple<Tensor, Tensor> index_reduce_backward(
|
|||
(grad * masked_src_result).index_select(dim, index),
|
||||
(grad * result).index_select(dim, index) /
|
||||
source.masked_fill(src_zero, 1));
|
||||
if ((src_num_zeros > 1).any().item<bool>()) {
|
||||
// GradMode::is_enabled() - adding the autograd Node is a no-op if autograd
|
||||
// is disabled this also avoids having the item() call in the usual case
|
||||
if (GradMode::is_enabled() && (src_num_zeros > 1).any().item<bool>()) {
|
||||
auto node = std::make_shared<DelayedError>(
|
||||
"index_reduce(): Double backward is unsupported for source when >1 zeros in source are scattered to the same position in self",
|
||||
/* num inputs */ 1);
|
||||
|
|
|
|||
|
|
@ -15621,10 +15621,6 @@ op_db: List[OpInfo] = [
|
|||
OpInfo('index_reduce',
|
||||
dtypes=all_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=True,
|
||||
skips=(
|
||||
# Pre-existing condition (calls .item); needs to be fixed
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
|
||||
),
|
||||
sample_inputs_func=sample_inputs_index_reduce),
|
||||
OpInfo('__getitem__',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user