diff --git a/aten/src/ATen/native/ComparisonUtils.cpp b/aten/src/ATen/native/ComparisonUtils.cpp index 8739f45d8ad..13bef0a00b9 100644 --- a/aten/src/ATen/native/ComparisonUtils.cpp +++ b/aten/src/ATen/native/ComparisonUtils.cpp @@ -24,6 +24,29 @@ static void _assert_match(const O& original, const C& compared, const std::strin } } +template<> +void _assert_match>( + const c10::Device& original, + const std::optional& compared, + const std::string& name) { + if (compared) { + const c10::Device& expected = compared.value(); + if (original.type() != expected.type()) { + std::stringstream msg; + msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original; + throw std::runtime_error(msg.str()); + } + + // If the expected device doesn't have an index (e.g., just "cuda"), + // or if both devices have the same index, consider them equal + if (expected.has_index() && original.has_index() && expected.index() != original.index()) { + std::stringstream msg; + msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original; + throw std::runtime_error(msg.str()); + } + } +} + void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional dtype, std::optional device, std::optional layout) { _assert_match(tensor.sym_sizes(), sizes, "sizes"); _assert_match(tensor.sym_strides(), strides, "strides"); diff --git a/test/export/test_export.py b/test/export/test_export.py index 5127f45805d..63098e658fc 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -59,6 +59,7 @@ from torch.export.graph_signature import ( OutputSpec, TensorArgument, ) +from torch.export.passes import move_to_device_pass from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing import FileCheck @@ -15914,6 +15915,22 @@ def forward(self, x): len(list(new_ep.graph.nodes)[-1].args[0]), len(signature.output_specs) ) + @requires_cuda + def test_assert_tensor_metadata_device_index(self): + class N(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + x = x.float() + y = y.float() + return x + y + + inp = (torch.randn(3, device="cuda"), torch.randn(3, device="cuda")) + ep = export(N(), inp) + ep = move_to_device_pass(ep, {"cuda:0": "cuda"}) + ep.module()(torch.randn(3, device="cuda:0"), torch.randn(3, device="cuda:0")) + def test_input_output_no_stacktrace(self): class M(torch.nn.Module): def forward(self, x):