mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTI] Fix a None as index codegen issue (#118187)
Summary: Fix a ABI-compatible codegen issue when index_put has None in its indices. Differential Revision: [D53047489](https://our.internmc.facebook.com/intern/diff/D53047489) Pull Request resolved: https://github.com/pytorch/pytorch/pull/118187 Approved by: https://github.com/chenyang78 ghstack dependencies: #118168, #118169
This commit is contained in:
parent
d1e661a1ce
commit
ee1dbb2acf
|
|
@ -20,6 +20,7 @@ from torch.testing._internal import common_utils
|
|||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
from torch.testing._internal.common_quantization import skip_if_no_torchvision
|
||||
from torch.testing._internal.common_utils import (
|
||||
DeterministicGuard,
|
||||
IS_CI,
|
||||
IS_FBCODE,
|
||||
IS_WINDOWS,
|
||||
|
|
@ -1324,6 +1325,9 @@ class AOTInductorTestsTemplate:
|
|||
self.check_model(Model(), inputs)
|
||||
|
||||
def test_index_put_fallback(self):
|
||||
# index_put falls back in the deterministic mode
|
||||
with DeterministicGuard(True):
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -1334,7 +1338,9 @@ class AOTInductorTestsTemplate:
|
|||
indices: Tuple[torch.Tensor],
|
||||
values: torch.Tensor,
|
||||
):
|
||||
return torch.index_put(self_tensor, indices, values, accumulate=True)
|
||||
return torch.index_put(
|
||||
self_tensor, indices, values, accumulate=True
|
||||
)
|
||||
|
||||
inputs = (
|
||||
torch.ones(4, device=self.device, dtype=torch.int64),
|
||||
|
|
@ -1597,6 +1603,24 @@ class AOTInductorTestsTemplate:
|
|||
)
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
def test_index_put_with_none_index(self):
|
||||
# index_put falls back in the deterministic mode
|
||||
with DeterministicGuard(True):
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, i1, i2, y):
|
||||
return torch.ops.aten.index_put(
|
||||
x, (None, None, i1, i2), y, accumulate=True
|
||||
)
|
||||
|
||||
example_inputs = (
|
||||
torch.rand(8, 192, 30, 30, device=self.device),
|
||||
torch.zeros(3, 14, 1, 1, dtype=torch.int64, device=self.device),
|
||||
torch.ones(3, 14, dtype=torch.int64, device=self.device),
|
||||
torch.randn(8, 192, 3, 14, 3, 14, device=self.device),
|
||||
)
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
|
||||
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
|
||||
|
||||
|
|
@ -1677,6 +1701,7 @@ CPU_TEST_FAILURES = {
|
|||
"test_scatter_fallback": fail_stack_allocation(is_skip=True),
|
||||
"test_scatter_reduce_fallback": fail_stack_allocation(is_skip=True),
|
||||
"test_index_put_fallback": fail_stack_allocation(is_skip=True),
|
||||
"test_index_put_with_none_index": fail_stack_allocation(is_skip=True),
|
||||
# C++ compile error, need for aoti_torch___scaled_dot_product_flash_attention_for_cpu
|
||||
"test_sdpa": fail_with_and_without_stack_allocation(is_skip=True),
|
||||
"test_sdpa_2": fail_with_and_without_stack_allocation(is_skip=True),
|
||||
|
|
|
|||
|
|
@ -1352,7 +1352,9 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
self.closed_bracket = "}"
|
||||
self.comment = "//"
|
||||
self.namespace = "at::"
|
||||
self.none_str = "at::Tensor()"
|
||||
self.none_str = (
|
||||
"nullptr" if config.aot_inductor.abi_compatible else "at::Tensor()"
|
||||
)
|
||||
self.extern_call_ops = set()
|
||||
self.size = "sizes()"
|
||||
self.stride = "strides()"
|
||||
|
|
|
|||
|
|
@ -698,10 +698,11 @@ AOTITorchError aoti_torch_index_put_out(
|
|||
const AtenTensorHandle values,
|
||||
bool accumulate) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
c10::List<std::optional<at::Tensor>> indices_;
|
||||
c10::List<c10::optional<at::Tensor>> indices_;
|
||||
indices_.reserve(num_indices);
|
||||
for (size_t i = 0; i < num_indices; i++) {
|
||||
indices_.emplace_back(*tensor_handle_to_tensor_pointer(indices[i]));
|
||||
indices_.emplace_back(
|
||||
pointer_to_optional(tensor_handle_to_tensor_pointer(indices[i])));
|
||||
}
|
||||
at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
|
||||
at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user