mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
added fake tensor support for foreach_copy (#149127)
Fixes #149111 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149127 Approved by: https://github.com/jansel, https://github.com/jeromean
This commit is contained in:
parent
7aacbab0b3
commit
a9ee797e41
|
|
@ -42,6 +42,7 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_WITH_ROCM,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
|
||||
|
||||
_BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator"
|
||||
|
|
@ -1355,6 +1356,39 @@ class TestForeach(TestCase):
|
|||
for t, ref_t in zip(out, ref_out):
|
||||
self.assertTrue(torch.equal(t, ref_t))
|
||||
|
||||
@requires_cuda
|
||||
@ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
|
||||
def test_foreach_copy_with_different_device_inputs(self, device, dtype, op):
|
||||
if dtype in (torch.complex128, torch.complex64):
|
||||
self.skipTest("Complex dtype not supported")
|
||||
# check foreach_copy when self and src tensorList have different device
|
||||
foreach_copy = op.method_variant
|
||||
copy_ = op.ref_inplace
|
||||
|
||||
def fn(self_tensor, src_tensor, non_blocking):
|
||||
return foreach_copy(self_tensor, src_tensor, non_blocking)
|
||||
|
||||
fn = torch.compile(fn)
|
||||
for non_blocking in (False,):
|
||||
for sample in op.sample_inputs(
|
||||
device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True
|
||||
):
|
||||
with torch.no_grad():
|
||||
ref_input = [t.detach().clone() for t in sample.input]
|
||||
ref_input_cpu = [t.detach().clone().to("cpu") for t in sample.input]
|
||||
rhs_tensors = [t.detach().clone().to("cpu") for t in sample.args[0]]
|
||||
self_tensors = [t.detach().clone().to("cpu") for t in sample.input]
|
||||
|
||||
output1 = fn(sample.input, rhs_tensors, non_blocking)
|
||||
for t, s in zip(ref_input, rhs_tensors):
|
||||
copy_(t, s, non_blocking)
|
||||
self.assertEqual(output1, ref_input)
|
||||
|
||||
output2 = fn(self_tensors, sample.args[0], non_blocking)
|
||||
for t, s in zip(ref_input_cpu, sample.args[0]):
|
||||
copy_(t, s, non_blocking)
|
||||
self.assertEqual(output2, ref_input_cpu)
|
||||
|
||||
# Test reverse-mode & forward-mode AD if supported.
|
||||
@onlyCUDA
|
||||
@ops(
|
||||
|
|
|
|||
|
|
@ -866,8 +866,16 @@ class FakeTensor(Tensor):
|
|||
has_scalar_only_inputs = False
|
||||
is_cpu_zero_dim = None
|
||||
|
||||
# list of ops which can have args(tensor/tensorList) in mixed device
|
||||
mixed_device_fns = ordered_set(
|
||||
aten._foreach_copy.default,
|
||||
)
|
||||
|
||||
def check_cpu_device(device: torch.device) -> bool:
|
||||
return device.type == "cpu"
|
||||
|
||||
def cpu_zero_dim(t: Tensor) -> bool:
|
||||
return t.device.type == "cpu" and t.dim() == 0
|
||||
return check_cpu_device(t.device) and t.dim() == 0
|
||||
|
||||
def merge_devices(t: object) -> None:
|
||||
nonlocal common_device
|
||||
|
|
@ -897,6 +905,14 @@ class FakeTensor(Tensor):
|
|||
is_cpu_zero_dim = t_is_cpu_zero_dim
|
||||
return
|
||||
|
||||
# if still device mismatches we will check ops which can work
|
||||
# on different devices for ex. _foreach_copy, and one of the
|
||||
# device must be cpu in this case we will return from here without
|
||||
# throwing an error
|
||||
if func in mixed_device_fns:
|
||||
if any(map(check_cpu_device, (common_device, t.device))):
|
||||
return
|
||||
|
||||
# mismatching devices of non-zero dim tensors, throw
|
||||
# This might be valid behavior and need to be explicitly modeled, e.g. reshape_as
|
||||
raise RuntimeError(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user