pytorch/test/test_foreach.py
Nikita Shulga 6d2887cc06 Reland "Move tensor grouping to ATen" (#103912)
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors,  see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.

Also, using `std::hash` for string should result in a faster hash function.

(cherry picked from commit 74b7a6c75e)

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>

This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
2023-06-21 09:26:33 +00:00

1022 lines
50 KiB
Python

# Owner(s): ["module: mta"]
from contextlib import nullcontext
from numbers import Number
import random
import re
import torch
import unittest
import itertools
from torch.testing import make_tensor
from torch.testing._comparison import default_tolerances
from torch.testing._internal.common_utils import \
TestCase, run_tests, TEST_WITH_ROCM, skipIfTorchDynamo, parametrize, gradcheck
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, onlyCUDA, ops, OpDTypes)
from torch.testing._internal.common_methods_invocations import (
foreach_unary_op_db, foreach_binary_op_db, foreach_pointwise_op_db,
foreach_reduce_op_db, foreach_lerp_op_db)
from torch.testing._internal.common_dtype import (
all_types_and_complex_and, integral_types, complex_types,
floating_types_and, floating_types, integral_types_and,
)
_BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator"
class RegularFuncWrapper:
def __init__(self, func):
self.func = func
def __call__(self, inputs, values=None, **kwargs):
if values is not None:
assert len(inputs) == 3
if isinstance(values, Number):
values = [values for _ in range(len(inputs[0]))]
return [self.func(*i, value=values[idx], **kwargs) for idx, i in enumerate(zip(*inputs))]
if len(inputs) == 2 and isinstance(inputs[1], Number):
# binary op with tensorlist and scalar.
inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
return [self.func(*i, **kwargs) for i in zip(*inputs)]
class ForeachFuncWrapper:
def __init__(self, func):
self.func = func
# Some foreach functions don't have in-place implementations.
self.is_inplace = False if func is None else func.__name__.endswith('_')
def __call__(self, inputs, is_cuda, is_fastpath, **kwargs):
actual = None
zero_size = kwargs.pop("zero_size")
if (
is_cuda and
torch.autograd.kineto_available() and
torch.profiler.ProfilerActivity.CUDA in torch.profiler.supported_activities()
):
with torch.profiler.profile() as p:
actual = self.func(*inputs, **kwargs)
keys = tuple([e.key for e in p.key_averages()])
mta_called = any("multi_tensor_apply_kernel" in k for k in keys)
assert mta_called == (is_fastpath and (not zero_size))
else:
actual = self.func(*inputs, **kwargs)
# note(mkozuki): inplace foreach functions are void functions.
return inputs[0] if self.is_inplace else actual
class InplaceForeachVersionBumpCheck:
def __init__(self, testcase: TestCase, tensorlist: "List[torch.Tensor]") -> None:
self._testcase = testcase
self._tensorlist = tensorlist
self._orig_version_counts = [t._version for t in tensorlist]
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
# note(crcrpar): some methods e.g. `_binary_test` could call the given inplace function multiple times
self._testcase.assertGreaterEqual([t._version for t in self._tensorlist], self._orig_version_counts)
def get_transform_func(num_tensors, dtype, device, is_fastpath):
def transform(t):
if not torch.is_tensor(t):
return t
return make_tensor(
(num_tensors, num_tensors), dtype=dtype, device=device,
requires_grad=True, noncontiguous=not is_fastpath,
)
return transform
def assert_multiple_grad_fns(tensors, test_case):
test_case.assertEqual(len({t.grad_fn for t in tensors}), len(tensors), msg=f"{[t.grad_fn for t in tensors]}")
def clone(arg):
if isinstance(arg, (list, tuple)):
return [clone(a) for a in arg]
if torch.is_tensor(arg):
return arg.clone().detach().requires_grad_()
else:
return arg
# note(crcrpar): `zero_size` is `False` unless (dtype, device) == (torch.float32, "cuda")
# as the pair would go through `multi_tensor_apply_kernel` if inputs are not zero size.
class TestForeach(TestCase):
@property
def is_cuda(self):
return self.device_type == 'cuda'
def _get_funcs(self, op):
return (
ForeachFuncWrapper(op.method_variant),
RegularFuncWrapper(op.ref),
ForeachFuncWrapper(op.inplace_variant),
RegularFuncWrapper(op.ref_inplace),
)
def _binary_test(
self,
dtype, op, ref, inputs, is_fastpath, is_inplace,
*,
alpha, scalar_self_arg: bool, zero_size: bool,
):
if zero_size:
with InplaceForeachVersionBumpCheck(self, inputs[0]) if op.is_inplace else nullcontext():
op(inputs, self.is_cuda, is_fastpath, zero_size=zero_size)
return
ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1]] if is_inplace else inputs
try:
with InplaceForeachVersionBumpCheck(self, inputs[0]) if op.is_inplace else nullcontext():
actual = op(inputs, self.is_cuda, is_fastpath, zero_size=zero_size)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e))):
if not scalar_self_arg:
ref(ref_inputs)
else:
[ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
else:
expected = ref(ref_inputs) if not scalar_self_arg else [ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
self.assertEqual(actual, expected)
if alpha is not None and not scalar_self_arg:
kwargs = {'alpha': alpha}
ref_inputs = inputs
try:
op_kwargs = {}
op_kwargs.update(kwargs)
op_kwargs['zero_size'] = zero_size
with InplaceForeachVersionBumpCheck(self, inputs[0]) if op.is_inplace else nullcontext():
actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e))):
ref(ref_inputs, **kwargs)
else:
expected = ref(ref_inputs, **kwargs)
if dtype in (torch.float16, torch.bfloat16) and TEST_WITH_ROCM:
self.assertEqual(expected, actual, atol=1.e-3, rtol=default_tolerances(dtype)[0])
else:
self.assertEqual(expected, actual)
@ops(foreach_binary_op_db)
@parametrize("is_fastpath", (True, False))
def test_binary_op(self, device, dtype, op, is_fastpath):
scalar_self_arg_test_complete = False
for i, sample in enumerate(op.sample_inputs(device, dtype, noncontiguous=not is_fastpath)):
(rhs_arg,) = sample.args
zero_size = sample.kwargs.pop("zero_size")
kwargs = {} or sample.kwargs
alpha = kwargs.pop("alpha", None)
disable_fastpath = kwargs.pop("disable_fastpath") if is_fastpath else False
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
self._binary_test(
dtype, wrapped_op, ref, [sample.input, rhs_arg],
is_fastpath and not disable_fastpath, False,
alpha=alpha, zero_size=zero_size, scalar_self_arg=False,
)
self._binary_test(
dtype, inplace_op, inplace_ref, [sample.input, rhs_arg],
is_fastpath and not disable_fastpath, True,
alpha=alpha, zero_size=zero_size, scalar_self_arg=False,
)
if op.supports_autograd and dtype in floating_types() and not zero_size:
transformed_sample = sample.transform(get_transform_func(len(sample.input), dtype, device, is_fastpath))
tensors = transformed_sample.input
(rhs_arg,) = transformed_sample.args
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
try:
sum(
wrapped_op([tensors, rhs_arg], is_cuda=False, is_fastpath=False, zero_size=zero_size)
).mean().backward()
except RuntimeError:
with self.assertRaises(RuntimeError):
sum(ref([ref_tensors, ref_rhs_arg])).mean().backward()
else:
sum(ref([ref_tensors, ref_rhs_arg])).mean().backward()
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
if isinstance(rhs_arg, list) and isinstance(rhs_arg[0], torch.Tensor):
self.assertEqual([t.grad for t in rhs_arg], [t.grad for t in ref_rhs_arg])
tensors = [t.clone().detach().requires_grad_().clone() for t in tensors]
ref_tensors = [t.clone().detach().requires_grad_().clone() for t in tensors]
inplace_op([tensors, rhs_arg], is_cuda=False, is_fastpath=False, zero_size=zero_size)
assert_multiple_grad_fns(tensors, self)
# note(crcrpar): the following ops' reference torch functions don't have the overload with Scalar/ScalarList.
is_foreach_max_min_imum_with_scalar_or_scalarlist = (
inplace_op.func in (torch._foreach_minimum_, torch._foreach_maximum_)
and (
isinstance(rhs_arg, Number) or (isinstance(rhs_arg, list) and isinstance(rhs_arg[0], Number))
)
)
if not is_foreach_max_min_imum_with_scalar_or_scalarlist:
inplace_ref([ref_tensors, rhs_arg])
torch.autograd.backward(sum([t.clone() for t in tensors]).sum(), inputs=tensors)
torch.autograd.backward(sum([t.clone() for t in ref_tensors]).sum(), inputs=ref_tensors)
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
if (
op.supports_scalar_self_arg
and isinstance(rhs_arg, Number)
and not scalar_self_arg_test_complete
and not zero_size
):
scalar_self_arg_test_complete = True
self._binary_test(
dtype, wrapped_op, ref, [rhs_arg, sample.input], is_fastpath, False,
alpha=alpha, scalar_self_arg=True, zero_size=False,
)
if op.supports_autograd and dtype == torch.float32 and not zero_size:
transformed_sample = sample.transform(
get_transform_func(len(sample.input), dtype, device, is_fastpath))
tensors = transformed_sample.input
(rhs_arg,) = transformed_sample.args
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
sum(wrapped_op(
[rhs_arg, tensors], is_cuda=False, is_fastpath=False, zero_size=False
)).mean().backward()
sum([ref.func(ref_rhs_arg, t) for t in ref_tensors]).mean().backward()
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
@ops(foreach_pointwise_op_db)
@parametrize("is_fastpath", (True, False))
def test_pointwise_op(self, device, dtype, op, is_fastpath):
for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath):
assert isinstance(sample.args, tuple)
assert len(sample.args) == 2
inputs = [sample.input, *sample.args]
zero_size = sample.kwargs.pop("zero_size")
kwargs = sample.kwargs
disable_fastpath = kwargs.pop("disable_fastpath") if is_fastpath else False
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
values = kwargs.pop("values")
self._pointwise_test(
wrapped_op, ref, inputs, is_fastpath and not disable_fastpath, False, values=values, zero_size=zero_size
)
self._pointwise_test(
inplace_op, inplace_ref, inputs, is_fastpath and not disable_fastpath,
True, values=values, zero_size=zero_size)
if op.supports_autograd and dtype in floating_types() and not zero_size:
transformed_sample = sample.transform(get_transform_func(len(sample.input), dtype, device, is_fastpath))
tensors = transformed_sample.input
rhs_arg = transformed_sample.args
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
try:
sum(
wrapped_op([tensors, *rhs_arg], is_cuda=False, is_fastpath=False, zero_size=zero_size)
).mean().backward()
except RuntimeError:
with self.assertRaises(RuntimeError):
sum(ref([ref_tensors, *ref_rhs_arg])).mean().backward()
else:
sum(ref([ref_tensors, *ref_rhs_arg])).mean().backward()
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
for op_list, ref_list in zip(rhs_arg, ref_rhs_arg):
if isinstance(op_list, list) and isinstance(op_list[0], torch.Tensor):
self.assertEqual([t.grad for t in op_list], [t.grad for t in ref_list])
tensors = [t.clone().detach().requires_grad_().clone() for t in tensors]
ref_tensors = [t.clone().detach().requires_grad_().clone() for t in tensors]
inplace_op([tensors, *rhs_arg], is_cuda=False, is_fastpath=False, zero_size=zero_size)
assert_multiple_grad_fns(tensors, self)
inplace_ref([ref_tensors, *rhs_arg])
torch.autograd.backward(sum([t.clone() for t in tensors]).sum(), inputs=tensors)
torch.autograd.backward(sum([t.clone() for t in ref_tensors]).sum(), inputs=ref_tensors)
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
if is_fastpath and isinstance(values, list) and not zero_size:
sample = sample.transform(lambda t: t.clone().detach() if torch.is_tensor(t) else t)
inputs = [sample.input, *sample.args]
tensor_values = torch.tensor(values)
# 1D Tensor of scalars
for is_inplace, op_, ref_ in ((False, wrapped_op, ref), (True, inplace_op, inplace_ref)):
self._pointwise_test(
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
values=tensor_values, zero_size=False)
self._pointwise_test(
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
values=tensor_values[0],
custom_values_err="Expected packed scalar Tensor to be of dimension 1. Got 0 instead.",
zero_size=False,
)
if self.is_cuda:
self._pointwise_test(
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
values=tensor_values.cuda(),
custom_values_err="Expected scalars to be on CPU, got cuda:0 instead.",
zero_size=False,
)
self._pointwise_test(
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
values=tensor_values[:2],
custom_values_err=f"Expected length of scalars to match input of length {len(values)} but got 2 instead.",
zero_size=False,
)
self._pointwise_test(
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
values=torch.tensor([[0, 1], [2, 3]])[:, 1],
custom_values_err="Expected scalars to be contiguous.",
zero_size=False,
)
if not zero_size:
# Tests of implicit broadcasting
N = len(sample.input)
inputs = [
[make_tensor((N, N), device=device, dtype=dtype, noncontiguous=not is_fastpath) for _ in range(N)],
[
make_tensor((N - i, 1), device=device, dtype=dtype, noncontiguous=not is_fastpath)
for i in range(N)
],
[
make_tensor((1, N - i), device=device, dtype=dtype, noncontiguous=not is_fastpath)
for i in range(N)
],
]
self._pointwise_test(
wrapped_op, ref, inputs, is_fastpath and disable_fastpath, is_inplace=False,
values=values, zero_size=zero_size)
self._pointwise_test(
inplace_op, inplace_ref, inputs, is_fastpath and disable_fastpath,
is_inplace=True, values=values, zero_size=zero_size)
def _pointwise_test(
self,
op, ref, inputs, is_fastpath, is_inplace,
*,
values=None, custom_values_err=None, zero_size,
):
kwargs = {'zero_size': zero_size}
if zero_size:
op(inputs, self.is_cuda, is_fastpath, **kwargs)
return
ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1], inputs[2]] if is_inplace else inputs
try:
with (InplaceForeachVersionBumpCheck(self, inputs[0]) if is_inplace else nullcontext()):
actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e))):
ref(ref_inputs)
else:
expected = ref(ref_inputs)
self.assertEqual(expected, actual)
if values is not None:
try:
actual = op(inputs + [values], self.is_cuda, is_fastpath, **kwargs)
except RuntimeError as e:
# Match with error messages from regular non-foreach reference if no
# custom error message was provided.
if custom_values_err is None:
with self.assertRaisesRegex(type(e), re.escape(str(e))):
ref(ref_inputs, values=values)
else:
self.assertEqual(re.escape(str(e)), re.escape(custom_values_err))
else:
expected = ref(ref_inputs, values=values)
self.assertEqual(expected, actual)
# note(mkozuki): why `try-except` for both fastpath?
# - inputs for fastpath can be integer tensors.
# - this is because opinfo dtypes are configured for out-place implementation
# - for integer inputs, trigonometric functions and exponential function returns float outputs,
# which causes "result type Float can't be case to the desired type" error.
# Thus, `try-except` is used even if `is_fastpath` is `True`.
def _inplace_unary_test(self, inplace, inplace_ref, inputs, is_fastpath, **kwargs):
copied_inputs = [[t.clone().detach() for t in tensors] for tensors in inputs]
try:
with InplaceForeachVersionBumpCheck(self, inputs[0]):
inplace(inputs, self.is_cuda, is_fastpath, **kwargs)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e))):
inplace_ref(copied_inputs)
else:
inplace_ref(copied_inputs)
self.assertEqual(copied_inputs, inputs)
@ops(foreach_unary_op_db)
@parametrize("is_fastpath", (True, False))
def test_unary_op(self, device, dtype, op, is_fastpath):
out_place_defined = op.name != "_foreach_zero"
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
samples = op.sample_inputs(device, dtype, noncontiguous=not is_fastpath)
disable_fastpath = op.name == "_foreach_abs" and dtype in complex_types()
for sample in samples:
zero_size = sample.kwargs.pop('zero_size')
inputs = [sample.input]
if zero_size:
if out_place_defined:
wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size)
inplace_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size)
continue
inputs = [sample.input]
disable_fastpath = (op.name == "_foreach_abs" and dtype in complex_types()) or sample.kwargs.pop(
"disable_fastpath"
)
if out_place_defined:
self.assertEqual(
ref(inputs),
wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size),
)
self._inplace_unary_test(
inplace_op, inplace_ref, [sample.input], is_fastpath and not disable_fastpath, zero_size=zero_size
)
if op.supports_autograd and dtype in floating_types() and not zero_size:
tensors = [t.clone().detach().requires_grad_() for t in sample.input]
ref_tensors = [t.clone().detach().requires_grad_() for t in tensors]
if out_place_defined:
out = wrapped_op.func(tensors)
# tensors have different shapes
torch.cat([t.view(-1) for t in out]).mean().backward()
torch.cat([ref.func(t).view(-1) for t in ref_tensors]).mean().backward()
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
self.assertEqual(len({t.grad_fn for t in out}), 1)
inplace_input_tensors = [t.clone().detach().requires_grad_() for t in tensors]
inplace_inputs = [t.clone() for t in inplace_input_tensors]
# set both to False to skip multi_tensor_apply_kernel check
inplace_op([inplace_inputs], False, False, zero_size=zero_size)
assert_multiple_grad_fns(inplace_inputs, self)
# per-tensor `grad_fn` check.
hook_buffer = []
def get_grad_fn_hook(i):
def hook(grad_inputs, grad_outputs) -> None:
hook_buffer.append(i)
return hook
for i, t in enumerate(inplace_inputs):
t.grad_fn.register_hook(get_grad_fn_hook(i))
_ = torch.autograd.grad(
inplace_inputs[0],
inputs=(inplace_input_tensors[0],),
grad_outputs=(torch.rand_like(inplace_inputs[0]),),
retain_graph=True,
)
self.assertEqual(hook_buffer, [0])
hook_buffer.clear()
# tensors have different shapes.
sum_of_cloned_tensors = torch.cat([t.view(-1) for t in inplace_inputs]).sum()
grad_output = torch.rand_like(sum_of_cloned_tensors)
grad_inputs = torch.autograd.grad(
sum_of_cloned_tensors,
inputs=tuple(inplace_input_tensors),
grad_outputs=(grad_output,),
retain_graph=False,
)
self.assertEqual(hook_buffer, list(reversed(range(len(inplace_inputs)))))
ref_inplace_input_tensors = [t.clone().detach().requires_grad_() for t in inplace_input_tensors]
ref_inplace_inputs = [t.clone() for t in ref_inplace_input_tensors]
ref_output = inplace_ref([ref_inplace_inputs])
ref_grad_inputs = torch.autograd.grad(
torch.cat([t.view(-1) for t in ref_output]).sum(),
inputs=tuple(ref_inplace_input_tensors),
grad_outputs=(grad_output,),
)
self.assertEqual(grad_inputs, ref_grad_inputs)
@ops(foreach_reduce_op_db)
@parametrize("is_fastpath", (True, False))
def test_reduce_op(self, device, dtype, op, is_fastpath):
for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath):
ord = sample.kwargs.pop("ord")
zero_size = sample.kwargs.pop("zero_size")
disable_fastpath = sample.kwargs.pop("disable_fastpath", False)
inputs = (sample.input,)
wrapped_op, ref, _, _ = self._get_funcs(op)
self.assertEqual(
ref(inputs, ord=ord),
wrapped_op(
inputs, self.is_cuda, is_fastpath and not disable_fastpath, ord=ord,
zero_size=zero_size,
),
)
if op.supports_autograd and dtype in floating_types() and not zero_size:
transformed_sample = sample.transform(get_transform_func(len(sample.input), dtype, device, is_fastpath))
tensors = transformed_sample.input
ref_tensors = clone(tensors)
sum(wrapped_op((tensors,), False, False, ord=ord, zero_size=zero_size)).backward()
sum(ref((ref_tensors,), ord=ord)).backward()
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
# TODO: enable empty list case
for tensors in [[torch.randn([0])]]:
res = torch._foreach_add(tensors, 1)
self.assertEqual(res, tensors)
torch._foreach_add_(tensors, 1)
self.assertEqual(res, tensors)
@ops(foreach_binary_op_db, dtypes=OpDTypes.supported)
def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op):
foreach_op, ref = op.method_variant, op.ref
tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
if ref == torch.sub and dtype == torch.bool:
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
[ref(t, 1) for t in tensors]
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
foreach_op(tensors, 1)
return
expected = [ref(t, 1) for t in tensors]
res = foreach_op(tensors, 1)
self.assertEqual(res, expected)
@ops(foreach_binary_op_db, allowed_dtypes=[torch.float])
def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
foreach_op = op.method_variant
tensors = [
torch.tensor([1.1], dtype=torch.float, device=device),
torch.tensor([1], dtype=torch.long, device=device),
]
runtime_error = None
try:
foreach_op(tensors, 1)
except RuntimeError as e:
runtime_error = e
self.assertIsNone(runtime_error)
@skipIfTorchDynamo("Different error msgs, TODO")
@ops(foreach_binary_op_db, dtypes=OpDTypes.supported)
def test_binary_op_list_error_cases(self, device, dtype, op):
foreach_op, foreach_op_, ref, ref_ = op.method_variant, op.inplace_variant, op.ref, op.ref_inplace
tensors1 = []
tensors2 = []
# Empty lists
with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"):
foreach_op(tensors1, tensors2)
with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"):
foreach_op_(tensors1, tensors2)
# One empty list
tensors1.append(torch.tensor([1], device=device, dtype=dtype))
with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."):
foreach_op(tensors1, tensors2)
with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."):
foreach_op_(tensors1, tensors2)
# Lists have different amount of tensors
tensors2.append(torch.tensor([1], device=device))
tensors2.append(torch.tensor([1], device=device))
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
foreach_op(tensors1, tensors2)
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
foreach_op_(tensors1, tensors2)
# Corresponding tensors with different sizes that aren't compatible with broadcast
# If sizes are different then foreach chooses slow path, thus error messages are expected
# to be the same as torch regular function.
tensors1 = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
tensors2 = [torch.ones(11, 11, device=device, dtype=dtype) for _ in range(10)]
try:
foreach_op(tensors1, tensors2)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e))):
[ref(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
try:
foreach_op_(tensors1, tensors2)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e))):
[ref_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
# different devices
if self.device_type == "cuda" and torch.cuda.device_count() > 1:
tensor1 = torch.zeros(10, 10, device="cuda:0", dtype=dtype)
tensor2 = torch.ones(10, 10, device="cuda:1", dtype=dtype)
if dtype == torch.bool and foreach_op == torch._foreach_sub:
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
foreach_op([tensor1], [tensor2])
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
foreach_op_([tensor1], [tensor2])
return
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
foreach_op([tensor1], [tensor2])
if dtype in integral_types_and(torch.bool) and foreach_op == torch._foreach_div:
with self.assertRaisesRegex(RuntimeError, "result type"):
foreach_op_([tensor1], [tensor2])
else:
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
foreach_op_([tensor1], [tensor2])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
@ops(foreach_binary_op_db, dtypes=OpDTypes.supported)
def test_binary_op_list_slow_path(self, device, dtype, op):
foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op)
# 0-strides
tensor1 = make_tensor((10, 10), dtype=dtype, device=device)
tensor2 = make_tensor((1,), device=device, dtype=dtype).expand_as(tensor1)
inputs = ([tensor1], [tensor2])
self._binary_test(
dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
zero_size=False, alpha=None, scalar_self_arg=False)
self._binary_test(
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
zero_size=False, alpha=None, scalar_self_arg=False)
# different strides
tensor1 = torch.zeros(10, 10, device=device, dtype=dtype)
tensor2 = torch.ones(10, 10, device=device, dtype=dtype)
inputs = ([tensor1], [tensor2.t()])
self._binary_test(
dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
zero_size=False, alpha=None, scalar_self_arg=False)
self._binary_test(
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
zero_size=False, alpha=None, scalar_self_arg=False)
# non contiguous
tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True)
tensor2 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True)
self.assertFalse(tensor1.is_contiguous())
self.assertFalse(tensor2.is_contiguous())
inputs = ([tensor1], [tensor2])
self._binary_test(
dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
zero_size=False, alpha=None, scalar_self_arg=False)
self._binary_test(
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
zero_size=False, alpha=None, scalar_self_arg=False)
# sliced tensor
tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype)
tensor2 = make_tensor((5, 2, 1, 3 * 7), device=device, dtype=dtype)[:, :, :, ::7]
inputs = ([tensor1], [tensor2])
self._binary_test(
dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
zero_size=False, alpha=None, scalar_self_arg=False)
self._binary_test(
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
zero_size=False, alpha=None, scalar_self_arg=False)
@ops(foreach_binary_op_db, dtypes=floating_types_and(torch.half, torch.bfloat16))
def test_binary_op_float_inf_nan(self, device, dtype, op):
inputs = (
[
torch.tensor([float("inf")], device=device, dtype=dtype),
torch.tensor([-float("inf")], device=device, dtype=dtype),
torch.tensor([float("nan")], device=device, dtype=dtype),
torch.tensor([float("nan")], device=device, dtype=dtype),
],
[
torch.tensor([-float("inf")], device=device, dtype=dtype),
torch.tensor([float("inf")], device=device, dtype=dtype),
torch.tensor([float("inf")], device=device, dtype=dtype),
torch.tensor([float("nan")], device=device, dtype=dtype),
],
)
op, ref, inplace_op, inplace_ref = self._get_funcs(op)
self._binary_test(dtype, op, ref, inputs, True, False, zero_size=False, alpha=None, scalar_self_arg=False)
self._binary_test(
dtype, inplace_op, inplace_ref, inputs, True, True, zero_size=False, alpha=None, scalar_self_arg=False
)
# note: Below three tests (postfixed with `_tensors_on_different_devices`)
# checks whether foreach works with lists of tensors on different devices
# but tensors of the same index are on the same device, e.g., ['cuda', 'cpu].
@onlyCUDA
@ops(foreach_unary_op_db)
def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
out_place_defined = op.name != "_foreach_zero"
method, ref, inplace_method, ref_inplace = self._get_funcs(op)
# tensors: ['cuda', 'cpu]
tensors = list(op.sample_inputs(device, dtype, num_input_tensors=[2]))[0].input
tensors[1] = tensors[1].to("cpu")
if out_place_defined:
try:
actual = method((tensors,), False, False, zero_size=False)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), str(e)):
ref((tensors,))
else:
expected = ref((tensors,))
self.assertEqual(expected, actual)
try:
inplace_method((tensors,), False, False, zero_size=False)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), str(e)):
ref_inplace((tensors,))
else:
if out_place_defined:
self.assertEqual(expected, tensors)
else:
self.assertEqual([torch.zeros_like(t) for t in tensors], tensors)
@onlyCUDA
@ops(foreach_binary_op_db)
def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
# `tensors1`: ['cuda', 'cpu']
# `tensors2`: ['cuda', 'cpu']
_cuda_tensors = list(op.sample_inputs(device, dtype, num_input_tensors=[2], same_size=True))[0].input
_cpu_tensors = list(op.sample_inputs("cpu", dtype, num_input_tensors=[2], same_size=True))[0].input
tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors))
foreach_op, foreach_op_ = op.method_variant, op.inplace_variant
native_op, native_op_ = op.ref, op.ref_inplace
try:
actual = foreach_op(tensors1, tensors2)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e))):
[native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
else:
expected = [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
self.assertEqual(expected, actual)
try:
foreach_op_(tensors1, tensors2)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e))):
[native_op_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
else:
self.assertEqual(actual, tensors1)
@onlyCUDA
@ops(foreach_pointwise_op_db, allowed_dtypes=floating_types())
def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op):
# tensors1: ['cuda', 'cpu]
# tensors2: ['cuda', 'cpu]
# tensors3: ['cuda', 'cpu]
# first tensorlist is zero-size when float32
_cuda_tensors = list(
op.sample_inputs(device, dtype, num_input_tensors=[3], same_size=True)
)[int(dtype == torch.float32)].input
_cpu_tensors = list(op.sample_inputs("cpu", dtype, num_input_tensors=[3], same_size=True))[0].input
tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors))
foreach_op, foreach_op_, native_op = op.method_variant, op.inplace_variant, op.ref
actual = foreach_op(tensors1, tensors2, tensors3)
expected = [native_op(*_cuda_tensors), native_op(*_cpu_tensors)]
self.assertEqual(expected, actual)
# note(mkozuki): Limiting dtypes to FP32&FP64, we can safely run inplace ops.
foreach_op_(tensors1, tensors2, tensors3)
self.assertEqual(expected, tensors1)
# note: BFloat16 has the same number of exponent bits as FP32
# so if squared L2 norm overflows in BF16, then it also overflows in FP32.
@onlyCUDA
@ops(foreach_reduce_op_db, allowed_dtypes=(torch.half, torch.bfloat16))
def test_foreach_l2_large_value_input(self, device, dtype, op):
ord, N = 2, 10
max_value = torch.finfo(dtype).max
scaler = torch.tensor([max_value]).sqrt().to(device=device, dtype=dtype)
inputs = ([
t * scaler for t in list(
op.sample_inputs(device, dtype, requries_grad=True, num_input_tensors=[N], low=1)
)[0].input
],)
# make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`.
self.assertTrue(scaler * scaler * N > max_value)
fn, ref_fn, *_ = self._get_funcs(op)
actual = fn(inputs, is_cuda=True, is_fastpath=True, ord=ord, zero_size=False)
expect = ref_fn(inputs, ord=ord)
if dtype == torch.float16:
# making sure the reference L2 norm values are in the range of FP16.
self.assertFalse(any(torch.isinf(e) for e in expect))
else:
self.assertTrue(all(torch.isinf(e) for e in expect))
self.assertEqual(expect, actual, equal_nan=False)
@parametrize("is_fastpath", (True, False))
@ops(foreach_lerp_op_db)
def test_lerp(self, device, dtype, op, is_fastpath):
for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath):
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
args = [*sample.args]
inputs = [sample.input, args[0]]
zero_size = sample.kwargs.pop("zero_size")
kwargs, ref_kwargs = {"zero_size": zero_size}, {}
if isinstance(args[1], list):
inputs.append(args[1])
else:
kwargs["weight"] = args[1]
ref_kwargs["weight"] = args[1]
if dtype in integral_types() or dtype == torch.bool or (not self.is_cuda and dtype == torch.half):
with self.assertRaises(RuntimeError):
wrapped_op(inputs, self.is_cuda, is_fastpath, **kwargs)
return
actual = wrapped_op(inputs, self.is_cuda, is_fastpath, **kwargs)
expected = ref(inputs, **ref_kwargs)
self.assertEqual(actual, expected)
inplace_inputs = [[t.clone() for t in inputs[0]]] + inputs[1:]
with InplaceForeachVersionBumpCheck(self, inplace_inputs[0]):
inplace_actual = inplace_op(inplace_inputs, self.is_cuda, is_fastpath, **kwargs)
self.assertEqual(inplace_actual, expected)
if op.supports_autograd and dtype in floating_types() and not zero_size:
transformed_sample = sample.transform(get_transform_func(len(sample.input), dtype, device, is_fastpath))
args = [*transformed_sample.args]
inputs = [transformed_sample.input, args[0]]
kwargs, ref_kwargs = {}, {}
if isinstance(args[1], list):
inputs.append(args[1])
else:
kwargs = ref_kwargs = {"weight": args[1]}
ref_tensors = clone(transformed_sample.input)
sum(
wrapped_op((transformed_sample.input, *inputs[1:]), False, False, **kwargs, zero_size=zero_size)
).mean().backward()
sum(ref((ref_tensors, *inputs[1:]), **ref_kwargs)).mean().backward()
self.assertEqual(
[t.grad for t in transformed_sample.input],
[t.grad for t in ref_tensors],
)
_tensors = [t.clone().detach().requires_grad_() for t in transformed_sample.input]
_ref_tensors = [t.clone().detach().requires_grad_() for t in _tensors]
tensors = [t.clone() for t in _tensors]
inplace_op((tensors, *inputs[1:]), False, False, **kwargs, zero_size=False)
ref_tensors = [t.clone() for t in _ref_tensors]
inplace_ref((ref_tensors, *inputs[1:]), **ref_kwargs)
assert_multiple_grad_fns(tensors, self)
# tensors have different shapes.
torch.autograd.backward(torch.cat([t.clone().view(-1) for t in tensors]).sum(), inputs=tensors)
torch.autograd.backward(torch.cat([t.clone().view(-1) for t in ref_tensors]).sum(), inputs=ref_tensors)
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
@onlyCUDA
@ops(foreach_reduce_op_db)
def test_foreach_reduce_large_input(self, device, dtype, op):
# test inputs larger than kChunkSize = 65536
ord, N = 2, 65536 * 2
disable_fastpath = True
if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
disable_fastpath = False
inputs = ([make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)],)
wrapped_op, ref, _, _ = self._get_funcs(op)
self.assertEqual(
ref(inputs, ord=ord),
wrapped_op(inputs, self.is_cuda, not disable_fastpath, ord=ord, zero_size=False),
)
@onlyCUDA
@ops(
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
dtypes=(torch.float,),
)
def test_inplace_foreach_leaf_check_and_grad_fn(self, device, dtype, op):
inplace_op = op.inplace_variant
if inplace_op is None:
self.skipTest("no in-place op available")
sample = list(op.sample_inputs(dtype=dtype, device=device, num_input_tensors=[2], same_size=True))[0]
sample.input[0].requires_grad_(True)
with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
inplace_op(sample.input, *sample.args)
sample.input[1].requires_grad_(True)
with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
inplace_op(sample.input, *sample.args)
_tensors = [t.clone().detach().requires_grad_(i == 0) for i, t in enumerate(sample.input)]
tensors = [t.clone() for t in _tensors]
inplace_op(tensors, *sample.args)
self.assertIsNotNone(tensors[0].grad_fn)
self.assertIsNone(tensors[1].grad_fn)
@onlyCUDA
@ops(
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
dtypes=(torch.float,),
)
def test_outplace_with_invalid_grads(self, device, dtype, op):
if op.name in {"_foreach_zero"}:
self.skipTest(f"{op.name} does not have out-place implementation")
func, *_ = self._get_funcs(op)
sample = list(op.sample_inputs(dtype=dtype, device=device, requires_grad=True, num_input_tensors=[2], same_size=True))[0]
self.assertTrue(all(t.requires_grad for t in sample.input))
sample.kwargs.pop("disable_fastpath")
if func.func in (torch._foreach_addcmul, torch._foreach_addcdiv):
if sample.kwargs.get("values") is None:
sample.kwargs.pop("values")
(out1, out2) = func([sample.input, *sample.args], is_cuda=False, is_fastpath=False, **sample.kwargs)
out1.backward(torch.ones_like(out1))
self.assertIsNotNone(sample.input[0].grad)
self.assertIsNone(sample.input[1].grad)
@ops(
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
dtypes=OpDTypes.supported,
allowed_dtypes=(torch.float64, torch.complex128),
)
def test_inplace_forward_mode_AD(self, device, dtype, op):
if not op.supports_forward_ad:
self.skipTest("forward AD not supported")
# note(crcrpar): The combinations below are failing in its forward path,
# which is before forward-mode AD happens. This function gates the combinations where
# - subtraction with Scalar/ScalarList of boolean value:
# - combinations where the in-place op in questions tries to write out complex result
# into float storage (= `self`)
def check_sample_eligibility(op, sample, dtype):
if (
op.name == "_foreach_sub"
and (
(isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
or isinstance(sample.args[0], bool)
)
):
return False, _BOOL_SUB_ERR_MSG
rhs_arg_has_complex_number = sample.args and ((
isinstance(sample.args[0], list)
and any(isinstance(a, complex) for a in sample.args[0])
) or (
isinstance(sample.args[0], complex)
))
if dtype == torch.float64 and rhs_arg_has_complex_number:
if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
return False, "result type ComplexDouble can't be cast to the desired output type Double"
if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"):
return False, "clamp is not supported for complex types"
if op.name == "_foreach_pow":
return False, "Found dtype Double but expected ComplexDouble"
return True, ""
for sample in op.sample_inputs(
device, dtype, requires_grad=True, num_input_tensors=[5], same_size=True,
):
# Call `clone` to avoid inplace modifications likewise
# `torch.testing._internal.common_utils.TestGradients._get_safe_inplace`
def inplace_func(*tensorlist):
kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {}
op.inplace_variant(tuple(t.clone() for t in tensorlist), *sample.args, **kwargs)
return tensorlist
working_sample, err_msg_pattern = check_sample_eligibility(op, sample, dtype)
if not working_sample:
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
gradcheck(
inplace_func,
sample.input,
raise_exception=True,
check_forward_ad=True,
check_backward_ad=False,
check_batched_grad=False,
)
else:
gradcheck(
inplace_func,
sample.input,
raise_exception=True,
check_forward_ad=True,
check_backward_ad=False,
check_batched_grad=False,
)
@unittest.skipIf(not (torch.cuda.is_available() and torch.cuda.device_count() > 1), "requires multiple GPUs")
def test_tensors_grouping(self):
num_tensors_per_list = 10
num_devices = torch.cuda.device_count()
dtypes = (torch.float16, torch.float32, torch.float64)
list1 = [
torch.tensor(
i,
device=torch.device("cuda", random.randint(0, num_devices - 1)),
dtype=dtypes[random.randint(0, 2)],
) for i in range(num_tensors_per_list)
]
list2 = [None for _ in list1]
list3 = [torch.rand_like(t) for t in list1]
nested_tensorlists = [list1, list2, list3]
grouped_tensors = torch.utils._foreach_utils._group_tensors_by_device_and_dtype(nested_tensorlists, with_indices=True)
num_tensors_seen = 0
for (device, dtype), ([l1, l2, l3], indices) in grouped_tensors.items():
for t in itertools.chain(l1, l3):
self.assertEquals(t.device, device)
self.assertEquals(t.dtype, dtype)
num_tensors_seen += 1
self.assertEqual(len(l1), len(l2))
self.assertTrue(all(p is None for p in l2))
for i, index in enumerate(indices):
self.assertEquals(l1[i], list1[index])
self.assertEquals(l2[i], list2[index])
self.assertEquals(l3[i], list3[index])
self.assertEquals(num_tensors_seen, 2 * num_tensors_per_list)
instantiate_device_type_tests(TestForeach, globals())
if __name__ == "__main__":
run_tests()