mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Allow comparing device w/o index with device w/ index (#159665)
In the case where we have expected device "cuda" and given device "cuda:0" I think we should succeed? Pull Request resolved: https://github.com/pytorch/pytorch/pull/159665 Approved by: https://github.com/yushangdi
This commit is contained in:
parent
53e47af0f7
commit
fc340d0ca3
|
|
@ -24,6 +24,29 @@ static void _assert_match(const O& original, const C& compared, const std::strin
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
void _assert_match<c10::Device, std::optional<c10::Device>>(
|
||||||
|
const c10::Device& original,
|
||||||
|
const std::optional<c10::Device>& compared,
|
||||||
|
const std::string& name) {
|
||||||
|
if (compared) {
|
||||||
|
const c10::Device& expected = compared.value();
|
||||||
|
if (original.type() != expected.type()) {
|
||||||
|
std::stringstream msg;
|
||||||
|
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the expected device doesn't have an index (e.g., just "cuda"),
|
||||||
|
// or if both devices have the same index, consider them equal
|
||||||
|
if (expected.has_index() && original.has_index() && expected.index() != original.index()) {
|
||||||
|
std::stringstream msg;
|
||||||
|
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
|
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
|
||||||
_assert_match(tensor.sym_sizes(), sizes, "sizes");
|
_assert_match(tensor.sym_sizes(), sizes, "sizes");
|
||||||
_assert_match(tensor.sym_strides(), strides, "strides");
|
_assert_match(tensor.sym_strides(), strides, "strides");
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,7 @@ from torch.export.graph_signature import (
|
||||||
OutputSpec,
|
OutputSpec,
|
||||||
TensorArgument,
|
TensorArgument,
|
||||||
)
|
)
|
||||||
|
from torch.export.passes import move_to_device_pass
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
|
|
@ -15914,6 +15915,22 @@ def forward(self, x):
|
||||||
len(list(new_ep.graph.nodes)[-1].args[0]), len(signature.output_specs)
|
len(list(new_ep.graph.nodes)[-1].args[0]), len(signature.output_specs)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@requires_cuda
|
||||||
|
def test_assert_tensor_metadata_device_index(self):
|
||||||
|
class N(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
x = x.float()
|
||||||
|
y = y.float()
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
inp = (torch.randn(3, device="cuda"), torch.randn(3, device="cuda"))
|
||||||
|
ep = export(N(), inp)
|
||||||
|
ep = move_to_device_pass(ep, {"cuda:0": "cuda"})
|
||||||
|
ep.module()(torch.randn(3, device="cuda:0"), torch.randn(3, device="cuda:0"))
|
||||||
|
|
||||||
def test_input_output_no_stacktrace(self):
|
def test_input_output_no_stacktrace(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user