[AOTI] Fix a None as index codegen issue (#118187)

Summary: Fix a ABI-compatible codegen issue when index_put has None in its indices.

Differential Revision: [D53047489](https://our.internmc.facebook.com/intern/diff/D53047489)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118187
Approved by: https://github.com/chenyang78
ghstack dependencies: #118168, #118169
This commit is contained in:
Bin Bao 2024-01-24 10:26:56 -08:00 committed by PyTorch MergeBot
parent d1e661a1ce
commit ee1dbb2acf
3 changed files with 47 additions and 19 deletions

View File

@ -20,6 +20,7 @@ from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_quantization import skip_if_no_torchvision
from torch.testing._internal.common_utils import (
DeterministicGuard,
IS_CI,
IS_FBCODE,
IS_WINDOWS,
@ -1324,6 +1325,9 @@ class AOTInductorTestsTemplate:
self.check_model(Model(), inputs)
def test_index_put_fallback(self):
# index_put falls back in the deterministic mode
with DeterministicGuard(True):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
@ -1334,7 +1338,9 @@ class AOTInductorTestsTemplate:
indices: Tuple[torch.Tensor],
values: torch.Tensor,
):
return torch.index_put(self_tensor, indices, values, accumulate=True)
return torch.index_put(
self_tensor, indices, values, accumulate=True
)
inputs = (
torch.ones(4, device=self.device, dtype=torch.int64),
@ -1597,6 +1603,24 @@ class AOTInductorTestsTemplate:
)
self.check_model(Model(), example_inputs)
def test_index_put_with_none_index(self):
# index_put falls back in the deterministic mode
with DeterministicGuard(True):
class Model(torch.nn.Module):
def forward(self, x, i1, i2, y):
return torch.ops.aten.index_put(
x, (None, None, i1, i2), y, accumulate=True
)
example_inputs = (
torch.rand(8, 192, 30, 30, device=self.device),
torch.zeros(3, 14, 1, 1, dtype=torch.int64, device=self.device),
torch.ones(3, 14, dtype=torch.int64, device=self.device),
torch.randn(8, 192, 3, 14, 3, 14, device=self.device),
)
self.check_model(Model(), example_inputs)
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
@ -1677,6 +1701,7 @@ CPU_TEST_FAILURES = {
"test_scatter_fallback": fail_stack_allocation(is_skip=True),
"test_scatter_reduce_fallback": fail_stack_allocation(is_skip=True),
"test_index_put_fallback": fail_stack_allocation(is_skip=True),
"test_index_put_with_none_index": fail_stack_allocation(is_skip=True),
# C++ compile error, need for aoti_torch___scaled_dot_product_flash_attention_for_cpu
"test_sdpa": fail_with_and_without_stack_allocation(is_skip=True),
"test_sdpa_2": fail_with_and_without_stack_allocation(is_skip=True),

View File

@ -1352,7 +1352,9 @@ class CppWrapperCodeGen(WrapperCodeGen):
self.closed_bracket = "}"
self.comment = "//"
self.namespace = "at::"
self.none_str = "at::Tensor()"
self.none_str = (
"nullptr" if config.aot_inductor.abi_compatible else "at::Tensor()"
)
self.extern_call_ops = set()
self.size = "sizes()"
self.stride = "strides()"

View File

@ -698,10 +698,11 @@ AOTITorchError aoti_torch_index_put_out(
const AtenTensorHandle values,
bool accumulate) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
c10::List<std::optional<at::Tensor>> indices_;
c10::List<c10::optional<at::Tensor>> indices_;
indices_.reserve(num_indices);
for (size_t i = 0; i < num_indices; i++) {
indices_.emplace_back(*tensor_handle_to_tensor_pointer(indices[i]));
indices_.emplace_back(
pointer_to_optional(tensor_handle_to_tensor_pointer(indices[i])));
}
at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);