[2/n] Support module.to("cuda:0") in FakeTensorMode on cuda-less machine (#163433)

Summary:
To support exporting a cuda model on a CPU-only machine under fake tensor mode.
User commonly need to move sample inputs to the cuda device with .to("cuda:0") or .to("cuda") call.
This diff supports this.

I expect the following pattern to work

```
with FakeTensorMode(allow_non_fake_inputs=True):
    cuda_module = module.to("cuda:0")
    cuda_sample_inputs = tuple([x.to("cuda:0") for x in sample_inputs])

    with torch.no_grad():
        ep = torch.export.export(cuda_module, cuda_sample_inputs)

```

Before
Moving module.to("cuda:0") under fake tensor mode would have parameter on `meta` device.

After
parameters would be on "cuda:0" .

Test Plan: buck2 run  fbcode//caffe2/test:fake_tensor -- --r test_move_module

Reviewed By: mikaylagawarecki

Differential Revision: D80102876

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163433
Approved by: https://github.com/albanD
This commit is contained in:
Sherlock Huang 2025-09-22 20:16:32 +00:00 committed by PyTorch MergeBot
parent d15048493c
commit 6f9aef5fef
3 changed files with 44 additions and 13 deletions

View File

@ -86,9 +86,7 @@ def _test_export_helper(self, dtype, op):
target_device = "cuda:1"
def to_fake_device(x):
x = converter.from_real_tensor(mode, x)
x.fake_device = torch.device(target_device)
return x
return x.to(target_device)
# Limit to first 100 inputs so tests don't take too long
for sample_input in itertools.islice(sample_inputs_itr, 100):
@ -173,9 +171,7 @@ converter = mode.fake_tensor_converter
target_device = "cuda:1"
def to_fake_device(x):
x = converter.from_real_tensor(mode, x)
x.fake_device = torch.device(target_device)
return x
return x.to(target_device)
# Limit to first 100 inputs so tests don't take too long
for sample_input in itertools.islice(sample_inputs_itr, 100):

View File

@ -214,7 +214,7 @@ class FakeTensorTest(TestCase):
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_op_with_zero_dim_bypassed(self):
if torch._functorch.config.fake_tensor_propagate_real_tensors:
return
self.skipTest("Propagate real tensor not supported")
shape_env = ShapeEnv()
mode = FakeTensorMode(shape_env=shape_env)
x = torch.tensor(1.0, device="cuda")
@ -1516,7 +1516,7 @@ class FakeTensorOperatorInvariants(TestCase):
# Skip this test, we will try to run CUDA operations to real prop so
# it clearly will not work on CPU runner
if torch._functorch.config.fake_tensor_propagate_real_tensors:
return
self.skipTest("Propagate real tensor not supported")
with FakeTensorMode(allow_non_fake_inputs=True):
self.assertEqual(torch.empty(10, device=GPU_TYPE).device.type, GPU_TYPE)
@ -1528,13 +1528,46 @@ class FakeTensorOperatorInvariants(TestCase):
torch.tensor([[3.14, 2], [1, 2]], device=GPU_TYPE).device.type, GPU_TYPE
)
@unittest.skipIf(not torch.backends.cuda.is_built(), "requires CUDA build")
def test_move_module_under_fake(self):
if torch._functorch.config.fake_tensor_propagate_real_tensors:
self.skipTest("Propagate real tensor not supported")
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
self.buffer = torch.nn.Buffer(torch.rand(2, 2))
self.param = torch.nn.Parameter(torch.rand(2, 2))
def forward(self, x):
return self.linear(x) + self.buffer + self.param
m = Module()
input = torch.rand(2, 2)
gpu_device = torch.device(GPU_TYPE, 0)
with FakeTensorMode(allow_non_fake_inputs=True):
m.to(device=gpu_device)
arg = input.to(device=gpu_device)
out = m(arg)
for p in m.parameters():
self.assertTrue(isinstance(p, FakeTensor))
self.assertEqual(p.device, gpu_device)
for b in m.buffers():
self.assertTrue(isinstance(b, FakeTensor))
self.assertEqual(b.device, gpu_device)
self.assertTrue(isinstance(out, FakeTensor))
self.assertEqual(out.device, gpu_device)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_move_meta_tensor(self):
if torch._functorch.config.fake_tensor_propagate_real_tensors:
return
self.skipTest("Propagate real tensor not supported")
meta_tensor = torch.ones(2, device="meta")
gpu_device = torch.device(GPU_TYPE)
with FakeTensorMode(allow_non_fake_inputs=True):
self.assertEqual(meta_tensor.to(device="cpu").device.type, "cpu")
self.assertEqual(meta_tensor.to(device=GPU_TYPE).device.type, GPU_TYPE)

View File

@ -929,8 +929,12 @@ class Module:
for module in self.children():
module._apply(fn)
from torch._subclasses.fake_tensor import FakeTensor
def compute_should_use_set_data(tensor, tensor_applied) -> bool:
if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
if torch._has_compatible_shallow_copy_type(
tensor, tensor_applied
) and not isinstance(tensor_applied, FakeTensor):
# If the new tensor has compatible tensor type as the existing tensor,
# the current behavior is to change the tensor in-place using `.data =`,
# and the future behavior is to overwrite the existing tensor. However,
@ -957,8 +961,6 @@ class Module:
param_applied = fn(param)
p_should_use_set_data = compute_should_use_set_data(param, param_applied)
from torch._subclasses.fake_tensor import FakeTensor
# subclasses may have multiple child tensors so we need to use swap_tensors
p_should_use_swap_tensors = (
should_use_swap_tensors