mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[mta] Backward of unary foreach functions (#89591)
as per title, this PR defines backward of those.
This doesn't implement forward-mode automatic differentiation as [the current codegen](a747326423/tools/autograd/gen_variable_type.py (L1513)) doesn't seem to handle `ArrayRef<Tensor>`.
Rel:
- https://github.com/pytorch/pytorch/issues/53796
- https://github.com/pytorch/pytorch/issues/58833
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89591
Approved by: https://github.com/albanD
This commit is contained in:
parent
32b2d8009a
commit
30876229a7
|
|
@ -7,6 +7,14 @@ namespace at { namespace native {
|
|||
|
||||
namespace {
|
||||
|
||||
// TODO(crcrpar): Handle version bump in codegen.
|
||||
// rel: https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482
|
||||
inline void increment_version(TensorList tensors) {
|
||||
for (const auto & t : tensors) {
|
||||
t.unsafeGetTensorImpl()->bump_version();
|
||||
}
|
||||
}
|
||||
|
||||
// Initializes args and checks if all args are aligned
|
||||
template<int depth, typename T>
|
||||
__device__ bool init_args(
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ template <typename scalar_t, template<class> class Op> void foreach_unary_op_(Te
|
|||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0>(),
|
||||
Op<opmath_t>());
|
||||
increment_version(tensors);
|
||||
}
|
||||
|
||||
template <template<class> class Op>
|
||||
|
|
|
|||
|
|
@ -597,6 +597,73 @@ BLAS and LAPACK Operations
|
|||
triangular_solve
|
||||
vdot
|
||||
|
||||
Foreach Operations
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. warning::
|
||||
This API is in beta and subject to future changes.
|
||||
Forward-mode AD is not supported.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
_foreach_abs
|
||||
_foreach_abs_
|
||||
_foreach_acos
|
||||
_foreach_acos_
|
||||
_foreach_asin
|
||||
_foreach_asin_
|
||||
_foreach_atan
|
||||
_foreach_atan_
|
||||
_foreach_ceil
|
||||
_foreach_ceil_
|
||||
_foreach_cos
|
||||
_foreach_cos_
|
||||
_foreach_cosh
|
||||
_foreach_cosh_
|
||||
_foreach_erf
|
||||
_foreach_erf_
|
||||
_foreach_erfc
|
||||
_foreach_erfc_
|
||||
_foreach_exp
|
||||
_foreach_exp_
|
||||
_foreach_expm1
|
||||
_foreach_expm1_
|
||||
_foreach_floor
|
||||
_foreach_floor_
|
||||
_foreach_log
|
||||
_foreach_log_
|
||||
_foreach_log10
|
||||
_foreach_log10_
|
||||
_foreach_log1p
|
||||
_foreach_log1p_
|
||||
_foreach_log2
|
||||
_foreach_log2_
|
||||
_foreach_neg
|
||||
_foreach_neg_
|
||||
_foreach_tan
|
||||
_foreach_tan_
|
||||
_foreach_sin
|
||||
_foreach_sin_
|
||||
_foreach_sinh
|
||||
_foreach_sinh_
|
||||
_foreach_round
|
||||
_foreach_round_
|
||||
_foreach_sqrt
|
||||
_foreach_sqrt_
|
||||
_foreach_lgamma
|
||||
_foreach_lgamma_
|
||||
_foreach_frac
|
||||
_foreach_frac_
|
||||
_foreach_reciprocal
|
||||
_foreach_reciprocal_
|
||||
_foreach_sigmoid
|
||||
_foreach_sigmoid_
|
||||
_foreach_trunc
|
||||
_foreach_trunc_
|
||||
_foreach_zero_
|
||||
|
||||
Utilities
|
||||
----------------------------------
|
||||
.. autosummary::
|
||||
|
|
|
|||
|
|
@ -222,6 +222,24 @@ class TestForeach(TestCase):
|
|||
inplace_ref(copied_inputs),
|
||||
self.assertEqual(copied_inputs, inputs)
|
||||
|
||||
def _test_unary(self, device, dtype, opinfo, N, is_fastpath):
|
||||
op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, 1)
|
||||
inputs = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
|
||||
# note(mkozuki): Complex inputs for `_foreach_abs` go through slowpath.
|
||||
if opinfo.name == "_foreach_abs" and dtype in complex_types():
|
||||
is_fastpath = False
|
||||
self._regular_unary_test(dtype, op, ref, inputs, is_fastpath)
|
||||
self._inplace_unary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath)
|
||||
|
||||
if opinfo.supports_autograd and dtype in floating_types():
|
||||
tensors = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath, same_size=True)
|
||||
tensors = [t.requires_grad_() for t in tensors]
|
||||
ref_tensors = [t.clone().detach().requires_grad_() for t in tensors]
|
||||
|
||||
sum(op.func(tensors)).mean().backward()
|
||||
sum([ref.func(t) for t in ref_tensors]).mean().backward()
|
||||
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
|
||||
|
||||
@skipMeta
|
||||
@ops(foreach_unary_op_db)
|
||||
@parametrize("is_fastpath", (True, False))
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@
|
|||
# to that argument could exist. You should either:
|
||||
# - Specify the formula for that gradient
|
||||
# - Specify not_implemented("function_name") as a formula to say that this is not
|
||||
# implement yet (but might be in the future and the user can request that on an issue)
|
||||
# implemented yet (but might be in the future and the user can request that on an issue)
|
||||
# - If that argument is not differentiable, because it is not a floating point dtype or the
|
||||
# function is not differentiable with respect to that argument for
|
||||
# example. You should either:
|
||||
|
|
|
|||
|
|
@ -98,6 +98,23 @@ if (task_should_compute_output({ ${name}_ix })) {
|
|||
"""
|
||||
)
|
||||
|
||||
# note(crcrpar): `self` argument and other optional positional argument
|
||||
# of foreach functions are basically a list of n `Tensor`s thus iterating over
|
||||
# `grads` in order to utilize and apply the existing derivative definitions
|
||||
# to each `Tensor`(s) of `self`, and the others.
|
||||
DERIVATIVE_SINGLE_FOREACH = CodeTemplate(
|
||||
"""\
|
||||
if (task_should_compute_output({ ${name}_ix })) {
|
||||
std::vector<Tensor> grad_result;
|
||||
grad_result.reserve(grads.size());
|
||||
for (const auto & i : c10::irange(grads.size())) {
|
||||
grad_result.emplace_back(${derivative});
|
||||
}
|
||||
copy_range(grad_inputs, ${name}_ix, grad_result);
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
|
||||
"""\
|
||||
if (task_should_compute_output({ ${name}_ix })) {
|
||||
|
|
@ -709,9 +726,13 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
|||
) in ("Tensor", "Tensor?"):
|
||||
formula = "any_grad_defined ? (" + formula + ") : Tensor()"
|
||||
checks_any_grad_defined = True
|
||||
if info.name.startswith("_foreach_"):
|
||||
derivative_template = DERIVATIVE_SINGLE_FOREACH
|
||||
else:
|
||||
derivative_template = DERIVATIVE_SINGLE
|
||||
return (
|
||||
checks_any_grad_defined,
|
||||
DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula),
|
||||
derivative_template.substitute(name=var_names[0], derivative=formula),
|
||||
)
|
||||
else:
|
||||
if "grad_input_mask" in formula:
|
||||
|
|
|
|||
|
|
@ -14003,3 +14003,59 @@ Performs the same operation as :func:`torch.alias`, but all output tensors
|
|||
are freshly created instead of aliasing the input.
|
||||
""",
|
||||
)
|
||||
|
||||
for unary_base_func_name in (
|
||||
"exp",
|
||||
"sqrt",
|
||||
"abs",
|
||||
"acos",
|
||||
"asin",
|
||||
"atan",
|
||||
"ceil",
|
||||
"cos",
|
||||
"cosh",
|
||||
"erf",
|
||||
"erfc",
|
||||
"expm1",
|
||||
"floor",
|
||||
"log",
|
||||
"log10",
|
||||
"log1p",
|
||||
"log2",
|
||||
"neg",
|
||||
"tan",
|
||||
"tanh",
|
||||
"sin",
|
||||
"sinh",
|
||||
"round",
|
||||
"lgamma",
|
||||
"frac",
|
||||
"reciprocal",
|
||||
"sigmoid",
|
||||
"trunc",
|
||||
"zero",
|
||||
):
|
||||
unary_foreach_func_name = f"_foreach_{unary_base_func_name}"
|
||||
if hasattr(torch, unary_foreach_func_name):
|
||||
add_docstr(
|
||||
getattr(torch, unary_foreach_func_name),
|
||||
r"""
|
||||
{}(self: List[Tensor]) -> List[Tensor]
|
||||
|
||||
Apply :func:`torch.{}` to each Tensor of the input list.
|
||||
""".format(
|
||||
unary_foreach_func_name, unary_base_func_name
|
||||
),
|
||||
)
|
||||
unary_inplace_foreach_func_name = f"{unary_foreach_func_name}_"
|
||||
if hasattr(torch, unary_inplace_foreach_func_name):
|
||||
add_docstr(
|
||||
getattr(torch, unary_inplace_foreach_func_name),
|
||||
r"""
|
||||
{}(self: List[Tensor]) -> None
|
||||
|
||||
Apply :func:`torch.{}` to each Tensor of the input list.
|
||||
""".format(
|
||||
unary_inplace_foreach_func_name, unary_base_func_name
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8074,25 +8074,26 @@ class foreach_pointwise_sample_func(foreach_inputs_sample_func):
|
|||
|
||||
|
||||
foreach_unary_op_db: List[OpInfo] = [
|
||||
ForeachFuncInfo('exp', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
|
||||
ForeachFuncInfo('acos', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
|
||||
ForeachFuncInfo('asin', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
|
||||
ForeachFuncInfo('atan', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
|
||||
ForeachFuncInfo('cos', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
|
||||
ForeachFuncInfo('cosh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
|
||||
ForeachFuncInfo('log', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
|
||||
ForeachFuncInfo('log10', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
|
||||
ForeachFuncInfo('log2', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
|
||||
ForeachFuncInfo('tan', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
|
||||
ForeachFuncInfo('tanh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
|
||||
ForeachFuncInfo('sin', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
|
||||
ForeachFuncInfo('sinh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
|
||||
ForeachFuncInfo('exp', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
|
||||
ForeachFuncInfo('acos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
|
||||
ForeachFuncInfo('asin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
|
||||
ForeachFuncInfo('atan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
|
||||
ForeachFuncInfo('cos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
|
||||
ForeachFuncInfo('cosh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
|
||||
ForeachFuncInfo('log', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
|
||||
ForeachFuncInfo('log10', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
|
||||
ForeachFuncInfo('log2', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
|
||||
ForeachFuncInfo('tan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
|
||||
ForeachFuncInfo('tanh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
|
||||
ForeachFuncInfo('sin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
|
||||
ForeachFuncInfo('sinh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
|
||||
|
||||
ForeachFuncInfo(
|
||||
'neg',
|
||||
dtypes=all_types_and_complex(),
|
||||
dtypesIfCUDA=all_types_and_complex(),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -8100,6 +8101,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
dtypes=floating_and_complex_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -8107,6 +8109,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
dtypes=all_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -8114,6 +8117,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -8121,6 +8125,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -8128,6 +8133,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -8135,6 +8141,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
dtypes=all_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -8142,6 +8149,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
dtypes=floating_and_complex_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -8149,6 +8157,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
dtypes=all_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -8156,6 +8165,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -8163,6 +8173,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.half),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -8170,6 +8181,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.half),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -8177,6 +8189,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
dtypes=all_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
|
@ -8186,6 +8199,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
supports_autograd=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -2571,6 +2571,7 @@ class ForeachFuncInfo(OpInfo):
|
|||
dtypesIfROCM=None,
|
||||
supports_alpha_param=False,
|
||||
sample_inputs_func=sample_inputs_foreach,
|
||||
supports_autograd=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
|
|
@ -2579,6 +2580,7 @@ class ForeachFuncInfo(OpInfo):
|
|||
dtypesIfCUDA=dtypesIfCUDA,
|
||||
dtypesIfROCM=dtypesIfROCM,
|
||||
sample_inputs_func=sample_inputs_func,
|
||||
supports_autograd=supports_autograd,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import copy
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Match, Optional, Sequence, Set, Tuple
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.types import Binding, NamedCType
|
||||
from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
|
||||
from torchgen.model import (
|
||||
FunctionSchema,
|
||||
NativeFunction,
|
||||
|
|
@ -357,6 +358,94 @@ Attempted to convert a derivative formula for a mutable operator
|
|||
this is not currently supported (we'd need to fix up the formula in the codegen)."""
|
||||
return info_dict, False
|
||||
|
||||
# (4) Generate derivative information of unary foreach functions if none is defined in `derivatives.yaml`
|
||||
base_op_name = f.func.name.name
|
||||
if (
|
||||
base_op_name.base.startswith("_foreach")
|
||||
and not base_op_name.inplace
|
||||
and len(f.func.arguments.post_self_positional) == 0
|
||||
):
|
||||
ref_native_op_name = base_op_name.base.split("_foreach_")[-1]
|
||||
for function_schema in functional_info_by_signature:
|
||||
if (
|
||||
function_schema.name.name.base == ref_native_op_name
|
||||
and not function_schema.name.name.inplace
|
||||
):
|
||||
all_saved_inputs = []
|
||||
all_saved_outputs = []
|
||||
diff_info_dict = copy.deepcopy(
|
||||
differentiability_infos[function_schema]
|
||||
)
|
||||
diff_info = diff_info_dict["Default"]
|
||||
modified_derivative_formulas = []
|
||||
for derivative in diff_info.derivatives:
|
||||
saved_inputs = []
|
||||
saved_outputs = []
|
||||
modified_formula = (
|
||||
derivative.formula.replace("grad", "grads[i]")
|
||||
.replace("self", "self[i]")
|
||||
.replace("result", "result[i]")
|
||||
)
|
||||
if "self" in modified_formula:
|
||||
saved_inputs.append(
|
||||
SavedAttribute(
|
||||
nctype=NamedCType(
|
||||
name="self", type=BaseCType(tensorListT)
|
||||
),
|
||||
expr="self",
|
||||
)
|
||||
)
|
||||
all_saved_inputs.append(saved_inputs[-1])
|
||||
if "result" in modified_formula:
|
||||
saved_outputs.append(
|
||||
SavedAttribute(
|
||||
nctype=NamedCType(
|
||||
name="result", type=BaseCType(tensorListT)
|
||||
),
|
||||
expr="result",
|
||||
)
|
||||
)
|
||||
all_saved_outputs.append(saved_outputs[-1])
|
||||
modified_derivative = Derivative(
|
||||
formula=modified_formula,
|
||||
original_formula=derivative.original_formula,
|
||||
var_names=("self",),
|
||||
saved_inputs=tuple(saved_inputs),
|
||||
saved_outputs=tuple(saved_outputs),
|
||||
named_gradients=set(),
|
||||
)
|
||||
modified_derivative_formulas.append(modified_derivative)
|
||||
assert f.func.arguments.self_arg is not None
|
||||
diff_info = DifferentiabilityInfo(
|
||||
name=base_op_name.base,
|
||||
func=f,
|
||||
op=f"Foreach{diff_info.op}",
|
||||
derivatives=modified_derivative_formulas,
|
||||
forward_derivatives=[],
|
||||
all_saved_inputs=tuple(set(all_saved_inputs)),
|
||||
all_saved_outputs=tuple(set(all_saved_outputs)),
|
||||
available_named_gradients=(),
|
||||
used_named_gradients=set(),
|
||||
args_with_derivatives=[
|
||||
Binding(
|
||||
name="self",
|
||||
nctype=NamedCType(
|
||||
name="self", type=BaseCType(tensorListT)
|
||||
),
|
||||
argument=f.func.arguments.self_arg.argument,
|
||||
default=None,
|
||||
)
|
||||
],
|
||||
non_differentiable_arg_names=[],
|
||||
output_differentiability=None,
|
||||
output_differentiability_conditions=None,
|
||||
)
|
||||
diff_info_dict["Default"] = diff_info
|
||||
if f.func not in differentiability_infos:
|
||||
differentiability_infos[f.func] = diff_info_dict
|
||||
functional_info_by_signature[f.func] = diff_info_dict
|
||||
return diff_info_dict, True
|
||||
|
||||
return None, False
|
||||
|
||||
result: List[NativeFunctionWithDifferentiabilityInfo] = []
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user