[reland 2] Call jit decomp in VariableType to improve forward AD coverage (#84976)

Reland of https://github.com/pytorch/pytorch/pull/84675
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84976
Approved by: https://github.com/zou3519
This commit is contained in:
soulitzer 2022-09-15 22:46:16 +00:00 committed by PyTorch MergeBot
parent 7dcc723d35
commit 7f88934a8f
23 changed files with 452 additions and 247 deletions

View File

@ -166,6 +166,7 @@ core_trainer_sources = [
"torch/csrc/autograd/saved_variable.cpp",
"torch/csrc/autograd/variable.cpp",
"torch/csrc/autograd/utils/warnings.cpp",
"torch/csrc/autograd/jit_decomp_interface.cpp",
"torch/csrc/jit/frontend/name_mangler.cpp",
"torch/csrc/jit/ir/type_hashing.cpp",
"torch/csrc/jit/serialization/pickler.cpp",

View File

@ -133,20 +133,6 @@ void vmapIncompatibleInplaceError(const char* schema_name) {
"please file a bug report instead.");
}
void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
// TODO: templatize based on op and keep static trace_exec
auto * trace_exec = torch::jit::GetDecompositionExecutor(schema);
trace_exec->run((*stack));
if (stack->back().isTuple()) {
IValue tup = stack->back();
stack->pop_back();
for (const auto& elem: tup.toTuple()->elements()) {
stack->push_back(elem);
}
}
}
static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) {
auto result_type = at::native::result_type(logical_scalar_tensor[0], second);
if (logical_scalar_tensor.scalar_type() != result_type) {

View File

@ -197,12 +197,6 @@ inline void handle_variadic_bdims(std::vector<std::pair<Tensor, optional<int64_t
#define VARIADIC_BDIMS_BOXED(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack);
#define RUN_JIT_DECOMPOSITION(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&run_jit_decomposition>());
using UnpackedBatchedTensor = std::tuple<Tensor,optional<int64_t>>;
inline void find_and_unpack_tensors(

View File

@ -15,6 +15,7 @@
#include <c10/core/SymIntArrayRef.h>
#include <c10/util/SmallBuffer.h>
#include <ATen/InferSize.h>
#include <torch/csrc/jit/runtime/decomposition_registry.h>
namespace at { namespace functorch {
@ -510,7 +511,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(chunk, chunk_batching_rule);
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
VMAP_SUPPORT(flip, flip_batch_rule);
RUN_JIT_DECOMPOSITION(trace)
m.impl("trace", torch::CppFunction::makeFromBoxedFunction<&torch::jit::run_jit_decomposition>());
VMAP_SUPPORT(tril, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(tril)));
VMAP_SUPPORT(triu, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(triu)));
VMAP_SUPPORT(repeat, repeat_batch_rule);

View File

@ -389,43 +389,9 @@ WithoutTop::~WithoutTop() {
pushDynamicLayer(std::move(layer_));
}
// NOTE: [forward-mode AD decompositions hack]
//
// The mechanism is: in DynamicLayerFrontMode, IF we are dispatching on the
// jvp transform, AND we have a decomposition for the operation, then run
// the decomposition.
//
// Let's break that down. There are a douple of moving pieces.
//
// 0. How do we know what transform we're dispatching on?
// Easy, check the top of the DynamicLayerStack and read the transform.
//
// 1. Next, we must identify when an operation (e.g. nll_loss_backward)
// gets dispatched to.
// - register a special kernel to the DynamicLayerFrontMode key
// (see JVP_DECOMP)
// - that special kernel invokes dynamicLayerFrontFallbackOperator with
// an arg indicating we're going to use a decomp
//
// 2. Next, we need to call the decomposition. See call_decomposition_for_jvp.
// We currently use python decompositions that we torchscript.
// Ideally c10::OperatorHandle would have a field like this
// to identify the operator.
// The stuff here should map 1:1 with the operator name.
// aten::nll_loss_backward -> nll_loss_backward
// aten::add.Tensor -> add_Tensor
static void call_decomposition_for_jvp(
static void dynamicLayerFrontFallback(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
run_jit_decomposition(op, stack);
}
static void dynamicLayerFrontFallbackOperator(
const c10::OperatorHandle& op,
torch::jit::Stack* stack,
bool decomp_jvp) {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
@ -434,13 +400,6 @@ static void dynamicLayerFrontFallbackOperator(
dump_local_tls();
}
#endif
// Hack: if jvp and we have a decomposition registered, then do the decomposition
if (dynamicLayerStack.back().interpreter().key() == TransformType::Jvp &&
decomp_jvp) {
return call_decomposition_for_jvp(op, stack);
}
// Save the current LocalDispatchKeySet (to the current DynamicLayer).
// Upon exiting the current scope, that LocalDispatchKeySet gets restored.
// When the current DynamicLayer dispatches to the next (inner) DynamicLayer,
@ -460,16 +419,6 @@ restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) {
return c10::impl::ForceDispatchKeyGuard(key_set);
}
void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
return dynamicLayerFrontFallbackOperator(op, stack, false);
}
void dynamicLayerFrontFallBackWithDecomp(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
return dynamicLayerFrontFallbackOperator(op, stack, true);
}
void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto& layer = dynamicLayerStackAccessor().back();
auto restore_guard = restoreLocalDispatchKeySetRAII(layer.interpreter().getSavedLocalDispatchKeySet());
@ -486,24 +435,5 @@ TORCH_LIBRARY_IMPL(_, FuncTorchDynamicLayerBackMode, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
}
#define JVP_DECOMP(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>());
#define JVP_DECOMP2(op, overload) \
m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>());
TORCH_LIBRARY_IMPL(aten, FuncTorchDynamicLayerFrontMode, m) {
JVP_DECOMP(nll_loss_backward);
JVP_DECOMP(nll_loss2d_backward);
JVP_DECOMP(_log_softmax_backward_data);
JVP_DECOMP(_softmax_backward_data);
OP_DECOMPOSE(log_sigmoid);
JVP_DECOMP(log_sigmoid_forward);
JVP_DECOMP(native_layer_norm_backward);
JVP_DECOMP(native_batch_norm_backward);
JVP_DECOMP(cudnn_batch_norm_backward);
}
}
} // namespace at

View File

@ -1124,9 +1124,6 @@ class TestOperators(TestCase):
# RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor,
# this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3].
xfail('normal', ''),
xfail('_masked.log_softmax', ''), # NYI: forward-AD for _log_softmax_backward_data
xfail('_masked.softmax', ''), # NYI: forward-AD for _softmax_backward_data
xfail('_masked.softmin', ''), # NYI: forward-AD for _softmax_backward_data
xfail('cdist', ''), # NYI: forward-AD for _cdist_forward
xfail('cholesky', ''), # NYI: forward-AD for cholesky
xfail('logcumsumexp', ''), # NYI: forward-AD for logcumsumexp
@ -1134,10 +1131,7 @@ class TestOperators(TestCase):
xfail('nn.functional.grid_sample', ''), # NYI: forward AD for grid_sampler_2d
xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward
xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward
xfail('nn.functional.instance_norm', ''), # NYI: forward AD for native_batch_norm_backward
xfail('nn.functional.logsigmoid', ''), # not differentiable w.r.t. buffer
xfail('nn.functional.softmin', ''), # NYI: forward-AD for _softmax_backward_data
xfail('nn.functional.softmin', 'with_dtype'), # NYI: forward-AD for _softmax_backward_data
xfail('renorm', ''), # NYI: forward AD for renorm
xfail('symeig', ''), # NYI: forward AD for symeig
xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward
@ -1151,7 +1145,6 @@ class TestOperators(TestCase):
xfail('scatter_reduce', 'mean'), # NYI: forward-AD for scatter_reduce
xfail('scatter_reduce', 'prod'), # NYI: forward-AD for scatter_reduce
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
xfail('native_layer_norm', ''), # NYI: forward-AD for native_layer_norm_backward
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
skip('as_strided_scatter', ''), # seems flaky
xfail('segment_reduce', 'offsets'), # NYI: forward-AD for segment_reduce
@ -1207,37 +1200,8 @@ class TestOperators(TestCase):
expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec))
return expected
# HACK: obviously pytorch should also have the same coverage
# For things that do have the same coverage, we test that jvp x vjp
# are the same between PyTorch and functorch. For things that don't,
# we check that jacfwd(vjp) and jacrev(vjp) are the same. This results
# in slower tests.
FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH = {
'nn.functional.nll_loss',
'softmax',
'log_softmax',
'nn.functional.cross_entropy',
'nn.functional.layer_norm',
'nn.functional.batch_norm',
}
if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH:
self.assertFalse(op.supports_fwgrad_bwgrad,
f"{op.name} now supports forward over reverse without a decomposition. " +
"Please remove the decomposition version")
def is_differentiable(t):
return isinstance(t, torch.Tensor) and t.dtype == torch.float32
args = (cotangents, *primals)
if op.name == 'nn.functional.binary_cross_entropy':
argnums = (0, 1) # targets is float32 but isn't differentiable
atol_rtol = 1.5e-4, 1.3e-06
else:
argnums = tuple(i for i in range(len(args)) if is_differentiable(args[i]))
atol_rtol = None
self._compare_jacobians_of_vjp(fn, args, argnums, atol_rtol)
else:
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
self.assertEqual(result, expected)
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
self.assertEqual(result, expected)
@skipOps('TestOperators', 'test_vmapjvpvjp', vjp_fail.union({
# Following operatos take too long, hence skipped

View File

@ -1957,7 +1957,20 @@
- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
self: log_sigmoid_backward(grad, self, buffer)
output: auto_element_wise
# HACK: This is just auto_element_wise followed by a view_as. The reason we have
# this is bc forward AD was complaining here about the shapes not being the same:
# the primal/tangent are 0-D/1-D respectively. This started happening after moving the
# jvp decomposition mechanism from functorch to core, possibly due to a batching rule.
# In functorch we rely on OP_DECOMPOSE, but now we compute forward AD using an actual
# formula.
#
# We'd like to avoid keeping the entire jvp decomposition mechanism in functorch,
# just for this single decomposition, but also want to avoid any cases from regressing:
# e.g. test_vmapjvpall_nn_functional_logsigmoid_cuda_float32 (passes on cpu, fails on CUDA).
#
# We should either figure out what is going on with vmap or perhaps fwd AD could
# be more tolerant about 0-dim vs 1-dim tensors
output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj().view_as(self_p)
- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _log_softmax_backward_data(grad, result, dim, self.scalar_type())

View File

@ -31,6 +31,7 @@ from torchgen.api import cpp
from torchgen.api.autograd import (
DifferentiableInput,
dispatch_strategy,
ForwardDerivative,
gen_differentiable_outputs,
is_differentiable,
NativeFunctionWithDifferentiabilityInfo,
@ -599,8 +600,14 @@ at::redispatch::${api_name}(${unpacked_args})"""
DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate(
"""\
auto ${tmp_var} = ([&]() {
${guard}
return ${base_type_call};
if (${try_jit_decomposition_bool} && ${any_has_forward_grad}) {
static c10::OperatorName full_name("aten::${op_name}", "${op_overload}");
static c10::optional<c10::OperatorHandle> opt_op = c10::Dispatcher::singleton().findSchema(full_name);
return impl::run_jit_decomposition_with_args_for_jvp<${return_types}>("${op_name}", *opt_op, ks, ${arg_names});
} else {
${guard}
return ${base_type_call};
}
})();
"""
)
@ -644,6 +651,12 @@ isFwGradDefined(${req_inp})\
"""
)
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE = CodeTemplate(
"""\
isFwGradDefinedTensorList(${req_inp})\
"""
)
FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate(
"""\
auto ${inp}_t_raw = toNonOptFwGrad(${inp});
@ -976,6 +989,23 @@ def emit_body(
f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative"
)
if requires_derivative and not len(fw_derivatives) == 0:
assert sum(len(derivative.var_names) for derivative in fw_derivatives) == len(
differentiable_outputs
), (
"Expected the number of forward derivatives implemented to match the "
"number of differentiable outputs. NB: This only applies when at least "
"one forward derivative is implemented. Not implementing any forward "
"derivatives is also okay, and we would require inputs to the op to "
"not have associated tangents in that case."
)
try_jit_decomposition = (
requires_derivative
and len(fw_derivatives) == 0
and (not modifies_arguments(f))
and (not returns_void)
)
def emit_save_inputs() -> List[str]:
setup: List[str] = []
if info is None or not info.has_derivatives:
@ -1342,7 +1372,9 @@ def emit_body(
)
return call
def emit_call(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
def emit_call(
f: NativeFunction, unpacked_bindings: List[Binding], try_jit_decomposition: bool
) -> str:
# We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
# (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
# the baseType operations still dispatch to non-Variable type, even if the arguments passed
@ -1356,13 +1388,47 @@ def emit_body(
else:
guard = "at::AutoDispatchBelowADInplaceOrView guard;"
try_jit_decomposition_bool = "true" if try_jit_decomposition else "false"
any_has_forward_grad = (
get_any_has_fw_grad_cond(derivative=None)
if requires_derivative
else "false"
)
return_types = ", ".join(
[cpp.return_type(a, symint=True).cpp_type() for a in f.func.returns]
)
if len(f.func.returns) > 1:
return_types = f"std::tuple<{return_types}>"
arg_names = [
a.name
for a in cpp.arguments(
f.func.arguments,
faithful=True,
symint=True,
method=False,
cpp_no_default_args=set(),
)
]
if not modifies_arguments(f) and not returns_void:
# Just to keep things simple here, we only care about this path
# and always emit the if/else for now
call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute(
base_type_call=base_type_call, tmp_var=TMP_VAR, guard=guard
base_type_call=base_type_call,
tmp_var=TMP_VAR,
guard=guard,
try_jit_decomposition_bool=try_jit_decomposition_bool,
any_has_forward_grad=any_has_forward_grad,
op_name=cpp.name(f.func),
op_overload=f.func.name.overload_name,
return_types=return_types,
arg_names=arg_names,
)
call += wrap_output(f, unpacked_bindings, TMP_VAR)
else:
assert not try_jit_decomposition
call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute(
base_type_call=base_type_call, guard=guard
)
@ -1410,38 +1476,14 @@ def emit_body(
def emit_any_has_forward_grad() -> List[str]:
content: List[str] = []
for derivative in fw_derivatives:
assert derivative.required_inputs_fw_grad is not None
requires_fw_grad = " || ".join(
[
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
for inp in differentiable_inputs
if inp.name in derivative.required_inputs_fw_grad
]
)
if not requires_fw_grad:
# Handle functions like stack
# For these, we don't unpack anything and always call the user function
if not (
len(differentiable_inputs) == 1
and is_tensor_list_type(differentiable_inputs[0].type)
):
raise RuntimeError(
f'No differentiable input to "{name}" is a differentiable Tensor (as the provided '
"forward AD formula does not use any input tangent) even though a forward gradient "
"formula has been defined for it. This case should only happen for function that "
"take a single TensorList as input. All other cases are not supported right now."
)
requires_fw_grad = "true"
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
if info and info.output_differentiability_conditions:
assert len(info.output_differentiability_conditions) == 1
requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && ({requires_fw_grad})"
requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && {requires_fw_grad}"
content.append(
f"auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};\n"
f"(void){get_any_has_forward_grad_name(derivative.var_names)};"
)
return content
def emit_check_inplace() -> List[str]:
@ -1564,46 +1606,83 @@ def emit_body(
content.append("\n".join(fw_grad_setters))
return content
def get_any_has_fw_grad_cond(derivative: Optional[ForwardDerivative]) -> str:
#
# Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)")
#
if derivative is None:
# (1) If a derivative is NOT provided, cond will check fw_grad of ALL differentiable inputs
# - Used in the out_fn case when we want to forbid fw derivatives
# - Used in the case where the fw_derivative is not defined, but we want
# To check if there is a decomposition registered for jvp
to_check: List[str] = []
for inp in list(
mapMaybe(
gen_differentiable_input,
f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator]
)
):
if is_tensor_type(inp.type):
to_check.append(
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
)
elif is_tensor_list_type(inp.type):
to_check.append(
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE.substitute(
req_inp=inp.name
)
)
else:
raise RuntimeError(
f'Unsupported input type for "{name}" when forbidding forward AD usage.'
)
return f'({" || ".join(to_check)})'
else:
# (2) If derivative is provided, use that information to determine which inputs
# to check fw_grad for
assert derivative.required_inputs_fw_grad is not None
if len(derivative.required_inputs_fw_grad) == 0:
# Handle functions like stack
# For these, we don't unpack anything and always call the user function
if not (
len(differentiable_inputs) == 1
and is_tensor_list_type(differentiable_inputs[0].type)
):
raise RuntimeError(
f'No differentiable input to "{name}" is a differentiable Tensor (as the provided '
"forward AD formula does not use any input tangent) even though a forward gradient "
"formula has been defined for it. This case should only happen for function that "
"take a single TensorList as input. All other cases are not supported right now."
)
any_has_fw_grad = "true"
else:
any_has_fw_grad = " || ".join(
[
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
for inp in differentiable_inputs
if inp.name in derivative.required_inputs_fw_grad
]
)
any_has_fw_grad = f"({any_has_fw_grad})"
return any_has_fw_grad
def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str:
def get_msg() -> str:
if is_out_fn:
msg = "because it is an out= function"
else:
msg = (
"because it has not been implemented yet.\\nPlease file an issue "
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
"so that we can prioritize its implementation."
)
return msg
res = ""
to_check: List[str] = []
for inp in list(
mapMaybe(
gen_differentiable_input,
f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator]
if is_out_fn:
msg = "because it is an out= function"
else:
msg = (
"because it has not been implemented yet.\\nPlease file an issue "
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
"so that we can prioritize its implementation."
)
):
if is_tensor_type(inp.type):
to_check.append(
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
)
elif is_tensor_list_type(inp.type):
cond = FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp="_t")
res += FW_DERIVATIVE_FORBID_LIST_TEMPLATE.substitute(
arg=inp.name, cond=cond, name=name, msg=get_msg()
)
else:
raise RuntimeError(
f'Unsupported input type for "{name}" when forbidding forward AD usage.'
)
if len(to_check) > 0:
cond = " || ".join(to_check)
res += FW_DERIVATIVE_FORBID_TEMPLATE.substitute(
cond=cond, name=name, msg=get_msg()
)
return res
cond = get_any_has_fw_grad_cond(derivative=None)
return (
FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond, name=name, msg=msg)
if cond != ""
else ""
)
body: List[str] = []
unpack_args_stats, unpacked_bindings = unpack_args(f)
@ -1617,7 +1696,7 @@ def emit_body(
body.extend(setup_derivative(differentiable_inputs))
body.append(declare_returned_variables(f))
body.append(emit_call(f, unpacked_bindings))
body.append(emit_call(f, unpacked_bindings, try_jit_decomposition))
if requires_derivative:
# set_flags has to appear after version_counter, because rebase_history
# requires that the counter is incremented before it is called
@ -1627,20 +1706,11 @@ def emit_body(
if is_out_fn:
body.append(emit_forbid_fw_derivatives(is_out_fn=True))
else:
if requires_derivative:
body.extend(emit_fw_derivatives())
if len(fw_derivatives) == 0:
body.append(emit_forbid_fw_derivatives())
if requires_derivative and not try_jit_decomposition:
if len(fw_derivatives) > 0:
body.extend(emit_fw_derivatives())
else:
assert sum(
len(derivative.var_names) for derivative in fw_derivatives
) == len(differentiable_outputs), (
"Expected the number of forward derivatives implemented to match the "
"number of differentiable outputs. NB: This only applies when at least "
"one forward derivative is implemented. Not implementing any forward "
"derivatives is also okay, and we would require inputs to the op to "
"not have associated tangents in that case."
)
body.append(emit_forbid_fw_derivatives())
if requires_derivative:
# Save only after the forward AD has been set up

View File

@ -893,6 +893,25 @@ def compiled_with_cxx11_abi():
from torch._ops import ops
from torch._classes import classes
# Import from torch._decomp import decompositions_for_jvp to register
# decompositions for jvp to the jit registry
# (decompositions_for_jvp depends on torch.ops, so we place it after)
#
# FIXME: We specify that __debug__ must be True because
# if python is run with -OO or -O flags (i.e., __debug__ is False), we encounter the
# following error:
#
# Return value was annotated as having type Tuple[NoneType, NoneType] but is actually of
# type Tuple[Tensor, Tensor]:
# File ".../torch/_decomp/__init__.py", line 1585
# else:
# buffer = z
# return min - torch.log1p(z), buffer
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
if os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__ and not torch._C._is_deploy_enabled(): # type: ignore[attr-defined]
from torch._decomp import decompositions_for_jvp
del decompositions_for_jvp
# quantization depends on torch.fx
# Import quantization
from torch import quantization as quantization

View File

@ -10,7 +10,7 @@ decomposition_table_for_jvp: Dict[torch._ops.OpOverload, Callable] = {}
register_decomposition = torch._decomp.register_decomposition
aten = torch.ops.aten
# NOTE: [forward-mode AD decompositions hack]
# NOTE: [forward-mode AD decompositions mechanism]
#
# The mechanism is in VariableType,
# IF any inputs have forward grad
@ -23,9 +23,15 @@ aten = torch.ops.aten
# Note that we would be building the backward graph at the decomposed level
# too, but that is OK, because we would've errored out otherwise anyway.
#
# TODO: what if jit decompositions exists, should we just use it?
# or do we want to have an explicit white list like functorch had
# using special JVP_DECOMP DynamicLayerFront kernel
# TODO: The mechanism we are using to register decompositions doesn't
# seem to be exclusively used for jvp. So open question here is whether
# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
# If that is the case, we may go down the decomposition path unexpectedly
# (and possibly produce an unintelligible error) vs erroring out earlier and
# printing that the forward AD formula is not implemented.
#
# The solution to this may be to have a explicitly white list control when
# to enable the decomposition.
def maybe_register_decomposition(op):
@ -179,7 +185,7 @@ def native_layer_norm_backward(
if len(outer_dim_indices) > 0:
d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
else:
d_bias = grad_out
d_bias = grad_out.clone()
elif bias is not None:
d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
else:

View File

@ -880,7 +880,7 @@ def add(
# TODO: add docstring
atan2 = _make_elementwise_binary_reference(
prims.atan2,
prims.atan2, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
supports_lhs_python_scalar=False,
supports_rhs_python_scalar=False,
@ -888,33 +888,33 @@ atan2 = _make_elementwise_binary_reference(
# TODO: add docstring
bitwise_and = _make_elementwise_binary_reference(
prims.bitwise_and,
prims.bitwise_and, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
# TODO: add docstring
bitwise_left_shift = _make_elementwise_binary_reference(
prims.shift_left,
prims.shift_left, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
aten_op=torch.ops.aten.bitwise_left_shift, # prim/aten name mismatch
)
# TODO: add docstring
bitwise_or = _make_elementwise_binary_reference(
prims.bitwise_or,
prims.bitwise_or, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
# TODO: add docstring
bitwise_right_shift = _make_elementwise_binary_reference(
prims.shift_right_arithmetic,
prims.shift_right_arithmetic, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
aten_op=torch.ops.aten.bitwise_right_shift, # prim/aten name mismatch
)
# TODO: add docstring
bitwise_xor = _make_elementwise_binary_reference(
prims.bitwise_xor,
prims.bitwise_xor, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
@ -971,7 +971,7 @@ def div(
# TODO: add docstring
eq = _make_elementwise_binary_reference(
prims.eq,
prims.eq, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
supports_lhs_python_scalar=False,
)
@ -1135,7 +1135,7 @@ floor_divide = _make_elementwise_binary_reference(
# TODO: add docstring
fmax = _make_elementwise_binary_reference(
prims.fmax,
prims.fmax, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
aten_op=torch.ops.aten.fmax,
supports_lhs_python_scalar=False,
@ -1144,7 +1144,7 @@ fmax = _make_elementwise_binary_reference(
# TODO: add docstring
fmin = _make_elementwise_binary_reference(
prims.fmin,
prims.fmin, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
aten_op=torch.ops.aten.fmin,
supports_lhs_python_scalar=False,
@ -1153,7 +1153,7 @@ fmin = _make_elementwise_binary_reference(
# TODO: add docstring
fmod = _make_elementwise_binary_reference(
prims.fmod,
prims.fmod, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
aten_op=torch.ops.aten.fmod,
supports_lhs_python_scalar=False,
@ -1162,7 +1162,7 @@ fmod = _make_elementwise_binary_reference(
# TODO: add docstring
gcd = _make_elementwise_binary_reference(
prims.gcd,
prims.gcd, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
aten_op=torch.ops.aten.gcd,
supports_lhs_python_scalar=False,
@ -1171,14 +1171,14 @@ gcd = _make_elementwise_binary_reference(
# TODO: add docstring
ge = _make_elementwise_binary_reference(
prims.ge,
prims.ge, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
supports_lhs_python_scalar=False,
)
# TODO: add docstring
gt = _make_elementwise_binary_reference(
prims.gt,
prims.gt, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
supports_lhs_python_scalar=False,
)
@ -1201,21 +1201,21 @@ heaviside = _make_elementwise_binary_reference(
)
hypot = _make_elementwise_binary_reference(
prims.hypot,
prims.hypot, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
supports_lhs_python_scalar=False,
supports_rhs_python_scalar=False,
)
igamma = _make_elementwise_binary_reference(
prims.igamma,
prims.igamma, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
supports_lhs_python_scalar=False,
supports_rhs_python_scalar=False,
)
igammac = _make_elementwise_binary_reference(
prims.igammac,
prims.igammac, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
supports_lhs_python_scalar=False,
supports_rhs_python_scalar=False,
@ -1320,7 +1320,7 @@ lcm = _make_elementwise_binary_reference(
# TODO: add docstring
le = _make_elementwise_binary_reference(
prims.le,
prims.le, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
supports_lhs_python_scalar=False,
)
@ -1383,39 +1383,39 @@ logical_xor = _make_elementwise_binary_reference(
# TODO: add docstring
lt = _make_elementwise_binary_reference(
prims.lt,
prims.lt, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
supports_lhs_python_scalar=False,
)
# TODO: add docstring
maximum = _make_elementwise_binary_reference(
prims.maximum,
prims.maximum, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
# TODO: add docstring
minimum = _make_elementwise_binary_reference(
prims.minimum,
prims.minimum, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
# TODO: add docstring
mul = _make_elementwise_binary_reference(
prims.mul,
prims.mul, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
# TODO: add docstring
ne = _make_elementwise_binary_reference(
prims.ne,
prims.ne, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
supports_lhs_python_scalar=False,
)
# TODO: add docstring
nextafter = _make_elementwise_binary_reference(
prims.nextafter,
prims.nextafter, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
supports_lhs_python_scalar=False,
supports_rhs_python_scalar=False,
@ -1423,7 +1423,7 @@ nextafter = _make_elementwise_binary_reference(
# TODO: add docstring
remainder = _make_elementwise_binary_reference(
prims.remainder,
prims.remainder, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
aten_op=torch.ops.aten.remainder,
)
@ -1484,7 +1484,7 @@ def sub(
# TODO: add docstring
true_divide = _make_elementwise_binary_reference(
prims.div,
prims.div, # type: ignore[has-type]
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
aten_op=None, # CompositeImplicitAutograd
)

View File

@ -61,7 +61,7 @@ def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
zeta = _make_elementwise_binary_reference(
prims.zeta,
prims.zeta, # type: ignore[has-type]
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
aten_op=torch.ops.aten.special_zeta,
)

View File

@ -368,7 +368,7 @@ def _get_analytical_jacobian_forward_ad(fn, inputs, outputs, *, check_grad_dtype
dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs)
for index_o, d_o in enumerate(dual_outputs):
val, res = fwAD.unpack_dual(d_o)
if check_grad_dtypes and val.is_complex() != res.is_complex():
if check_grad_dtypes and res is not None and val.is_complex() != res.is_complex():
raise GradcheckError('Forward AD gradient has dtype mismatch.')
if res is None:

View File

@ -1285,6 +1285,14 @@ Call this whenever a new thread is created in order to propagate values from
std::cout << "Excluded: " << toString(local_keyset.excluded_) << "\n";
});
py_module.def("_is_deploy_enabled", []() {
#if defined(USE_DEPLOY)
return true;
#else
return false;
#endif
});
const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
THPDefaultCPUGenerator =
(THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);

View File

@ -2,6 +2,9 @@
#include <c10/util/irange.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
@ -11,6 +14,7 @@
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/autograd/jit_decomp_interface.h>
#include <torch/csrc/utils/variadic.h>
#include <array>
@ -456,5 +460,58 @@ inline std::vector<c10::ScalarType> to_args_scalartypes(
return args_scalartypes;
}
namespace impl {
namespace {
// If run_jit_decomposition were not a member function, we would be able
// to pass this as a template parameter to c10::Boxedkernel::makeFromFunction.
// However, member functions cannot be passed this way - instead we wrap our
// call in this functor so it can be passed to c10::BoxedKernel::makeFromFunctor
class WrapperFunctor final : public c10::OperatorKernel {
public:
WrapperFunctor(JitDecompInterface* impl) : impl_(impl){};
void operator()(
const c10::OperatorHandle& op,
c10::DispatchKeySet ks,
torch::jit::Stack* stack) {
impl_->run_jit_decomposition(op, stack);
}
JitDecompInterface* impl_;
};
} // namespace
template <class Return, class... Args>
Return run_jit_decomposition_with_args_for_jvp(
c10::string_view name,
const c10::OperatorHandle& opHandle,
c10::DispatchKeySet dispatchKeySet,
Args&&... args) {
// see NOTE: [Jit Decomposition Interface]
JitDecompInterface* impl = getJitDecompImpl();
TORCH_CHECK_NOT_IMPLEMENTED(
impl && impl->has_jit_decomposition(opHandle.schema()),
"Trying to use forward AD with ",
name,
" that does not support it because it has not been implemented yet.\nPlease file an issue "
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
"so that we can prioritize its implementation.\n"
"Note that forward AD support for some operators require PyTorch to be built with "
"TorchScript and for JIT to be enabled. "
"If the environment var PYTORCH_JIT=0 is set or if the library is not built with TorchScript, "
"some operators may no longer be used with forward AD.");
return c10::KernelFunction::makeFromBoxedKernel(
c10::BoxedKernel::makeFromFunctor(
std::make_unique<WrapperFunctor>(impl)))
.call<Return, Args...>(
opHandle, dispatchKeySet, std::forward<Args>(args)...);
}
} // namespace impl
} // namespace autograd
} // namespace torch

View File

@ -100,5 +100,23 @@ inline bool isFwGradDefined(const c10::optional<at::Tensor>& t) {
return t.has_value() && t->defined() && t->_fw_grad(/*level */ 0).defined();
}
inline bool isFwGradDefinedTensorList(const at::TensorList& variables) {
bool ret = false;
for (auto& variable : variables) {
ret |= isFwGradDefined(variable);
}
return ret;
}
inline bool isFwGradDefinedTensorList(
const c10::List<c10::optional<at::Tensor>> li) {
bool ret = false;
for (auto i : c10::irange(li.size())) {
auto t = li.get(i);
ret |= (t.has_value() && isFwGradDefined(t.value()));
}
return ret;
}
} // namespace autograd
} // namespace torch

View File

@ -0,0 +1,21 @@
#include <torch/csrc/autograd/jit_decomp_interface.h>
namespace torch {
namespace autograd {
namespace impl {
namespace {
JitDecompInterface* impl = nullptr;
}
void setJitDecompImpl(JitDecompInterface* impl_) {
impl = impl_;
}
JitDecompInterface* getJitDecompImpl() {
return impl;
}
} // namespace impl
} // namespace autograd
} // namespace torch

View File

@ -0,0 +1,54 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/core/function_schema.h>
#include <c10/macros/Export.h>
// NOTE: [Jit Decomposition Interface]
//
// For some context of why we need this at all, see NOTE: [forward-mode AD
// decompositions mechanism]
//
// Introducing that mechanism from the NOTE is problematic because:
// - it relies on TorchScript, so now VariableTypeX.cpp depends on TorchScript.
// - there exist internal builds like lite_trainer, which depend on VariableType
// but do not depend on TorchScript.
//
// For internal builds like lite_trainer builds to pass, and for OSS builds that
// do depend on TorchScript to still support the forward AD decomp mechanism, we
// implement a PImpl pattern to avoid a static dependency in favor of a dynamic
// one
// - during static initialization time, if the library is built with TorchScript
// setJitDecompImpl is called in decomposition_registry.cpp setting a global
// ptr to the impl
// - when the program is run,if getJitDecompImpl returns a non null ptr, we can
// carry on normally, otherwise we gracefully error out
//
// For extra context, see VariableHooksInterface.h, where a similar technique
// is used
namespace torch {
namespace autograd {
namespace impl {
struct TORCH_API JitDecompInterface {
virtual ~JitDecompInterface() = default;
virtual bool has_jit_decomposition(
const c10::FunctionSchema& schema) const = 0;
virtual void run_jit_decomposition(
const c10::OperatorHandle& op,
jit::Stack* stack) const = 0;
};
TORCH_API void setJitDecompImpl(JitDecompInterface* impl);
TORCH_API JitDecompInterface* getJitDecompImpl();
struct TORCH_API JitDecompRegisterer {
explicit JitDecompRegisterer(JitDecompInterface* impl) {
setJitDecompImpl(impl);
}
};
} // namespace impl
} // namespace autograd
} // namespace torch

View File

@ -8,6 +8,7 @@
#include <torch/csrc/jit/serialization/import_source.h>
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/jit_decomp_interface.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/inliner.h>
@ -160,6 +161,47 @@ void RegisterDecomposition(
schema_to_decomposition[&schema] = g;
}
// see NOTE: [Jit Decomposition Interface]
struct JitDecomp final : torch::autograd::impl::JitDecompInterface {
bool has_jit_decomposition(const c10::FunctionSchema& schema) const override;
void run_jit_decomposition(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) const override;
};
JitDecomp jitDecomp;
torch::autograd::impl::JitDecompRegisterer registerJitDecomp(&jitDecomp);
void JitDecomp::run_jit_decomposition(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) const {
::torch::jit::run_jit_decomposition(op, stack);
}
bool JitDecomp::has_jit_decomposition(const FunctionSchema& schema) const {
return ::torch::jit::has_jit_decomposition(schema);
}
void run_jit_decomposition(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
const auto& schema = op.schema();
// TODO: templatize based on op and keep static trace_exec
auto* trace_exec = torch::jit::GetDecompositionExecutor(schema);
trace_exec->run((*stack));
if (stack->back().isTuple()) {
at::IValue tup = stack->back();
stack->pop_back();
for (const auto& elem : tup.toTuple()->elements()) {
stack->push_back(elem);
}
}
}
bool has_jit_decomposition(const FunctionSchema& schema) {
return GetDecompositionFunction(schema).has_value();
}
Function* GetDecompositionExecutor(const FunctionSchema& schema) {
auto maybe_func = GetDecompositionFunction(schema);
TORCH_INTERNAL_ASSERT(maybe_func);

View File

@ -25,5 +25,11 @@ TORCH_API Function* GetDecompositionExecutor(const char* schema_literal);
TORCH_API Function* GetDecompositionExecutor(const FunctionSchema& schema);
TORCH_API void run_jit_decomposition(
const c10::OperatorHandle& op,
torch::jit::Stack* stack);
TORCH_API bool has_jit_decomposition(const FunctionSchema& schema);
} // namespace jit
} // namespace torch

View File

@ -9,7 +9,7 @@ try:
except ImportError:
HAS_SYMPY = False
aten = torch.ops.aten
aten = torch.ops.aten # type: ignore[has-type]
__all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "is_symbolic_op", "handle_symbolic_op", "PySymInt", "ShapeEnv",

View File

@ -38,7 +38,6 @@ import torch._refs as refs # noqa: F401
import torch._refs.nn.functional
import torch._refs.special
import torch._refs.linalg
import torch._prims as prims # noqa: F401
from torch.utils._pytree import tree_flatten
@ -10233,6 +10232,7 @@ op_db: List[OpInfo] = [
assert_jit_shape_analysis=True,
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=True),
OpInfo('softmax',
aliases=('special.softmax', 'nn.functional.softmax',),
@ -10242,6 +10242,7 @@ op_db: List[OpInfo] = [
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=True),
# `softmin` supports different dtypes based on whether `dtype` argument,
# is passed or not. Hence two OpInfo entries, one with dtype and other without.
@ -10254,6 +10255,7 @@ op_db: List[OpInfo] = [
assert_jit_shape_analysis=False,
assert_autodiffed=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False),
OpInfo('nn.functional.softmin',
variant_test_name="with_dtype",
@ -10262,6 +10264,7 @@ op_db: List[OpInfo] = [
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
assert_autodiffed=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False),
OpInfo(
"nn.functional.cross_entropy",
@ -10270,6 +10273,7 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_cross_entropy,
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=(
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-3)}),
@ -10361,6 +10365,7 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
assert_jit_shape_analysis=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_native_layer_norm,
error_inputs_func=error_inputs_native_layer_norm,
skips=(
@ -10732,6 +10737,7 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
# Consider making it a parameter or input, or detaching the gradient
@ -10750,6 +10756,7 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
decorators=[
DecorateInfo(
@ -11789,6 +11796,7 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
sample_inputs_func=sample_inputs_batch_norm,
skips=(
@ -11811,6 +11819,7 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[onlyCUDA, disablecuDNN],
skips=(
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
@ -14772,6 +14781,7 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_softmax_variant,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_autodiffed=True),
OpInfo(
'log_softmax',
@ -14781,6 +14791,7 @@ op_db: List[OpInfo] = [
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_autodiffed=True),
UnaryUfuncInfo('logit',
aten_backward_name='logit_backward',
@ -15659,6 +15670,7 @@ op_db: List[OpInfo] = [
supports_out=False,
sample_inputs_func=sample_inputs_nll_loss,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
skips=(
# RuntimeError:

View File

@ -990,6 +990,7 @@ op_db: List[OpInfo] = [
),
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
),
OpInfo(
@ -1017,6 +1018,7 @@ op_db: List[OpInfo] = [
],
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
),
OpInfo(
@ -1037,6 +1039,7 @@ op_db: List[OpInfo] = [
),
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
),
OpInfo(