[export] Allow comparing device w/o index with device w/ index (#159665)

In the case where we have expected device "cuda" and given device "cuda:0" I think we should succeed?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159665
Approved by: https://github.com/yushangdi
This commit is contained in:
angelayi 2025-08-04 17:00:07 +00:00 committed by PyTorch MergeBot
parent 53e47af0f7
commit fc340d0ca3
2 changed files with 40 additions and 0 deletions

View File

@ -24,6 +24,29 @@ static void _assert_match(const O& original, const C& compared, const std::strin
}
}
template<>
void _assert_match<c10::Device, std::optional<c10::Device>>(
const c10::Device& original,
const std::optional<c10::Device>& compared,
const std::string& name) {
if (compared) {
const c10::Device& expected = compared.value();
if (original.type() != expected.type()) {
std::stringstream msg;
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
throw std::runtime_error(msg.str());
}
// If the expected device doesn't have an index (e.g., just "cuda"),
// or if both devices have the same index, consider them equal
if (expected.has_index() && original.has_index() && expected.index() != original.index()) {
std::stringstream msg;
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
throw std::runtime_error(msg.str());
}
}
}
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
_assert_match(tensor.sym_sizes(), sizes, "sizes");
_assert_match(tensor.sym_strides(), strides, "strides");

View File

@ -59,6 +59,7 @@ from torch.export.graph_signature import (
OutputSpec,
TensorArgument,
)
from torch.export.passes import move_to_device_pass
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.testing import FileCheck
@ -15914,6 +15915,22 @@ def forward(self, x):
len(list(new_ep.graph.nodes)[-1].args[0]), len(signature.output_specs)
)
@requires_cuda
def test_assert_tensor_metadata_device_index(self):
class N(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
x = x.float()
y = y.float()
return x + y
inp = (torch.randn(3, device="cuda"), torch.randn(3, device="cuda"))
ep = export(N(), inp)
ep = move_to_device_pass(ep, {"cuda:0": "cuda"})
ep.module()(torch.randn(3, device="cuda:0"), torch.randn(3, device="cuda:0"))
def test_input_output_no_stacktrace(self):
class M(torch.nn.Module):
def forward(self, x):