pytorch/test/export/test_export_opinfo.py
Sherlock Huang a956c4ab1c Return NoOpDeviceGuardImpl in replace of CudaDeviceGuard when device is not available, or cpu-only build (#160532)
Summary:

To support exporting a cuda model on a CPU-only machine under fake tensor mode.
User commonly need to move sample inputs to the cuda device with .to("cuda:0") or .to("cuda") call.
This diff supports this.
I expect the following pattern to work
```
with FakeTensorMode(allow_non_fake_inputs=True):
    cuda_module = module.to("cuda:0")
    cuda_sample_inputs = tuple([x.to("cuda:0") for x in sample_inputs])
    with torch.no_grad():
        ep = torch.export.export(cuda_module, cuda_sample_inputs)
```

Test Plan:
CI

Rollback Plan:

Differential Revision: D80181887

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160532
Approved by: https://github.com/henryoier, https://github.com/ezyang
2025-09-13 01:50:51 +00:00

176 lines
5.4 KiB
Python

# Owner(s): ["oncall: export"]
# ruff: noqa: F841
# flake8: noqa
import itertools
import unittest
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
)
from torch.testing._internal.common_methods_invocations import (
op_db,
skip,
skipOps,
xfail,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils import _pytree as pytree
# following are failing with regular torch.export.export
export_failures = {
xfail("allclose"),
xfail("combinations"),
xfail("corrcoef"),
xfail("cov"),
xfail("equal"),
xfail("linalg.lstsq"),
xfail("linalg.lstsq", "grad_oriented"),
xfail("nn.functional.ctc_loss"),
xfail("nn.functional.gaussian_nll_loss"),
xfail("sparse.sampled_addmm"),
xfail("tensor_split"),
}
# following are failing fake export on cuda device
fake_export_failures = {
xfail("geqrf"),
xfail("histogram"),
xfail("masked.amax"),
xfail("masked.amin"),
xfail("masked.argmax"),
xfail("masked.argmin"),
xfail("masked.logaddexp"),
xfail("masked.logsumexp"),
xfail("masked.mean"),
xfail("masked.prod"),
xfail("masked.std"),
xfail("masked.sum"),
xfail("masked.var"),
xfail("to_sparse"),
# cannot xfail as it is passing for cpu-only build
skip("nn.functional.grid_sample"),
skip("nn.functional.conv2d"),
skip("nn.functional.scaled_dot_product_attention"),
}
fake_decomposition_failures = {
xfail("linalg.matrix_rank"),
xfail("nn.functional.binary_cross_entropy_with_logits"),
xfail("nn.functional.instance_norm"),
xfail("nn.functional.multi_margin_loss"),
xfail("repeat_interleave"),
xfail("take"),
}
def _test_export_helper(self, dtype, op):
sample_inputs_itr = op.sample_inputs("cpu", dtype, requires_grad=False)
mode = FakeTensorMode(allow_non_fake_inputs=True)
converter = mode.fake_tensor_converter
# intentionally avoid cuda:0 to flush out some bugs
target_device = "cuda:1"
def to_fake_device(x):
x = converter.from_real_tensor(mode, x)
x.fake_device = torch.device(target_device)
return x
# Limit to first 100 inputs so tests don't take too long
for sample_input in itertools.islice(sample_inputs_itr, 100):
args = tuple([sample_input.input] + list(sample_input.args))
kwargs = sample_input.kwargs
# hack to skip non-tensor in args, as export doesn't support it
if any(not isinstance(arg, torch.Tensor) for arg in args):
continue
if "device" in kwargs:
kwargs["device"] = target_device
with mode:
args, kwargs = pytree.tree_map_only(
torch.Tensor, to_fake_device, (args, kwargs)
)
class Module(torch.nn.Module):
def forward(self, *args):
return op.op(*args, **kwargs)
m = Module()
ep = torch.export.export(m, args)
for node in ep.graph.nodes:
if node.op == "call_function":
fake_tensor = node.meta.get("val", None)
if isinstance(fake_tensor, FakeTensor):
self.assertEqual(
fake_tensor.device, torch.device(target_device)
)
class TestExportOpInfo(TestCase):
@ops(op_db, allowed_dtypes=(torch.float,))
@skipOps(
"TestExportOpInfo", "test_fake_export", export_failures | fake_export_failures
)
def test_fake_export(self, device, dtype, op):
_test_export_helper(self, dtype, op)
@unittest.skipIf(not torch.backends.cuda.is_built(), "requires CUDA build")
def test_preserve_original_behavior(self):
def cuda_calls_behavior_unchanged():
cpu_x = torch.randn(2)
with self.assertRaisesRegex(
RuntimeError, "Found no NVIDIA driver on your system."
):
cuda_x = cpu_x.to("cuda")
with self.assertRaisesRegex(
RuntimeError, "Found no NVIDIA driver on your system."
):
torch.randn(2, device="cuda")
with self.assertRaisesRegex(
RuntimeError, "Found no NVIDIA driver on your system."
):
torch.cuda.get_device_capability()
with self.assertRaisesRegex(
RuntimeError, "Found no NVIDIA driver on your system."
):
torch.cuda.set_device(1)
with self.assertRaisesRegex(
RuntimeError, "Found no NVIDIA driver on your system."
):
torch.cuda.current_device()
self.assertEqual(torch.cuda.is_available(), False)
self.assertEqual(torch.cuda.device_count(), 0)
cuda_calls_behavior_unchanged()
cpu_x = torch.randn(2)
with FakeTensorMode(allow_non_fake_inputs=True) as mode:
cuda_x = mode.from_tensor(cpu_x)
cuda_x.fake_device = torch.device("cuda")
cuda_y = cuda_x + cuda_x
self.assertEqual(cuda_y.device.type, "cuda")
# should fail again after exiting the fake mode, with the identical error message
cuda_calls_behavior_unchanged()
instantiate_device_type_tests(TestExportOpInfo, globals(), only_for="cpu")
if __name__ == "__main__":
run_tests()