mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Allow comparing device w/o index with device w/ index (#159665)
In the case where we have expected device "cuda" and given device "cuda:0" I think we should succeed? Pull Request resolved: https://github.com/pytorch/pytorch/pull/159665 Approved by: https://github.com/yushangdi
This commit is contained in:
parent
53e47af0f7
commit
fc340d0ca3
|
|
@ -24,6 +24,29 @@ static void _assert_match(const O& original, const C& compared, const std::strin
|
|||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
void _assert_match<c10::Device, std::optional<c10::Device>>(
|
||||
const c10::Device& original,
|
||||
const std::optional<c10::Device>& compared,
|
||||
const std::string& name) {
|
||||
if (compared) {
|
||||
const c10::Device& expected = compared.value();
|
||||
if (original.type() != expected.type()) {
|
||||
std::stringstream msg;
|
||||
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// If the expected device doesn't have an index (e.g., just "cuda"),
|
||||
// or if both devices have the same index, consider them equal
|
||||
if (expected.has_index() && original.has_index() && expected.index() != original.index()) {
|
||||
std::stringstream msg;
|
||||
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
|
||||
_assert_match(tensor.sym_sizes(), sizes, "sizes");
|
||||
_assert_match(tensor.sym_strides(), strides, "strides");
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ from torch.export.graph_signature import (
|
|||
OutputSpec,
|
||||
TensorArgument,
|
||||
)
|
||||
from torch.export.passes import move_to_device_pass
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.testing import FileCheck
|
||||
|
|
@ -15914,6 +15915,22 @@ def forward(self, x):
|
|||
len(list(new_ep.graph.nodes)[-1].args[0]), len(signature.output_specs)
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_assert_tensor_metadata_device_index(self):
|
||||
class N(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
x = x.float()
|
||||
y = y.float()
|
||||
return x + y
|
||||
|
||||
inp = (torch.randn(3, device="cuda"), torch.randn(3, device="cuda"))
|
||||
ep = export(N(), inp)
|
||||
ep = move_to_device_pass(ep, {"cuda:0": "cuda"})
|
||||
ep.module()(torch.randn(3, device="cuda:0"), torch.randn(3, device="cuda:0"))
|
||||
|
||||
def test_input_output_no_stacktrace(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user