[mta] Backward of unary foreach functions (#89591)

as per title, this PR defines backward of those.

This doesn't implement forward-mode automatic differentiation as [the current codegen](a747326423/tools/autograd/gen_variable_type.py (L1513)) doesn't seem to handle `ArrayRef<Tensor>`.

Rel:
- https://github.com/pytorch/pytorch/issues/53796
- https://github.com/pytorch/pytorch/issues/58833

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89591
Approved by: https://github.com/albanD
This commit is contained in:
Masaki Kozuki 2023-01-23 08:28:06 +00:00 committed by PyTorch MergeBot
parent 32b2d8009a
commit 30876229a7
10 changed files with 292 additions and 16 deletions

View File

@ -7,6 +7,14 @@ namespace at { namespace native {
namespace {
// TODO(crcrpar): Handle version bump in codegen.
// rel: https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482
inline void increment_version(TensorList tensors) {
for (const auto & t : tensors) {
t.unsafeGetTensorImpl()->bump_version();
}
}
// Initializes args and checks if all args are aligned
template<int depth, typename T>
__device__ bool init_args(

View File

@ -73,6 +73,7 @@ template <typename scalar_t, template<class> class Op> void foreach_unary_op_(Te
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
Op<opmath_t>());
increment_version(tensors);
}
template <template<class> class Op>

View File

@ -597,6 +597,73 @@ BLAS and LAPACK Operations
triangular_solve
vdot
Foreach Operations
~~~~~~~~~~~~~~~~~~
.. warning::
This API is in beta and subject to future changes.
Forward-mode AD is not supported.
.. autosummary::
:toctree: generated
:nosignatures:
_foreach_abs
_foreach_abs_
_foreach_acos
_foreach_acos_
_foreach_asin
_foreach_asin_
_foreach_atan
_foreach_atan_
_foreach_ceil
_foreach_ceil_
_foreach_cos
_foreach_cos_
_foreach_cosh
_foreach_cosh_
_foreach_erf
_foreach_erf_
_foreach_erfc
_foreach_erfc_
_foreach_exp
_foreach_exp_
_foreach_expm1
_foreach_expm1_
_foreach_floor
_foreach_floor_
_foreach_log
_foreach_log_
_foreach_log10
_foreach_log10_
_foreach_log1p
_foreach_log1p_
_foreach_log2
_foreach_log2_
_foreach_neg
_foreach_neg_
_foreach_tan
_foreach_tan_
_foreach_sin
_foreach_sin_
_foreach_sinh
_foreach_sinh_
_foreach_round
_foreach_round_
_foreach_sqrt
_foreach_sqrt_
_foreach_lgamma
_foreach_lgamma_
_foreach_frac
_foreach_frac_
_foreach_reciprocal
_foreach_reciprocal_
_foreach_sigmoid
_foreach_sigmoid_
_foreach_trunc
_foreach_trunc_
_foreach_zero_
Utilities
----------------------------------
.. autosummary::

View File

@ -222,6 +222,24 @@ class TestForeach(TestCase):
inplace_ref(copied_inputs),
self.assertEqual(copied_inputs, inputs)
def _test_unary(self, device, dtype, opinfo, N, is_fastpath):
op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, 1)
inputs = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
# note(mkozuki): Complex inputs for `_foreach_abs` go through slowpath.
if opinfo.name == "_foreach_abs" and dtype in complex_types():
is_fastpath = False
self._regular_unary_test(dtype, op, ref, inputs, is_fastpath)
self._inplace_unary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath)
if opinfo.supports_autograd and dtype in floating_types():
tensors = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath, same_size=True)
tensors = [t.requires_grad_() for t in tensors]
ref_tensors = [t.clone().detach().requires_grad_() for t in tensors]
sum(op.func(tensors)).mean().backward()
sum([ref.func(t) for t in ref_tensors]).mean().backward()
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
@skipMeta
@ops(foreach_unary_op_db)
@parametrize("is_fastpath", (True, False))

View File

@ -42,7 +42,7 @@
# to that argument could exist. You should either:
# - Specify the formula for that gradient
# - Specify not_implemented("function_name") as a formula to say that this is not
# implement yet (but might be in the future and the user can request that on an issue)
# implemented yet (but might be in the future and the user can request that on an issue)
# - If that argument is not differentiable, because it is not a floating point dtype or the
# function is not differentiable with respect to that argument for
# example. You should either:

View File

@ -98,6 +98,23 @@ if (task_should_compute_output({ ${name}_ix })) {
"""
)
# note(crcrpar): `self` argument and other optional positional argument
# of foreach functions are basically a list of n `Tensor`s thus iterating over
# `grads` in order to utilize and apply the existing derivative definitions
# to each `Tensor`(s) of `self`, and the others.
DERIVATIVE_SINGLE_FOREACH = CodeTemplate(
"""\
if (task_should_compute_output({ ${name}_ix })) {
std::vector<Tensor> grad_result;
grad_result.reserve(grads.size());
for (const auto & i : c10::irange(grads.size())) {
grad_result.emplace_back(${derivative});
}
copy_range(grad_inputs, ${name}_ix, grad_result);
}
"""
)
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
"""\
if (task_should_compute_output({ ${name}_ix })) {
@ -709,9 +726,13 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
) in ("Tensor", "Tensor?"):
formula = "any_grad_defined ? (" + formula + ") : Tensor()"
checks_any_grad_defined = True
if info.name.startswith("_foreach_"):
derivative_template = DERIVATIVE_SINGLE_FOREACH
else:
derivative_template = DERIVATIVE_SINGLE
return (
checks_any_grad_defined,
DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula),
derivative_template.substitute(name=var_names[0], derivative=formula),
)
else:
if "grad_input_mask" in formula:

View File

@ -14003,3 +14003,59 @@ Performs the same operation as :func:`torch.alias`, but all output tensors
are freshly created instead of aliasing the input.
""",
)
for unary_base_func_name in (
"exp",
"sqrt",
"abs",
"acos",
"asin",
"atan",
"ceil",
"cos",
"cosh",
"erf",
"erfc",
"expm1",
"floor",
"log",
"log10",
"log1p",
"log2",
"neg",
"tan",
"tanh",
"sin",
"sinh",
"round",
"lgamma",
"frac",
"reciprocal",
"sigmoid",
"trunc",
"zero",
):
unary_foreach_func_name = f"_foreach_{unary_base_func_name}"
if hasattr(torch, unary_foreach_func_name):
add_docstr(
getattr(torch, unary_foreach_func_name),
r"""
{}(self: List[Tensor]) -> List[Tensor]
Apply :func:`torch.{}` to each Tensor of the input list.
""".format(
unary_foreach_func_name, unary_base_func_name
),
)
unary_inplace_foreach_func_name = f"{unary_foreach_func_name}_"
if hasattr(torch, unary_inplace_foreach_func_name):
add_docstr(
getattr(torch, unary_inplace_foreach_func_name),
r"""
{}(self: List[Tensor]) -> None
Apply :func:`torch.{}` to each Tensor of the input list.
""".format(
unary_inplace_foreach_func_name, unary_base_func_name
),
)

View File

@ -8074,25 +8074,26 @@ class foreach_pointwise_sample_func(foreach_inputs_sample_func):
foreach_unary_op_db: List[OpInfo] = [
ForeachFuncInfo('exp', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
ForeachFuncInfo('acos', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
ForeachFuncInfo('asin', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
ForeachFuncInfo('atan', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
ForeachFuncInfo('cos', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
ForeachFuncInfo('cosh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
ForeachFuncInfo('log', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
ForeachFuncInfo('log10', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
ForeachFuncInfo('log2', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
ForeachFuncInfo('tan', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
ForeachFuncInfo('tanh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
ForeachFuncInfo('sin', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
ForeachFuncInfo('sinh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
ForeachFuncInfo('exp', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo('acos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo('asin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo('atan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo('cos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo('cosh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo('log', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo('log10', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo('log2', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo('tan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo('tanh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo('sin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo('sinh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
ForeachFuncInfo(
'neg',
dtypes=all_types_and_complex(),
dtypesIfCUDA=all_types_and_complex(),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
ForeachFuncInfo(
@ -8100,6 +8101,7 @@ foreach_unary_op_db: List[OpInfo] = [
dtypes=floating_and_complex_types_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
ForeachFuncInfo(
@ -8107,6 +8109,7 @@ foreach_unary_op_db: List[OpInfo] = [
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
ForeachFuncInfo(
@ -8114,6 +8117,7 @@ foreach_unary_op_db: List[OpInfo] = [
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
ForeachFuncInfo(
@ -8121,6 +8125,7 @@ foreach_unary_op_db: List[OpInfo] = [
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
ForeachFuncInfo(
@ -8128,6 +8133,7 @@ foreach_unary_op_db: List[OpInfo] = [
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
ForeachFuncInfo(
@ -8135,6 +8141,7 @@ foreach_unary_op_db: List[OpInfo] = [
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
ForeachFuncInfo(
@ -8142,6 +8149,7 @@ foreach_unary_op_db: List[OpInfo] = [
dtypes=floating_and_complex_types_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
ForeachFuncInfo(
@ -8149,6 +8157,7 @@ foreach_unary_op_db: List[OpInfo] = [
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
ForeachFuncInfo(
@ -8156,6 +8165,7 @@ foreach_unary_op_db: List[OpInfo] = [
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
ForeachFuncInfo(
@ -8163,6 +8173,7 @@ foreach_unary_op_db: List[OpInfo] = [
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
ForeachFuncInfo(
@ -8170,6 +8181,7 @@ foreach_unary_op_db: List[OpInfo] = [
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
ForeachFuncInfo(
@ -8177,6 +8189,7 @@ foreach_unary_op_db: List[OpInfo] = [
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
ForeachFuncInfo(
@ -8186,6 +8199,7 @@ foreach_unary_op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
]

View File

@ -2571,6 +2571,7 @@ class ForeachFuncInfo(OpInfo):
dtypesIfROCM=None,
supports_alpha_param=False,
sample_inputs_func=sample_inputs_foreach,
supports_autograd=False,
**kwargs,
):
super().__init__(
@ -2579,6 +2580,7 @@ class ForeachFuncInfo(OpInfo):
dtypesIfCUDA=dtypesIfCUDA,
dtypesIfROCM=dtypesIfROCM,
sample_inputs_func=sample_inputs_func,
supports_autograd=supports_autograd,
**kwargs,
)

View File

@ -1,9 +1,10 @@
import copy
import re
from dataclasses import dataclass
from typing import Dict, List, Match, Optional, Sequence, Set, Tuple
from torchgen.api import cpp
from torchgen.api.types import Binding, NamedCType
from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
from torchgen.model import (
FunctionSchema,
NativeFunction,
@ -357,6 +358,94 @@ Attempted to convert a derivative formula for a mutable operator
this is not currently supported (we'd need to fix up the formula in the codegen)."""
return info_dict, False
# (4) Generate derivative information of unary foreach functions if none is defined in `derivatives.yaml`
base_op_name = f.func.name.name
if (
base_op_name.base.startswith("_foreach")
and not base_op_name.inplace
and len(f.func.arguments.post_self_positional) == 0
):
ref_native_op_name = base_op_name.base.split("_foreach_")[-1]
for function_schema in functional_info_by_signature:
if (
function_schema.name.name.base == ref_native_op_name
and not function_schema.name.name.inplace
):
all_saved_inputs = []
all_saved_outputs = []
diff_info_dict = copy.deepcopy(
differentiability_infos[function_schema]
)
diff_info = diff_info_dict["Default"]
modified_derivative_formulas = []
for derivative in diff_info.derivatives:
saved_inputs = []
saved_outputs = []
modified_formula = (
derivative.formula.replace("grad", "grads[i]")
.replace("self", "self[i]")
.replace("result", "result[i]")
)
if "self" in modified_formula:
saved_inputs.append(
SavedAttribute(
nctype=NamedCType(
name="self", type=BaseCType(tensorListT)
),
expr="self",
)
)
all_saved_inputs.append(saved_inputs[-1])
if "result" in modified_formula:
saved_outputs.append(
SavedAttribute(
nctype=NamedCType(
name="result", type=BaseCType(tensorListT)
),
expr="result",
)
)
all_saved_outputs.append(saved_outputs[-1])
modified_derivative = Derivative(
formula=modified_formula,
original_formula=derivative.original_formula,
var_names=("self",),
saved_inputs=tuple(saved_inputs),
saved_outputs=tuple(saved_outputs),
named_gradients=set(),
)
modified_derivative_formulas.append(modified_derivative)
assert f.func.arguments.self_arg is not None
diff_info = DifferentiabilityInfo(
name=base_op_name.base,
func=f,
op=f"Foreach{diff_info.op}",
derivatives=modified_derivative_formulas,
forward_derivatives=[],
all_saved_inputs=tuple(set(all_saved_inputs)),
all_saved_outputs=tuple(set(all_saved_outputs)),
available_named_gradients=(),
used_named_gradients=set(),
args_with_derivatives=[
Binding(
name="self",
nctype=NamedCType(
name="self", type=BaseCType(tensorListT)
),
argument=f.func.arguments.self_arg.argument,
default=None,
)
],
non_differentiable_arg_names=[],
output_differentiability=None,
output_differentiability_conditions=None,
)
diff_info_dict["Default"] = diff_info
if f.func not in differentiability_infos:
differentiability_infos[f.func] = diff_info_dict
functional_info_by_signature[f.func] = diff_info_dict
return diff_info_dict, True
return None, False
result: List[NativeFunctionWithDifferentiabilityInfo] = []