mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors, see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.
Also, using `std::hash` for string should result in a faster hash function.
(cherry picked from commit 74b7a6c75e)
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>
This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
1022 lines
50 KiB
Python
1022 lines
50 KiB
Python
# Owner(s): ["module: mta"]
|
|
|
|
from contextlib import nullcontext
|
|
from numbers import Number
|
|
import random
|
|
import re
|
|
import torch
|
|
import unittest
|
|
import itertools
|
|
|
|
from torch.testing import make_tensor
|
|
from torch.testing._comparison import default_tolerances
|
|
from torch.testing._internal.common_utils import \
|
|
TestCase, run_tests, TEST_WITH_ROCM, skipIfTorchDynamo, parametrize, gradcheck
|
|
from torch.testing._internal.common_device_type import \
|
|
(instantiate_device_type_tests, dtypes, onlyCUDA, ops, OpDTypes)
|
|
from torch.testing._internal.common_methods_invocations import (
|
|
foreach_unary_op_db, foreach_binary_op_db, foreach_pointwise_op_db,
|
|
foreach_reduce_op_db, foreach_lerp_op_db)
|
|
from torch.testing._internal.common_dtype import (
|
|
all_types_and_complex_and, integral_types, complex_types,
|
|
floating_types_and, floating_types, integral_types_and,
|
|
)
|
|
|
|
|
|
_BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator"
|
|
|
|
|
|
class RegularFuncWrapper:
|
|
def __init__(self, func):
|
|
self.func = func
|
|
|
|
def __call__(self, inputs, values=None, **kwargs):
|
|
if values is not None:
|
|
assert len(inputs) == 3
|
|
if isinstance(values, Number):
|
|
values = [values for _ in range(len(inputs[0]))]
|
|
return [self.func(*i, value=values[idx], **kwargs) for idx, i in enumerate(zip(*inputs))]
|
|
if len(inputs) == 2 and isinstance(inputs[1], Number):
|
|
# binary op with tensorlist and scalar.
|
|
inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
|
|
return [self.func(*i, **kwargs) for i in zip(*inputs)]
|
|
|
|
|
|
class ForeachFuncWrapper:
|
|
def __init__(self, func):
|
|
self.func = func
|
|
# Some foreach functions don't have in-place implementations.
|
|
self.is_inplace = False if func is None else func.__name__.endswith('_')
|
|
|
|
def __call__(self, inputs, is_cuda, is_fastpath, **kwargs):
|
|
actual = None
|
|
zero_size = kwargs.pop("zero_size")
|
|
if (
|
|
is_cuda and
|
|
torch.autograd.kineto_available() and
|
|
torch.profiler.ProfilerActivity.CUDA in torch.profiler.supported_activities()
|
|
):
|
|
with torch.profiler.profile() as p:
|
|
actual = self.func(*inputs, **kwargs)
|
|
keys = tuple([e.key for e in p.key_averages()])
|
|
mta_called = any("multi_tensor_apply_kernel" in k for k in keys)
|
|
assert mta_called == (is_fastpath and (not zero_size))
|
|
else:
|
|
actual = self.func(*inputs, **kwargs)
|
|
# note(mkozuki): inplace foreach functions are void functions.
|
|
return inputs[0] if self.is_inplace else actual
|
|
|
|
|
|
class InplaceForeachVersionBumpCheck:
|
|
|
|
def __init__(self, testcase: TestCase, tensorlist: "List[torch.Tensor]") -> None:
|
|
self._testcase = testcase
|
|
self._tensorlist = tensorlist
|
|
self._orig_version_counts = [t._version for t in tensorlist]
|
|
|
|
def __enter__(self):
|
|
pass
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
# note(crcrpar): some methods e.g. `_binary_test` could call the given inplace function multiple times
|
|
self._testcase.assertGreaterEqual([t._version for t in self._tensorlist], self._orig_version_counts)
|
|
|
|
|
|
def get_transform_func(num_tensors, dtype, device, is_fastpath):
|
|
def transform(t):
|
|
if not torch.is_tensor(t):
|
|
return t
|
|
return make_tensor(
|
|
(num_tensors, num_tensors), dtype=dtype, device=device,
|
|
requires_grad=True, noncontiguous=not is_fastpath,
|
|
)
|
|
|
|
return transform
|
|
|
|
|
|
def assert_multiple_grad_fns(tensors, test_case):
|
|
test_case.assertEqual(len({t.grad_fn for t in tensors}), len(tensors), msg=f"{[t.grad_fn for t in tensors]}")
|
|
|
|
|
|
def clone(arg):
|
|
if isinstance(arg, (list, tuple)):
|
|
return [clone(a) for a in arg]
|
|
if torch.is_tensor(arg):
|
|
return arg.clone().detach().requires_grad_()
|
|
else:
|
|
return arg
|
|
|
|
|
|
# note(crcrpar): `zero_size` is `False` unless (dtype, device) == (torch.float32, "cuda")
|
|
# as the pair would go through `multi_tensor_apply_kernel` if inputs are not zero size.
|
|
class TestForeach(TestCase):
|
|
@property
|
|
def is_cuda(self):
|
|
return self.device_type == 'cuda'
|
|
|
|
def _get_funcs(self, op):
|
|
return (
|
|
ForeachFuncWrapper(op.method_variant),
|
|
RegularFuncWrapper(op.ref),
|
|
ForeachFuncWrapper(op.inplace_variant),
|
|
RegularFuncWrapper(op.ref_inplace),
|
|
)
|
|
|
|
def _binary_test(
|
|
self,
|
|
dtype, op, ref, inputs, is_fastpath, is_inplace,
|
|
*,
|
|
alpha, scalar_self_arg: bool, zero_size: bool,
|
|
):
|
|
if zero_size:
|
|
with InplaceForeachVersionBumpCheck(self, inputs[0]) if op.is_inplace else nullcontext():
|
|
op(inputs, self.is_cuda, is_fastpath, zero_size=zero_size)
|
|
return
|
|
|
|
ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1]] if is_inplace else inputs
|
|
try:
|
|
with InplaceForeachVersionBumpCheck(self, inputs[0]) if op.is_inplace else nullcontext():
|
|
actual = op(inputs, self.is_cuda, is_fastpath, zero_size=zero_size)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
if not scalar_self_arg:
|
|
ref(ref_inputs)
|
|
else:
|
|
[ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
|
|
else:
|
|
expected = ref(ref_inputs) if not scalar_self_arg else [ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
|
|
self.assertEqual(actual, expected)
|
|
if alpha is not None and not scalar_self_arg:
|
|
kwargs = {'alpha': alpha}
|
|
ref_inputs = inputs
|
|
try:
|
|
op_kwargs = {}
|
|
op_kwargs.update(kwargs)
|
|
op_kwargs['zero_size'] = zero_size
|
|
with InplaceForeachVersionBumpCheck(self, inputs[0]) if op.is_inplace else nullcontext():
|
|
actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
ref(ref_inputs, **kwargs)
|
|
else:
|
|
expected = ref(ref_inputs, **kwargs)
|
|
if dtype in (torch.float16, torch.bfloat16) and TEST_WITH_ROCM:
|
|
self.assertEqual(expected, actual, atol=1.e-3, rtol=default_tolerances(dtype)[0])
|
|
else:
|
|
self.assertEqual(expected, actual)
|
|
|
|
@ops(foreach_binary_op_db)
|
|
@parametrize("is_fastpath", (True, False))
|
|
def test_binary_op(self, device, dtype, op, is_fastpath):
|
|
scalar_self_arg_test_complete = False
|
|
for i, sample in enumerate(op.sample_inputs(device, dtype, noncontiguous=not is_fastpath)):
|
|
(rhs_arg,) = sample.args
|
|
zero_size = sample.kwargs.pop("zero_size")
|
|
kwargs = {} or sample.kwargs
|
|
alpha = kwargs.pop("alpha", None)
|
|
disable_fastpath = kwargs.pop("disable_fastpath") if is_fastpath else False
|
|
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
|
|
self._binary_test(
|
|
dtype, wrapped_op, ref, [sample.input, rhs_arg],
|
|
is_fastpath and not disable_fastpath, False,
|
|
alpha=alpha, zero_size=zero_size, scalar_self_arg=False,
|
|
)
|
|
self._binary_test(
|
|
dtype, inplace_op, inplace_ref, [sample.input, rhs_arg],
|
|
is_fastpath and not disable_fastpath, True,
|
|
alpha=alpha, zero_size=zero_size, scalar_self_arg=False,
|
|
)
|
|
|
|
if op.supports_autograd and dtype in floating_types() and not zero_size:
|
|
transformed_sample = sample.transform(get_transform_func(len(sample.input), dtype, device, is_fastpath))
|
|
tensors = transformed_sample.input
|
|
(rhs_arg,) = transformed_sample.args
|
|
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
|
|
try:
|
|
sum(
|
|
wrapped_op([tensors, rhs_arg], is_cuda=False, is_fastpath=False, zero_size=zero_size)
|
|
).mean().backward()
|
|
except RuntimeError:
|
|
with self.assertRaises(RuntimeError):
|
|
sum(ref([ref_tensors, ref_rhs_arg])).mean().backward()
|
|
else:
|
|
sum(ref([ref_tensors, ref_rhs_arg])).mean().backward()
|
|
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
|
|
if isinstance(rhs_arg, list) and isinstance(rhs_arg[0], torch.Tensor):
|
|
self.assertEqual([t.grad for t in rhs_arg], [t.grad for t in ref_rhs_arg])
|
|
tensors = [t.clone().detach().requires_grad_().clone() for t in tensors]
|
|
ref_tensors = [t.clone().detach().requires_grad_().clone() for t in tensors]
|
|
inplace_op([tensors, rhs_arg], is_cuda=False, is_fastpath=False, zero_size=zero_size)
|
|
assert_multiple_grad_fns(tensors, self)
|
|
|
|
# note(crcrpar): the following ops' reference torch functions don't have the overload with Scalar/ScalarList.
|
|
is_foreach_max_min_imum_with_scalar_or_scalarlist = (
|
|
inplace_op.func in (torch._foreach_minimum_, torch._foreach_maximum_)
|
|
and (
|
|
isinstance(rhs_arg, Number) or (isinstance(rhs_arg, list) and isinstance(rhs_arg[0], Number))
|
|
)
|
|
)
|
|
if not is_foreach_max_min_imum_with_scalar_or_scalarlist:
|
|
inplace_ref([ref_tensors, rhs_arg])
|
|
torch.autograd.backward(sum([t.clone() for t in tensors]).sum(), inputs=tensors)
|
|
torch.autograd.backward(sum([t.clone() for t in ref_tensors]).sum(), inputs=ref_tensors)
|
|
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
|
|
if (
|
|
op.supports_scalar_self_arg
|
|
and isinstance(rhs_arg, Number)
|
|
and not scalar_self_arg_test_complete
|
|
and not zero_size
|
|
):
|
|
scalar_self_arg_test_complete = True
|
|
self._binary_test(
|
|
dtype, wrapped_op, ref, [rhs_arg, sample.input], is_fastpath, False,
|
|
alpha=alpha, scalar_self_arg=True, zero_size=False,
|
|
)
|
|
if op.supports_autograd and dtype == torch.float32 and not zero_size:
|
|
transformed_sample = sample.transform(
|
|
get_transform_func(len(sample.input), dtype, device, is_fastpath))
|
|
tensors = transformed_sample.input
|
|
(rhs_arg,) = transformed_sample.args
|
|
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
|
|
sum(wrapped_op(
|
|
[rhs_arg, tensors], is_cuda=False, is_fastpath=False, zero_size=False
|
|
)).mean().backward()
|
|
sum([ref.func(ref_rhs_arg, t) for t in ref_tensors]).mean().backward()
|
|
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
|
|
|
|
@ops(foreach_pointwise_op_db)
|
|
@parametrize("is_fastpath", (True, False))
|
|
def test_pointwise_op(self, device, dtype, op, is_fastpath):
|
|
for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath):
|
|
assert isinstance(sample.args, tuple)
|
|
assert len(sample.args) == 2
|
|
inputs = [sample.input, *sample.args]
|
|
zero_size = sample.kwargs.pop("zero_size")
|
|
kwargs = sample.kwargs
|
|
disable_fastpath = kwargs.pop("disable_fastpath") if is_fastpath else False
|
|
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
|
|
values = kwargs.pop("values")
|
|
self._pointwise_test(
|
|
wrapped_op, ref, inputs, is_fastpath and not disable_fastpath, False, values=values, zero_size=zero_size
|
|
)
|
|
self._pointwise_test(
|
|
inplace_op, inplace_ref, inputs, is_fastpath and not disable_fastpath,
|
|
True, values=values, zero_size=zero_size)
|
|
|
|
if op.supports_autograd and dtype in floating_types() and not zero_size:
|
|
transformed_sample = sample.transform(get_transform_func(len(sample.input), dtype, device, is_fastpath))
|
|
tensors = transformed_sample.input
|
|
rhs_arg = transformed_sample.args
|
|
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
|
|
try:
|
|
sum(
|
|
wrapped_op([tensors, *rhs_arg], is_cuda=False, is_fastpath=False, zero_size=zero_size)
|
|
).mean().backward()
|
|
except RuntimeError:
|
|
with self.assertRaises(RuntimeError):
|
|
sum(ref([ref_tensors, *ref_rhs_arg])).mean().backward()
|
|
else:
|
|
sum(ref([ref_tensors, *ref_rhs_arg])).mean().backward()
|
|
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
|
|
for op_list, ref_list in zip(rhs_arg, ref_rhs_arg):
|
|
if isinstance(op_list, list) and isinstance(op_list[0], torch.Tensor):
|
|
self.assertEqual([t.grad for t in op_list], [t.grad for t in ref_list])
|
|
tensors = [t.clone().detach().requires_grad_().clone() for t in tensors]
|
|
ref_tensors = [t.clone().detach().requires_grad_().clone() for t in tensors]
|
|
inplace_op([tensors, *rhs_arg], is_cuda=False, is_fastpath=False, zero_size=zero_size)
|
|
assert_multiple_grad_fns(tensors, self)
|
|
inplace_ref([ref_tensors, *rhs_arg])
|
|
torch.autograd.backward(sum([t.clone() for t in tensors]).sum(), inputs=tensors)
|
|
torch.autograd.backward(sum([t.clone() for t in ref_tensors]).sum(), inputs=ref_tensors)
|
|
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
|
|
|
|
if is_fastpath and isinstance(values, list) and not zero_size:
|
|
sample = sample.transform(lambda t: t.clone().detach() if torch.is_tensor(t) else t)
|
|
inputs = [sample.input, *sample.args]
|
|
tensor_values = torch.tensor(values)
|
|
# 1D Tensor of scalars
|
|
for is_inplace, op_, ref_ in ((False, wrapped_op, ref), (True, inplace_op, inplace_ref)):
|
|
self._pointwise_test(
|
|
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
|
|
values=tensor_values, zero_size=False)
|
|
self._pointwise_test(
|
|
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
|
|
values=tensor_values[0],
|
|
custom_values_err="Expected packed scalar Tensor to be of dimension 1. Got 0 instead.",
|
|
zero_size=False,
|
|
)
|
|
if self.is_cuda:
|
|
self._pointwise_test(
|
|
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
|
|
values=tensor_values.cuda(),
|
|
custom_values_err="Expected scalars to be on CPU, got cuda:0 instead.",
|
|
zero_size=False,
|
|
)
|
|
self._pointwise_test(
|
|
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
|
|
values=tensor_values[:2],
|
|
custom_values_err=f"Expected length of scalars to match input of length {len(values)} but got 2 instead.",
|
|
zero_size=False,
|
|
)
|
|
self._pointwise_test(
|
|
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
|
|
values=torch.tensor([[0, 1], [2, 3]])[:, 1],
|
|
custom_values_err="Expected scalars to be contiguous.",
|
|
zero_size=False,
|
|
)
|
|
|
|
if not zero_size:
|
|
# Tests of implicit broadcasting
|
|
N = len(sample.input)
|
|
inputs = [
|
|
[make_tensor((N, N), device=device, dtype=dtype, noncontiguous=not is_fastpath) for _ in range(N)],
|
|
[
|
|
make_tensor((N - i, 1), device=device, dtype=dtype, noncontiguous=not is_fastpath)
|
|
for i in range(N)
|
|
],
|
|
[
|
|
make_tensor((1, N - i), device=device, dtype=dtype, noncontiguous=not is_fastpath)
|
|
for i in range(N)
|
|
],
|
|
]
|
|
self._pointwise_test(
|
|
wrapped_op, ref, inputs, is_fastpath and disable_fastpath, is_inplace=False,
|
|
values=values, zero_size=zero_size)
|
|
self._pointwise_test(
|
|
inplace_op, inplace_ref, inputs, is_fastpath and disable_fastpath,
|
|
is_inplace=True, values=values, zero_size=zero_size)
|
|
|
|
def _pointwise_test(
|
|
self,
|
|
op, ref, inputs, is_fastpath, is_inplace,
|
|
*,
|
|
values=None, custom_values_err=None, zero_size,
|
|
):
|
|
kwargs = {'zero_size': zero_size}
|
|
if zero_size:
|
|
op(inputs, self.is_cuda, is_fastpath, **kwargs)
|
|
return
|
|
ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1], inputs[2]] if is_inplace else inputs
|
|
try:
|
|
with (InplaceForeachVersionBumpCheck(self, inputs[0]) if is_inplace else nullcontext()):
|
|
actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
ref(ref_inputs)
|
|
else:
|
|
expected = ref(ref_inputs)
|
|
self.assertEqual(expected, actual)
|
|
if values is not None:
|
|
try:
|
|
actual = op(inputs + [values], self.is_cuda, is_fastpath, **kwargs)
|
|
except RuntimeError as e:
|
|
# Match with error messages from regular non-foreach reference if no
|
|
# custom error message was provided.
|
|
if custom_values_err is None:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
ref(ref_inputs, values=values)
|
|
else:
|
|
self.assertEqual(re.escape(str(e)), re.escape(custom_values_err))
|
|
else:
|
|
expected = ref(ref_inputs, values=values)
|
|
self.assertEqual(expected, actual)
|
|
|
|
# note(mkozuki): why `try-except` for both fastpath?
|
|
# - inputs for fastpath can be integer tensors.
|
|
# - this is because opinfo dtypes are configured for out-place implementation
|
|
# - for integer inputs, trigonometric functions and exponential function returns float outputs,
|
|
# which causes "result type Float can't be case to the desired type" error.
|
|
# Thus, `try-except` is used even if `is_fastpath` is `True`.
|
|
def _inplace_unary_test(self, inplace, inplace_ref, inputs, is_fastpath, **kwargs):
|
|
copied_inputs = [[t.clone().detach() for t in tensors] for tensors in inputs]
|
|
try:
|
|
with InplaceForeachVersionBumpCheck(self, inputs[0]):
|
|
inplace(inputs, self.is_cuda, is_fastpath, **kwargs)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
inplace_ref(copied_inputs)
|
|
else:
|
|
inplace_ref(copied_inputs)
|
|
self.assertEqual(copied_inputs, inputs)
|
|
|
|
@ops(foreach_unary_op_db)
|
|
@parametrize("is_fastpath", (True, False))
|
|
def test_unary_op(self, device, dtype, op, is_fastpath):
|
|
out_place_defined = op.name != "_foreach_zero"
|
|
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
|
|
samples = op.sample_inputs(device, dtype, noncontiguous=not is_fastpath)
|
|
disable_fastpath = op.name == "_foreach_abs" and dtype in complex_types()
|
|
for sample in samples:
|
|
zero_size = sample.kwargs.pop('zero_size')
|
|
inputs = [sample.input]
|
|
if zero_size:
|
|
if out_place_defined:
|
|
wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size)
|
|
inplace_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size)
|
|
continue
|
|
inputs = [sample.input]
|
|
disable_fastpath = (op.name == "_foreach_abs" and dtype in complex_types()) or sample.kwargs.pop(
|
|
"disable_fastpath"
|
|
)
|
|
if out_place_defined:
|
|
self.assertEqual(
|
|
ref(inputs),
|
|
wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, zero_size=zero_size),
|
|
)
|
|
self._inplace_unary_test(
|
|
inplace_op, inplace_ref, [sample.input], is_fastpath and not disable_fastpath, zero_size=zero_size
|
|
)
|
|
if op.supports_autograd and dtype in floating_types() and not zero_size:
|
|
tensors = [t.clone().detach().requires_grad_() for t in sample.input]
|
|
ref_tensors = [t.clone().detach().requires_grad_() for t in tensors]
|
|
if out_place_defined:
|
|
out = wrapped_op.func(tensors)
|
|
# tensors have different shapes
|
|
torch.cat([t.view(-1) for t in out]).mean().backward()
|
|
torch.cat([ref.func(t).view(-1) for t in ref_tensors]).mean().backward()
|
|
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
|
|
self.assertEqual(len({t.grad_fn for t in out}), 1)
|
|
|
|
inplace_input_tensors = [t.clone().detach().requires_grad_() for t in tensors]
|
|
inplace_inputs = [t.clone() for t in inplace_input_tensors]
|
|
# set both to False to skip multi_tensor_apply_kernel check
|
|
inplace_op([inplace_inputs], False, False, zero_size=zero_size)
|
|
assert_multiple_grad_fns(inplace_inputs, self)
|
|
|
|
# per-tensor `grad_fn` check.
|
|
hook_buffer = []
|
|
|
|
def get_grad_fn_hook(i):
|
|
|
|
def hook(grad_inputs, grad_outputs) -> None:
|
|
hook_buffer.append(i)
|
|
|
|
return hook
|
|
|
|
for i, t in enumerate(inplace_inputs):
|
|
t.grad_fn.register_hook(get_grad_fn_hook(i))
|
|
|
|
_ = torch.autograd.grad(
|
|
inplace_inputs[0],
|
|
inputs=(inplace_input_tensors[0],),
|
|
grad_outputs=(torch.rand_like(inplace_inputs[0]),),
|
|
retain_graph=True,
|
|
)
|
|
self.assertEqual(hook_buffer, [0])
|
|
hook_buffer.clear()
|
|
|
|
# tensors have different shapes.
|
|
sum_of_cloned_tensors = torch.cat([t.view(-1) for t in inplace_inputs]).sum()
|
|
grad_output = torch.rand_like(sum_of_cloned_tensors)
|
|
grad_inputs = torch.autograd.grad(
|
|
sum_of_cloned_tensors,
|
|
inputs=tuple(inplace_input_tensors),
|
|
grad_outputs=(grad_output,),
|
|
retain_graph=False,
|
|
)
|
|
self.assertEqual(hook_buffer, list(reversed(range(len(inplace_inputs)))))
|
|
|
|
ref_inplace_input_tensors = [t.clone().detach().requires_grad_() for t in inplace_input_tensors]
|
|
ref_inplace_inputs = [t.clone() for t in ref_inplace_input_tensors]
|
|
ref_output = inplace_ref([ref_inplace_inputs])
|
|
ref_grad_inputs = torch.autograd.grad(
|
|
torch.cat([t.view(-1) for t in ref_output]).sum(),
|
|
inputs=tuple(ref_inplace_input_tensors),
|
|
grad_outputs=(grad_output,),
|
|
)
|
|
self.assertEqual(grad_inputs, ref_grad_inputs)
|
|
|
|
@ops(foreach_reduce_op_db)
|
|
@parametrize("is_fastpath", (True, False))
|
|
def test_reduce_op(self, device, dtype, op, is_fastpath):
|
|
for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath):
|
|
ord = sample.kwargs.pop("ord")
|
|
zero_size = sample.kwargs.pop("zero_size")
|
|
disable_fastpath = sample.kwargs.pop("disable_fastpath", False)
|
|
|
|
inputs = (sample.input,)
|
|
wrapped_op, ref, _, _ = self._get_funcs(op)
|
|
|
|
self.assertEqual(
|
|
ref(inputs, ord=ord),
|
|
wrapped_op(
|
|
inputs, self.is_cuda, is_fastpath and not disable_fastpath, ord=ord,
|
|
zero_size=zero_size,
|
|
),
|
|
)
|
|
if op.supports_autograd and dtype in floating_types() and not zero_size:
|
|
transformed_sample = sample.transform(get_transform_func(len(sample.input), dtype, device, is_fastpath))
|
|
tensors = transformed_sample.input
|
|
ref_tensors = clone(tensors)
|
|
sum(wrapped_op((tensors,), False, False, ord=ord, zero_size=zero_size)).backward()
|
|
sum(ref((ref_tensors,), ord=ord)).backward()
|
|
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
|
|
|
|
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
|
def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
|
|
# TODO: enable empty list case
|
|
for tensors in [[torch.randn([0])]]:
|
|
res = torch._foreach_add(tensors, 1)
|
|
self.assertEqual(res, tensors)
|
|
|
|
torch._foreach_add_(tensors, 1)
|
|
self.assertEqual(res, tensors)
|
|
|
|
@ops(foreach_binary_op_db, dtypes=OpDTypes.supported)
|
|
def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op):
|
|
foreach_op, ref = op.method_variant, op.ref
|
|
tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
|
|
|
|
if ref == torch.sub and dtype == torch.bool:
|
|
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
|
|
[ref(t, 1) for t in tensors]
|
|
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
|
|
foreach_op(tensors, 1)
|
|
return
|
|
|
|
expected = [ref(t, 1) for t in tensors]
|
|
res = foreach_op(tensors, 1)
|
|
self.assertEqual(res, expected)
|
|
|
|
@ops(foreach_binary_op_db, allowed_dtypes=[torch.float])
|
|
def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
|
|
foreach_op = op.method_variant
|
|
tensors = [
|
|
torch.tensor([1.1], dtype=torch.float, device=device),
|
|
torch.tensor([1], dtype=torch.long, device=device),
|
|
]
|
|
runtime_error = None
|
|
try:
|
|
foreach_op(tensors, 1)
|
|
except RuntimeError as e:
|
|
runtime_error = e
|
|
self.assertIsNone(runtime_error)
|
|
|
|
@skipIfTorchDynamo("Different error msgs, TODO")
|
|
@ops(foreach_binary_op_db, dtypes=OpDTypes.supported)
|
|
def test_binary_op_list_error_cases(self, device, dtype, op):
|
|
foreach_op, foreach_op_, ref, ref_ = op.method_variant, op.inplace_variant, op.ref, op.ref_inplace
|
|
tensors1 = []
|
|
tensors2 = []
|
|
|
|
# Empty lists
|
|
with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"):
|
|
foreach_op(tensors1, tensors2)
|
|
with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"):
|
|
foreach_op_(tensors1, tensors2)
|
|
|
|
# One empty list
|
|
tensors1.append(torch.tensor([1], device=device, dtype=dtype))
|
|
with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."):
|
|
foreach_op(tensors1, tensors2)
|
|
with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."):
|
|
foreach_op_(tensors1, tensors2)
|
|
|
|
# Lists have different amount of tensors
|
|
tensors2.append(torch.tensor([1], device=device))
|
|
tensors2.append(torch.tensor([1], device=device))
|
|
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
|
|
foreach_op(tensors1, tensors2)
|
|
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
|
|
foreach_op_(tensors1, tensors2)
|
|
|
|
# Corresponding tensors with different sizes that aren't compatible with broadcast
|
|
# If sizes are different then foreach chooses slow path, thus error messages are expected
|
|
# to be the same as torch regular function.
|
|
tensors1 = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
|
tensors2 = [torch.ones(11, 11, device=device, dtype=dtype) for _ in range(10)]
|
|
try:
|
|
foreach_op(tensors1, tensors2)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
[ref(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
|
|
try:
|
|
foreach_op_(tensors1, tensors2)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
[ref_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
|
|
|
|
# different devices
|
|
if self.device_type == "cuda" and torch.cuda.device_count() > 1:
|
|
tensor1 = torch.zeros(10, 10, device="cuda:0", dtype=dtype)
|
|
tensor2 = torch.ones(10, 10, device="cuda:1", dtype=dtype)
|
|
if dtype == torch.bool and foreach_op == torch._foreach_sub:
|
|
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
|
|
foreach_op([tensor1], [tensor2])
|
|
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
|
|
foreach_op_([tensor1], [tensor2])
|
|
return
|
|
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
|
|
foreach_op([tensor1], [tensor2])
|
|
if dtype in integral_types_and(torch.bool) and foreach_op == torch._foreach_div:
|
|
with self.assertRaisesRegex(RuntimeError, "result type"):
|
|
foreach_op_([tensor1], [tensor2])
|
|
else:
|
|
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
|
|
foreach_op_([tensor1], [tensor2])
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
|
|
@ops(foreach_binary_op_db, dtypes=OpDTypes.supported)
|
|
def test_binary_op_list_slow_path(self, device, dtype, op):
|
|
foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op)
|
|
# 0-strides
|
|
tensor1 = make_tensor((10, 10), dtype=dtype, device=device)
|
|
tensor2 = make_tensor((1,), device=device, dtype=dtype).expand_as(tensor1)
|
|
inputs = ([tensor1], [tensor2])
|
|
self._binary_test(
|
|
dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
|
|
zero_size=False, alpha=None, scalar_self_arg=False)
|
|
self._binary_test(
|
|
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
|
|
zero_size=False, alpha=None, scalar_self_arg=False)
|
|
|
|
# different strides
|
|
tensor1 = torch.zeros(10, 10, device=device, dtype=dtype)
|
|
tensor2 = torch.ones(10, 10, device=device, dtype=dtype)
|
|
inputs = ([tensor1], [tensor2.t()])
|
|
self._binary_test(
|
|
dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
|
|
zero_size=False, alpha=None, scalar_self_arg=False)
|
|
self._binary_test(
|
|
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
|
|
zero_size=False, alpha=None, scalar_self_arg=False)
|
|
|
|
# non contiguous
|
|
tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True)
|
|
tensor2 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True)
|
|
self.assertFalse(tensor1.is_contiguous())
|
|
self.assertFalse(tensor2.is_contiguous())
|
|
inputs = ([tensor1], [tensor2])
|
|
self._binary_test(
|
|
dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
|
|
zero_size=False, alpha=None, scalar_self_arg=False)
|
|
self._binary_test(
|
|
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
|
|
zero_size=False, alpha=None, scalar_self_arg=False)
|
|
|
|
# sliced tensor
|
|
tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype)
|
|
tensor2 = make_tensor((5, 2, 1, 3 * 7), device=device, dtype=dtype)[:, :, :, ::7]
|
|
inputs = ([tensor1], [tensor2])
|
|
self._binary_test(
|
|
dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
|
|
zero_size=False, alpha=None, scalar_self_arg=False)
|
|
self._binary_test(
|
|
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
|
|
zero_size=False, alpha=None, scalar_self_arg=False)
|
|
|
|
@ops(foreach_binary_op_db, dtypes=floating_types_and(torch.half, torch.bfloat16))
|
|
def test_binary_op_float_inf_nan(self, device, dtype, op):
|
|
inputs = (
|
|
[
|
|
torch.tensor([float("inf")], device=device, dtype=dtype),
|
|
torch.tensor([-float("inf")], device=device, dtype=dtype),
|
|
torch.tensor([float("nan")], device=device, dtype=dtype),
|
|
torch.tensor([float("nan")], device=device, dtype=dtype),
|
|
],
|
|
[
|
|
torch.tensor([-float("inf")], device=device, dtype=dtype),
|
|
torch.tensor([float("inf")], device=device, dtype=dtype),
|
|
torch.tensor([float("inf")], device=device, dtype=dtype),
|
|
torch.tensor([float("nan")], device=device, dtype=dtype),
|
|
],
|
|
)
|
|
op, ref, inplace_op, inplace_ref = self._get_funcs(op)
|
|
self._binary_test(dtype, op, ref, inputs, True, False, zero_size=False, alpha=None, scalar_self_arg=False)
|
|
self._binary_test(
|
|
dtype, inplace_op, inplace_ref, inputs, True, True, zero_size=False, alpha=None, scalar_self_arg=False
|
|
)
|
|
|
|
# note: Below three tests (postfixed with `_tensors_on_different_devices`)
|
|
# checks whether foreach works with lists of tensors on different devices
|
|
# but tensors of the same index are on the same device, e.g., ['cuda', 'cpu].
|
|
@onlyCUDA
|
|
@ops(foreach_unary_op_db)
|
|
def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
|
|
out_place_defined = op.name != "_foreach_zero"
|
|
method, ref, inplace_method, ref_inplace = self._get_funcs(op)
|
|
# tensors: ['cuda', 'cpu]
|
|
tensors = list(op.sample_inputs(device, dtype, num_input_tensors=[2]))[0].input
|
|
tensors[1] = tensors[1].to("cpu")
|
|
if out_place_defined:
|
|
try:
|
|
actual = method((tensors,), False, False, zero_size=False)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), str(e)):
|
|
ref((tensors,))
|
|
else:
|
|
expected = ref((tensors,))
|
|
self.assertEqual(expected, actual)
|
|
|
|
try:
|
|
inplace_method((tensors,), False, False, zero_size=False)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), str(e)):
|
|
ref_inplace((tensors,))
|
|
else:
|
|
if out_place_defined:
|
|
self.assertEqual(expected, tensors)
|
|
else:
|
|
self.assertEqual([torch.zeros_like(t) for t in tensors], tensors)
|
|
|
|
@onlyCUDA
|
|
@ops(foreach_binary_op_db)
|
|
def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
|
|
# `tensors1`: ['cuda', 'cpu']
|
|
# `tensors2`: ['cuda', 'cpu']
|
|
_cuda_tensors = list(op.sample_inputs(device, dtype, num_input_tensors=[2], same_size=True))[0].input
|
|
_cpu_tensors = list(op.sample_inputs("cpu", dtype, num_input_tensors=[2], same_size=True))[0].input
|
|
tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors))
|
|
|
|
foreach_op, foreach_op_ = op.method_variant, op.inplace_variant
|
|
native_op, native_op_ = op.ref, op.ref_inplace
|
|
try:
|
|
actual = foreach_op(tensors1, tensors2)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
[native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
|
|
else:
|
|
expected = [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
|
|
self.assertEqual(expected, actual)
|
|
try:
|
|
foreach_op_(tensors1, tensors2)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
[native_op_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
|
|
else:
|
|
self.assertEqual(actual, tensors1)
|
|
|
|
@onlyCUDA
|
|
@ops(foreach_pointwise_op_db, allowed_dtypes=floating_types())
|
|
def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op):
|
|
# tensors1: ['cuda', 'cpu]
|
|
# tensors2: ['cuda', 'cpu]
|
|
# tensors3: ['cuda', 'cpu]
|
|
# first tensorlist is zero-size when float32
|
|
_cuda_tensors = list(
|
|
op.sample_inputs(device, dtype, num_input_tensors=[3], same_size=True)
|
|
)[int(dtype == torch.float32)].input
|
|
_cpu_tensors = list(op.sample_inputs("cpu", dtype, num_input_tensors=[3], same_size=True))[0].input
|
|
tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors))
|
|
|
|
foreach_op, foreach_op_, native_op = op.method_variant, op.inplace_variant, op.ref
|
|
actual = foreach_op(tensors1, tensors2, tensors3)
|
|
expected = [native_op(*_cuda_tensors), native_op(*_cpu_tensors)]
|
|
self.assertEqual(expected, actual)
|
|
|
|
# note(mkozuki): Limiting dtypes to FP32&FP64, we can safely run inplace ops.
|
|
foreach_op_(tensors1, tensors2, tensors3)
|
|
self.assertEqual(expected, tensors1)
|
|
|
|
# note: BFloat16 has the same number of exponent bits as FP32
|
|
# so if squared L2 norm overflows in BF16, then it also overflows in FP32.
|
|
@onlyCUDA
|
|
@ops(foreach_reduce_op_db, allowed_dtypes=(torch.half, torch.bfloat16))
|
|
def test_foreach_l2_large_value_input(self, device, dtype, op):
|
|
ord, N = 2, 10
|
|
max_value = torch.finfo(dtype).max
|
|
scaler = torch.tensor([max_value]).sqrt().to(device=device, dtype=dtype)
|
|
inputs = ([
|
|
t * scaler for t in list(
|
|
op.sample_inputs(device, dtype, requries_grad=True, num_input_tensors=[N], low=1)
|
|
)[0].input
|
|
],)
|
|
# make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`.
|
|
self.assertTrue(scaler * scaler * N > max_value)
|
|
fn, ref_fn, *_ = self._get_funcs(op)
|
|
actual = fn(inputs, is_cuda=True, is_fastpath=True, ord=ord, zero_size=False)
|
|
expect = ref_fn(inputs, ord=ord)
|
|
if dtype == torch.float16:
|
|
# making sure the reference L2 norm values are in the range of FP16.
|
|
self.assertFalse(any(torch.isinf(e) for e in expect))
|
|
else:
|
|
self.assertTrue(all(torch.isinf(e) for e in expect))
|
|
self.assertEqual(expect, actual, equal_nan=False)
|
|
|
|
@parametrize("is_fastpath", (True, False))
|
|
@ops(foreach_lerp_op_db)
|
|
def test_lerp(self, device, dtype, op, is_fastpath):
|
|
for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath):
|
|
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
|
|
args = [*sample.args]
|
|
inputs = [sample.input, args[0]]
|
|
zero_size = sample.kwargs.pop("zero_size")
|
|
|
|
kwargs, ref_kwargs = {"zero_size": zero_size}, {}
|
|
if isinstance(args[1], list):
|
|
inputs.append(args[1])
|
|
else:
|
|
kwargs["weight"] = args[1]
|
|
ref_kwargs["weight"] = args[1]
|
|
|
|
if dtype in integral_types() or dtype == torch.bool or (not self.is_cuda and dtype == torch.half):
|
|
with self.assertRaises(RuntimeError):
|
|
wrapped_op(inputs, self.is_cuda, is_fastpath, **kwargs)
|
|
return
|
|
actual = wrapped_op(inputs, self.is_cuda, is_fastpath, **kwargs)
|
|
expected = ref(inputs, **ref_kwargs)
|
|
self.assertEqual(actual, expected)
|
|
|
|
inplace_inputs = [[t.clone() for t in inputs[0]]] + inputs[1:]
|
|
with InplaceForeachVersionBumpCheck(self, inplace_inputs[0]):
|
|
inplace_actual = inplace_op(inplace_inputs, self.is_cuda, is_fastpath, **kwargs)
|
|
self.assertEqual(inplace_actual, expected)
|
|
|
|
if op.supports_autograd and dtype in floating_types() and not zero_size:
|
|
transformed_sample = sample.transform(get_transform_func(len(sample.input), dtype, device, is_fastpath))
|
|
args = [*transformed_sample.args]
|
|
inputs = [transformed_sample.input, args[0]]
|
|
|
|
kwargs, ref_kwargs = {}, {}
|
|
if isinstance(args[1], list):
|
|
inputs.append(args[1])
|
|
else:
|
|
kwargs = ref_kwargs = {"weight": args[1]}
|
|
ref_tensors = clone(transformed_sample.input)
|
|
sum(
|
|
wrapped_op((transformed_sample.input, *inputs[1:]), False, False, **kwargs, zero_size=zero_size)
|
|
).mean().backward()
|
|
sum(ref((ref_tensors, *inputs[1:]), **ref_kwargs)).mean().backward()
|
|
self.assertEqual(
|
|
[t.grad for t in transformed_sample.input],
|
|
[t.grad for t in ref_tensors],
|
|
)
|
|
_tensors = [t.clone().detach().requires_grad_() for t in transformed_sample.input]
|
|
_ref_tensors = [t.clone().detach().requires_grad_() for t in _tensors]
|
|
tensors = [t.clone() for t in _tensors]
|
|
inplace_op((tensors, *inputs[1:]), False, False, **kwargs, zero_size=False)
|
|
ref_tensors = [t.clone() for t in _ref_tensors]
|
|
inplace_ref((ref_tensors, *inputs[1:]), **ref_kwargs)
|
|
assert_multiple_grad_fns(tensors, self)
|
|
|
|
# tensors have different shapes.
|
|
torch.autograd.backward(torch.cat([t.clone().view(-1) for t in tensors]).sum(), inputs=tensors)
|
|
torch.autograd.backward(torch.cat([t.clone().view(-1) for t in ref_tensors]).sum(), inputs=ref_tensors)
|
|
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
|
|
|
|
@onlyCUDA
|
|
@ops(foreach_reduce_op_db)
|
|
def test_foreach_reduce_large_input(self, device, dtype, op):
|
|
# test inputs larger than kChunkSize = 65536
|
|
ord, N = 2, 65536 * 2
|
|
disable_fastpath = True
|
|
if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
|
|
disable_fastpath = False
|
|
inputs = ([make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)],)
|
|
wrapped_op, ref, _, _ = self._get_funcs(op)
|
|
self.assertEqual(
|
|
ref(inputs, ord=ord),
|
|
wrapped_op(inputs, self.is_cuda, not disable_fastpath, ord=ord, zero_size=False),
|
|
)
|
|
|
|
@onlyCUDA
|
|
@ops(
|
|
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
|
|
dtypes=(torch.float,),
|
|
)
|
|
def test_inplace_foreach_leaf_check_and_grad_fn(self, device, dtype, op):
|
|
inplace_op = op.inplace_variant
|
|
if inplace_op is None:
|
|
self.skipTest("no in-place op available")
|
|
|
|
sample = list(op.sample_inputs(dtype=dtype, device=device, num_input_tensors=[2], same_size=True))[0]
|
|
sample.input[0].requires_grad_(True)
|
|
with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
|
|
inplace_op(sample.input, *sample.args)
|
|
sample.input[1].requires_grad_(True)
|
|
with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
|
|
inplace_op(sample.input, *sample.args)
|
|
|
|
_tensors = [t.clone().detach().requires_grad_(i == 0) for i, t in enumerate(sample.input)]
|
|
tensors = [t.clone() for t in _tensors]
|
|
inplace_op(tensors, *sample.args)
|
|
self.assertIsNotNone(tensors[0].grad_fn)
|
|
self.assertIsNone(tensors[1].grad_fn)
|
|
|
|
@onlyCUDA
|
|
@ops(
|
|
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
|
|
dtypes=(torch.float,),
|
|
)
|
|
def test_outplace_with_invalid_grads(self, device, dtype, op):
|
|
if op.name in {"_foreach_zero"}:
|
|
self.skipTest(f"{op.name} does not have out-place implementation")
|
|
func, *_ = self._get_funcs(op)
|
|
sample = list(op.sample_inputs(dtype=dtype, device=device, requires_grad=True, num_input_tensors=[2], same_size=True))[0]
|
|
self.assertTrue(all(t.requires_grad for t in sample.input))
|
|
sample.kwargs.pop("disable_fastpath")
|
|
if func.func in (torch._foreach_addcmul, torch._foreach_addcdiv):
|
|
if sample.kwargs.get("values") is None:
|
|
sample.kwargs.pop("values")
|
|
(out1, out2) = func([sample.input, *sample.args], is_cuda=False, is_fastpath=False, **sample.kwargs)
|
|
out1.backward(torch.ones_like(out1))
|
|
self.assertIsNotNone(sample.input[0].grad)
|
|
self.assertIsNone(sample.input[1].grad)
|
|
|
|
@ops(
|
|
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
|
|
dtypes=OpDTypes.supported,
|
|
allowed_dtypes=(torch.float64, torch.complex128),
|
|
)
|
|
def test_inplace_forward_mode_AD(self, device, dtype, op):
|
|
if not op.supports_forward_ad:
|
|
self.skipTest("forward AD not supported")
|
|
|
|
# note(crcrpar): The combinations below are failing in its forward path,
|
|
# which is before forward-mode AD happens. This function gates the combinations where
|
|
# - subtraction with Scalar/ScalarList of boolean value:
|
|
# - combinations where the in-place op in questions tries to write out complex result
|
|
# into float storage (= `self`)
|
|
def check_sample_eligibility(op, sample, dtype):
|
|
if (
|
|
op.name == "_foreach_sub"
|
|
and (
|
|
(isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
|
|
or isinstance(sample.args[0], bool)
|
|
)
|
|
):
|
|
return False, _BOOL_SUB_ERR_MSG
|
|
rhs_arg_has_complex_number = sample.args and ((
|
|
isinstance(sample.args[0], list)
|
|
and any(isinstance(a, complex) for a in sample.args[0])
|
|
) or (
|
|
isinstance(sample.args[0], complex)
|
|
))
|
|
if dtype == torch.float64 and rhs_arg_has_complex_number:
|
|
if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
|
|
return False, "result type ComplexDouble can't be cast to the desired output type Double"
|
|
if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"):
|
|
return False, "clamp is not supported for complex types"
|
|
if op.name == "_foreach_pow":
|
|
return False, "Found dtype Double but expected ComplexDouble"
|
|
|
|
return True, ""
|
|
|
|
for sample in op.sample_inputs(
|
|
device, dtype, requires_grad=True, num_input_tensors=[5], same_size=True,
|
|
):
|
|
# Call `clone` to avoid inplace modifications likewise
|
|
# `torch.testing._internal.common_utils.TestGradients._get_safe_inplace`
|
|
def inplace_func(*tensorlist):
|
|
kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {}
|
|
op.inplace_variant(tuple(t.clone() for t in tensorlist), *sample.args, **kwargs)
|
|
return tensorlist
|
|
|
|
working_sample, err_msg_pattern = check_sample_eligibility(op, sample, dtype)
|
|
if not working_sample:
|
|
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
|
|
gradcheck(
|
|
inplace_func,
|
|
sample.input,
|
|
raise_exception=True,
|
|
check_forward_ad=True,
|
|
check_backward_ad=False,
|
|
check_batched_grad=False,
|
|
)
|
|
else:
|
|
gradcheck(
|
|
inplace_func,
|
|
sample.input,
|
|
raise_exception=True,
|
|
check_forward_ad=True,
|
|
check_backward_ad=False,
|
|
check_batched_grad=False,
|
|
)
|
|
|
|
@unittest.skipIf(not (torch.cuda.is_available() and torch.cuda.device_count() > 1), "requires multiple GPUs")
|
|
def test_tensors_grouping(self):
|
|
num_tensors_per_list = 10
|
|
num_devices = torch.cuda.device_count()
|
|
dtypes = (torch.float16, torch.float32, torch.float64)
|
|
list1 = [
|
|
torch.tensor(
|
|
i,
|
|
device=torch.device("cuda", random.randint(0, num_devices - 1)),
|
|
dtype=dtypes[random.randint(0, 2)],
|
|
) for i in range(num_tensors_per_list)
|
|
]
|
|
list2 = [None for _ in list1]
|
|
list3 = [torch.rand_like(t) for t in list1]
|
|
nested_tensorlists = [list1, list2, list3]
|
|
grouped_tensors = torch.utils._foreach_utils._group_tensors_by_device_and_dtype(nested_tensorlists, with_indices=True)
|
|
num_tensors_seen = 0
|
|
for (device, dtype), ([l1, l2, l3], indices) in grouped_tensors.items():
|
|
for t in itertools.chain(l1, l3):
|
|
self.assertEquals(t.device, device)
|
|
self.assertEquals(t.dtype, dtype)
|
|
num_tensors_seen += 1
|
|
self.assertEqual(len(l1), len(l2))
|
|
self.assertTrue(all(p is None for p in l2))
|
|
for i, index in enumerate(indices):
|
|
self.assertEquals(l1[i], list1[index])
|
|
self.assertEquals(l2[i], list2[index])
|
|
self.assertEquals(l3[i], list3[index])
|
|
self.assertEquals(num_tensors_seen, 2 * num_tensors_per_list)
|
|
|
|
|
|
instantiate_device_type_tests(TestForeach, globals())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|