mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[AOTI] Fix a None as index codegen issue (#118187)
Summary: Fix a ABI-compatible codegen issue when index_put has None in its indices. Differential Revision: [D53047489](https://our.internmc.facebook.com/intern/diff/D53047489) Pull Request resolved: https://github.com/pytorch/pytorch/pull/118187 Approved by: https://github.com/chenyang78 ghstack dependencies: #118168, #118169
This commit is contained in:
parent
d1e661a1ce
commit
ee1dbb2acf
|
|
@ -20,6 +20,7 @@ from torch.testing._internal import common_utils
|
||||||
from torch.testing._internal.common_cuda import SM80OrLater
|
from torch.testing._internal.common_cuda import SM80OrLater
|
||||||
from torch.testing._internal.common_quantization import skip_if_no_torchvision
|
from torch.testing._internal.common_quantization import skip_if_no_torchvision
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
|
DeterministicGuard,
|
||||||
IS_CI,
|
IS_CI,
|
||||||
IS_FBCODE,
|
IS_FBCODE,
|
||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
|
|
@ -1324,25 +1325,30 @@ class AOTInductorTestsTemplate:
|
||||||
self.check_model(Model(), inputs)
|
self.check_model(Model(), inputs)
|
||||||
|
|
||||||
def test_index_put_fallback(self):
|
def test_index_put_fallback(self):
|
||||||
class Model(torch.nn.Module):
|
# index_put falls back in the deterministic mode
|
||||||
def __init__(self):
|
with DeterministicGuard(True):
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
class Model(torch.nn.Module):
|
||||||
self,
|
def __init__(self):
|
||||||
self_tensor: torch.Tensor,
|
super().__init__()
|
||||||
indices: Tuple[torch.Tensor],
|
|
||||||
values: torch.Tensor,
|
|
||||||
):
|
|
||||||
return torch.index_put(self_tensor, indices, values, accumulate=True)
|
|
||||||
|
|
||||||
inputs = (
|
def forward(
|
||||||
torch.ones(4, device=self.device, dtype=torch.int64),
|
self,
|
||||||
(torch.tensor([1, 1, 2, 2], device=self.device, dtype=torch.bool),),
|
self_tensor: torch.Tensor,
|
||||||
torch.ones(4, device=self.device, dtype=torch.int64),
|
indices: Tuple[torch.Tensor],
|
||||||
)
|
values: torch.Tensor,
|
||||||
|
):
|
||||||
|
return torch.index_put(
|
||||||
|
self_tensor, indices, values, accumulate=True
|
||||||
|
)
|
||||||
|
|
||||||
self.check_model(Model(), inputs)
|
inputs = (
|
||||||
|
torch.ones(4, device=self.device, dtype=torch.int64),
|
||||||
|
(torch.tensor([1, 1, 2, 2], device=self.device, dtype=torch.bool),),
|
||||||
|
torch.ones(4, device=self.device, dtype=torch.int64),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.check_model(Model(), inputs)
|
||||||
|
|
||||||
def test_convolution(self):
|
def test_convolution(self):
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
|
|
@ -1597,6 +1603,24 @@ class AOTInductorTestsTemplate:
|
||||||
)
|
)
|
||||||
self.check_model(Model(), example_inputs)
|
self.check_model(Model(), example_inputs)
|
||||||
|
|
||||||
|
def test_index_put_with_none_index(self):
|
||||||
|
# index_put falls back in the deterministic mode
|
||||||
|
with DeterministicGuard(True):
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def forward(self, x, i1, i2, y):
|
||||||
|
return torch.ops.aten.index_put(
|
||||||
|
x, (None, None, i1, i2), y, accumulate=True
|
||||||
|
)
|
||||||
|
|
||||||
|
example_inputs = (
|
||||||
|
torch.rand(8, 192, 30, 30, device=self.device),
|
||||||
|
torch.zeros(3, 14, 1, 1, dtype=torch.int64, device=self.device),
|
||||||
|
torch.ones(3, 14, dtype=torch.int64, device=self.device),
|
||||||
|
torch.randn(8, 192, 3, 14, 3, 14, device=self.device),
|
||||||
|
)
|
||||||
|
self.check_model(Model(), example_inputs)
|
||||||
|
|
||||||
|
|
||||||
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
|
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
|
||||||
|
|
||||||
|
|
@ -1677,6 +1701,7 @@ CPU_TEST_FAILURES = {
|
||||||
"test_scatter_fallback": fail_stack_allocation(is_skip=True),
|
"test_scatter_fallback": fail_stack_allocation(is_skip=True),
|
||||||
"test_scatter_reduce_fallback": fail_stack_allocation(is_skip=True),
|
"test_scatter_reduce_fallback": fail_stack_allocation(is_skip=True),
|
||||||
"test_index_put_fallback": fail_stack_allocation(is_skip=True),
|
"test_index_put_fallback": fail_stack_allocation(is_skip=True),
|
||||||
|
"test_index_put_with_none_index": fail_stack_allocation(is_skip=True),
|
||||||
# C++ compile error, need for aoti_torch___scaled_dot_product_flash_attention_for_cpu
|
# C++ compile error, need for aoti_torch___scaled_dot_product_flash_attention_for_cpu
|
||||||
"test_sdpa": fail_with_and_without_stack_allocation(is_skip=True),
|
"test_sdpa": fail_with_and_without_stack_allocation(is_skip=True),
|
||||||
"test_sdpa_2": fail_with_and_without_stack_allocation(is_skip=True),
|
"test_sdpa_2": fail_with_and_without_stack_allocation(is_skip=True),
|
||||||
|
|
|
||||||
|
|
@ -1352,7 +1352,9 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
||||||
self.closed_bracket = "}"
|
self.closed_bracket = "}"
|
||||||
self.comment = "//"
|
self.comment = "//"
|
||||||
self.namespace = "at::"
|
self.namespace = "at::"
|
||||||
self.none_str = "at::Tensor()"
|
self.none_str = (
|
||||||
|
"nullptr" if config.aot_inductor.abi_compatible else "at::Tensor()"
|
||||||
|
)
|
||||||
self.extern_call_ops = set()
|
self.extern_call_ops = set()
|
||||||
self.size = "sizes()"
|
self.size = "sizes()"
|
||||||
self.stride = "strides()"
|
self.stride = "strides()"
|
||||||
|
|
|
||||||
|
|
@ -698,10 +698,11 @@ AOTITorchError aoti_torch_index_put_out(
|
||||||
const AtenTensorHandle values,
|
const AtenTensorHandle values,
|
||||||
bool accumulate) {
|
bool accumulate) {
|
||||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
c10::List<std::optional<at::Tensor>> indices_;
|
c10::List<c10::optional<at::Tensor>> indices_;
|
||||||
indices_.reserve(num_indices);
|
indices_.reserve(num_indices);
|
||||||
for (size_t i = 0; i < num_indices; i++) {
|
for (size_t i = 0; i < num_indices; i++) {
|
||||||
indices_.emplace_back(*tensor_handle_to_tensor_pointer(indices[i]));
|
indices_.emplace_back(
|
||||||
|
pointer_to_optional(tensor_handle_to_tensor_pointer(indices[i])));
|
||||||
}
|
}
|
||||||
at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
|
at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
|
||||||
at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
|
at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user