added fake tensor support for foreach_copy (#149127)

Fixes #149111

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149127
Approved by: https://github.com/jansel, https://github.com/jeromean
This commit is contained in:
pralay 2025-03-27 09:26:19 +00:00 committed by PyTorch MergeBot
parent 7aacbab0b3
commit a9ee797e41
2 changed files with 51 additions and 1 deletions

View File

@ -42,6 +42,7 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ROCM,
TestCase,
)
from torch.testing._internal.triton_utils import requires_cuda
_BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator"
@ -1355,6 +1356,39 @@ class TestForeach(TestCase):
for t, ref_t in zip(out, ref_out):
self.assertTrue(torch.equal(t, ref_t))
@requires_cuda
@ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
def test_foreach_copy_with_different_device_inputs(self, device, dtype, op):
if dtype in (torch.complex128, torch.complex64):
self.skipTest("Complex dtype not supported")
# check foreach_copy when self and src tensorList have different device
foreach_copy = op.method_variant
copy_ = op.ref_inplace
def fn(self_tensor, src_tensor, non_blocking):
return foreach_copy(self_tensor, src_tensor, non_blocking)
fn = torch.compile(fn)
for non_blocking in (False,):
for sample in op.sample_inputs(
device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True
):
with torch.no_grad():
ref_input = [t.detach().clone() for t in sample.input]
ref_input_cpu = [t.detach().clone().to("cpu") for t in sample.input]
rhs_tensors = [t.detach().clone().to("cpu") for t in sample.args[0]]
self_tensors = [t.detach().clone().to("cpu") for t in sample.input]
output1 = fn(sample.input, rhs_tensors, non_blocking)
for t, s in zip(ref_input, rhs_tensors):
copy_(t, s, non_blocking)
self.assertEqual(output1, ref_input)
output2 = fn(self_tensors, sample.args[0], non_blocking)
for t, s in zip(ref_input_cpu, sample.args[0]):
copy_(t, s, non_blocking)
self.assertEqual(output2, ref_input_cpu)
# Test reverse-mode & forward-mode AD if supported.
@onlyCUDA
@ops(

View File

@ -866,8 +866,16 @@ class FakeTensor(Tensor):
has_scalar_only_inputs = False
is_cpu_zero_dim = None
# list of ops which can have args(tensor/tensorList) in mixed device
mixed_device_fns = ordered_set(
aten._foreach_copy.default,
)
def check_cpu_device(device: torch.device) -> bool:
return device.type == "cpu"
def cpu_zero_dim(t: Tensor) -> bool:
return t.device.type == "cpu" and t.dim() == 0
return check_cpu_device(t.device) and t.dim() == 0
def merge_devices(t: object) -> None:
nonlocal common_device
@ -897,6 +905,14 @@ class FakeTensor(Tensor):
is_cpu_zero_dim = t_is_cpu_zero_dim
return
# if still device mismatches we will check ops which can work
# on different devices for ex. _foreach_copy, and one of the
# device must be cpu in this case we will return from here without
# throwing an error
if func in mixed_device_fns:
if any(map(check_cpu_device, (common_device, t.device))):
return
# mismatching devices of non-zero dim tensors, throw
# This might be valid behavior and need to be explicitly modeled, e.g. reshape_as
raise RuntimeError(