mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[reland 2] Call jit decomp in VariableType to improve forward AD coverage (#84976)
Reland of https://github.com/pytorch/pytorch/pull/84675 Pull Request resolved: https://github.com/pytorch/pytorch/pull/84976 Approved by: https://github.com/zou3519
This commit is contained in:
parent
7dcc723d35
commit
7f88934a8f
|
|
@ -166,6 +166,7 @@ core_trainer_sources = [
|
|||
"torch/csrc/autograd/saved_variable.cpp",
|
||||
"torch/csrc/autograd/variable.cpp",
|
||||
"torch/csrc/autograd/utils/warnings.cpp",
|
||||
"torch/csrc/autograd/jit_decomp_interface.cpp",
|
||||
"torch/csrc/jit/frontend/name_mangler.cpp",
|
||||
"torch/csrc/jit/ir/type_hashing.cpp",
|
||||
"torch/csrc/jit/serialization/pickler.cpp",
|
||||
|
|
|
|||
|
|
@ -133,20 +133,6 @@ void vmapIncompatibleInplaceError(const char* schema_name) {
|
|||
"please file a bug report instead.");
|
||||
}
|
||||
|
||||
void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
// TODO: templatize based on op and keep static trace_exec
|
||||
auto * trace_exec = torch::jit::GetDecompositionExecutor(schema);
|
||||
trace_exec->run((*stack));
|
||||
if (stack->back().isTuple()) {
|
||||
IValue tup = stack->back();
|
||||
stack->pop_back();
|
||||
for (const auto& elem: tup.toTuple()->elements()) {
|
||||
stack->push_back(elem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) {
|
||||
auto result_type = at::native::result_type(logical_scalar_tensor[0], second);
|
||||
if (logical_scalar_tensor.scalar_type() != result_type) {
|
||||
|
|
|
|||
|
|
@ -197,12 +197,6 @@ inline void handle_variadic_bdims(std::vector<std::pair<Tensor, optional<int64_t
|
|||
#define VARIADIC_BDIMS_BOXED(op) \
|
||||
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
|
||||
|
||||
void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
||||
|
||||
#define RUN_JIT_DECOMPOSITION(op) \
|
||||
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&run_jit_decomposition>());
|
||||
|
||||
|
||||
using UnpackedBatchedTensor = std::tuple<Tensor,optional<int64_t>>;
|
||||
|
||||
inline void find_and_unpack_tensors(
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
#include <c10/core/SymIntArrayRef.h>
|
||||
#include <c10/util/SmallBuffer.h>
|
||||
#include <ATen/InferSize.h>
|
||||
#include <torch/csrc/jit/runtime/decomposition_registry.h>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
|
|
@ -510,7 +511,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
|||
VMAP_SUPPORT(chunk, chunk_batching_rule);
|
||||
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
|
||||
VMAP_SUPPORT(flip, flip_batch_rule);
|
||||
RUN_JIT_DECOMPOSITION(trace)
|
||||
m.impl("trace", torch::CppFunction::makeFromBoxedFunction<&torch::jit::run_jit_decomposition>());
|
||||
VMAP_SUPPORT(tril, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(tril)));
|
||||
VMAP_SUPPORT(triu, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(triu)));
|
||||
VMAP_SUPPORT(repeat, repeat_batch_rule);
|
||||
|
|
|
|||
|
|
@ -389,43 +389,9 @@ WithoutTop::~WithoutTop() {
|
|||
pushDynamicLayer(std::move(layer_));
|
||||
}
|
||||
|
||||
// NOTE: [forward-mode AD decompositions hack]
|
||||
//
|
||||
// The mechanism is: in DynamicLayerFrontMode, IF we are dispatching on the
|
||||
// jvp transform, AND we have a decomposition for the operation, then run
|
||||
// the decomposition.
|
||||
//
|
||||
// Let's break that down. There are a douple of moving pieces.
|
||||
//
|
||||
// 0. How do we know what transform we're dispatching on?
|
||||
// Easy, check the top of the DynamicLayerStack and read the transform.
|
||||
//
|
||||
// 1. Next, we must identify when an operation (e.g. nll_loss_backward)
|
||||
// gets dispatched to.
|
||||
// - register a special kernel to the DynamicLayerFrontMode key
|
||||
// (see JVP_DECOMP)
|
||||
// - that special kernel invokes dynamicLayerFrontFallbackOperator with
|
||||
// an arg indicating we're going to use a decomp
|
||||
//
|
||||
// 2. Next, we need to call the decomposition. See call_decomposition_for_jvp.
|
||||
// We currently use python decompositions that we torchscript.
|
||||
|
||||
// Ideally c10::OperatorHandle would have a field like this
|
||||
// to identify the operator.
|
||||
// The stuff here should map 1:1 with the operator name.
|
||||
// aten::nll_loss_backward -> nll_loss_backward
|
||||
// aten::add.Tensor -> add_Tensor
|
||||
|
||||
static void call_decomposition_for_jvp(
|
||||
static void dynamicLayerFrontFallback(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack) {
|
||||
run_jit_decomposition(op, stack);
|
||||
}
|
||||
|
||||
static void dynamicLayerFrontFallbackOperator(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack,
|
||||
bool decomp_jvp) {
|
||||
auto& dynamicLayerStack = dynamicLayerStackAccessor();
|
||||
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
|
||||
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
|
||||
|
|
@ -434,13 +400,6 @@ static void dynamicLayerFrontFallbackOperator(
|
|||
dump_local_tls();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Hack: if jvp and we have a decomposition registered, then do the decomposition
|
||||
if (dynamicLayerStack.back().interpreter().key() == TransformType::Jvp &&
|
||||
decomp_jvp) {
|
||||
return call_decomposition_for_jvp(op, stack);
|
||||
}
|
||||
|
||||
// Save the current LocalDispatchKeySet (to the current DynamicLayer).
|
||||
// Upon exiting the current scope, that LocalDispatchKeySet gets restored.
|
||||
// When the current DynamicLayer dispatches to the next (inner) DynamicLayer,
|
||||
|
|
@ -460,16 +419,6 @@ restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) {
|
|||
return c10::impl::ForceDispatchKeyGuard(key_set);
|
||||
}
|
||||
|
||||
void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
return dynamicLayerFrontFallbackOperator(op, stack, false);
|
||||
}
|
||||
|
||||
void dynamicLayerFrontFallBackWithDecomp(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack) {
|
||||
return dynamicLayerFrontFallbackOperator(op, stack, true);
|
||||
}
|
||||
|
||||
void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
auto& layer = dynamicLayerStackAccessor().back();
|
||||
auto restore_guard = restoreLocalDispatchKeySetRAII(layer.interpreter().getSavedLocalDispatchKeySet());
|
||||
|
|
@ -486,24 +435,5 @@ TORCH_LIBRARY_IMPL(_, FuncTorchDynamicLayerBackMode, m) {
|
|||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
|
||||
}
|
||||
|
||||
#define JVP_DECOMP(op) \
|
||||
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>());
|
||||
|
||||
#define JVP_DECOMP2(op, overload) \
|
||||
m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>());
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FuncTorchDynamicLayerFrontMode, m) {
|
||||
JVP_DECOMP(nll_loss_backward);
|
||||
JVP_DECOMP(nll_loss2d_backward);
|
||||
JVP_DECOMP(_log_softmax_backward_data);
|
||||
JVP_DECOMP(_softmax_backward_data);
|
||||
OP_DECOMPOSE(log_sigmoid);
|
||||
JVP_DECOMP(log_sigmoid_forward);
|
||||
JVP_DECOMP(native_layer_norm_backward);
|
||||
JVP_DECOMP(native_batch_norm_backward);
|
||||
JVP_DECOMP(cudnn_batch_norm_backward);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -1124,9 +1124,6 @@ class TestOperators(TestCase):
|
|||
# RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor,
|
||||
# this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3].
|
||||
xfail('normal', ''),
|
||||
xfail('_masked.log_softmax', ''), # NYI: forward-AD for _log_softmax_backward_data
|
||||
xfail('_masked.softmax', ''), # NYI: forward-AD for _softmax_backward_data
|
||||
xfail('_masked.softmin', ''), # NYI: forward-AD for _softmax_backward_data
|
||||
xfail('cdist', ''), # NYI: forward-AD for _cdist_forward
|
||||
xfail('cholesky', ''), # NYI: forward-AD for cholesky
|
||||
xfail('logcumsumexp', ''), # NYI: forward-AD for logcumsumexp
|
||||
|
|
@ -1134,10 +1131,7 @@ class TestOperators(TestCase):
|
|||
xfail('nn.functional.grid_sample', ''), # NYI: forward AD for grid_sampler_2d
|
||||
xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward
|
||||
xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward
|
||||
xfail('nn.functional.instance_norm', ''), # NYI: forward AD for native_batch_norm_backward
|
||||
xfail('nn.functional.logsigmoid', ''), # not differentiable w.r.t. buffer
|
||||
xfail('nn.functional.softmin', ''), # NYI: forward-AD for _softmax_backward_data
|
||||
xfail('nn.functional.softmin', 'with_dtype'), # NYI: forward-AD for _softmax_backward_data
|
||||
xfail('renorm', ''), # NYI: forward AD for renorm
|
||||
xfail('symeig', ''), # NYI: forward AD for symeig
|
||||
xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward
|
||||
|
|
@ -1151,7 +1145,6 @@ class TestOperators(TestCase):
|
|||
xfail('scatter_reduce', 'mean'), # NYI: forward-AD for scatter_reduce
|
||||
xfail('scatter_reduce', 'prod'), # NYI: forward-AD for scatter_reduce
|
||||
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
|
||||
xfail('native_layer_norm', ''), # NYI: forward-AD for native_layer_norm_backward
|
||||
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
|
||||
skip('as_strided_scatter', ''), # seems flaky
|
||||
xfail('segment_reduce', 'offsets'), # NYI: forward-AD for segment_reduce
|
||||
|
|
@ -1207,37 +1200,8 @@ class TestOperators(TestCase):
|
|||
expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec))
|
||||
return expected
|
||||
|
||||
# HACK: obviously pytorch should also have the same coverage
|
||||
# For things that do have the same coverage, we test that jvp x vjp
|
||||
# are the same between PyTorch and functorch. For things that don't,
|
||||
# we check that jacfwd(vjp) and jacrev(vjp) are the same. This results
|
||||
# in slower tests.
|
||||
FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH = {
|
||||
'nn.functional.nll_loss',
|
||||
'softmax',
|
||||
'log_softmax',
|
||||
'nn.functional.cross_entropy',
|
||||
'nn.functional.layer_norm',
|
||||
'nn.functional.batch_norm',
|
||||
}
|
||||
if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH:
|
||||
self.assertFalse(op.supports_fwgrad_bwgrad,
|
||||
f"{op.name} now supports forward over reverse without a decomposition. " +
|
||||
"Please remove the decomposition version")
|
||||
|
||||
def is_differentiable(t):
|
||||
return isinstance(t, torch.Tensor) and t.dtype == torch.float32
|
||||
args = (cotangents, *primals)
|
||||
if op.name == 'nn.functional.binary_cross_entropy':
|
||||
argnums = (0, 1) # targets is float32 but isn't differentiable
|
||||
atol_rtol = 1.5e-4, 1.3e-06
|
||||
else:
|
||||
argnums = tuple(i for i in range(len(args)) if is_differentiable(args[i]))
|
||||
atol_rtol = None
|
||||
self._compare_jacobians_of_vjp(fn, args, argnums, atol_rtol)
|
||||
else:
|
||||
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
|
||||
self.assertEqual(result, expected)
|
||||
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@skipOps('TestOperators', 'test_vmapjvpvjp', vjp_fail.union({
|
||||
# Following operatos take too long, hence skipped
|
||||
|
|
|
|||
|
|
@ -1957,7 +1957,20 @@
|
|||
|
||||
- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
|
||||
self: log_sigmoid_backward(grad, self, buffer)
|
||||
output: auto_element_wise
|
||||
# HACK: This is just auto_element_wise followed by a view_as. The reason we have
|
||||
# this is bc forward AD was complaining here about the shapes not being the same:
|
||||
# the primal/tangent are 0-D/1-D respectively. This started happening after moving the
|
||||
# jvp decomposition mechanism from functorch to core, possibly due to a batching rule.
|
||||
# In functorch we rely on OP_DECOMPOSE, but now we compute forward AD using an actual
|
||||
# formula.
|
||||
#
|
||||
# We'd like to avoid keeping the entire jvp decomposition mechanism in functorch,
|
||||
# just for this single decomposition, but also want to avoid any cases from regressing:
|
||||
# e.g. test_vmapjvpall_nn_functional_logsigmoid_cuda_float32 (passes on cpu, fails on CUDA).
|
||||
#
|
||||
# We should either figure out what is going on with vmap or perhaps fwd AD could
|
||||
# be more tolerant about 0-dim vs 1-dim tensors
|
||||
output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj().view_as(self_p)
|
||||
|
||||
- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
|
||||
self: _log_softmax_backward_data(grad, result, dim, self.scalar_type())
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from torchgen.api import cpp
|
|||
from torchgen.api.autograd import (
|
||||
DifferentiableInput,
|
||||
dispatch_strategy,
|
||||
ForwardDerivative,
|
||||
gen_differentiable_outputs,
|
||||
is_differentiable,
|
||||
NativeFunctionWithDifferentiabilityInfo,
|
||||
|
|
@ -599,8 +600,14 @@ at::redispatch::${api_name}(${unpacked_args})"""
|
|||
DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate(
|
||||
"""\
|
||||
auto ${tmp_var} = ([&]() {
|
||||
${guard}
|
||||
return ${base_type_call};
|
||||
if (${try_jit_decomposition_bool} && ${any_has_forward_grad}) {
|
||||
static c10::OperatorName full_name("aten::${op_name}", "${op_overload}");
|
||||
static c10::optional<c10::OperatorHandle> opt_op = c10::Dispatcher::singleton().findSchema(full_name);
|
||||
return impl::run_jit_decomposition_with_args_for_jvp<${return_types}>("${op_name}", *opt_op, ks, ${arg_names});
|
||||
} else {
|
||||
${guard}
|
||||
return ${base_type_call};
|
||||
}
|
||||
})();
|
||||
"""
|
||||
)
|
||||
|
|
@ -644,6 +651,12 @@ isFwGradDefined(${req_inp})\
|
|||
"""
|
||||
)
|
||||
|
||||
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE = CodeTemplate(
|
||||
"""\
|
||||
isFwGradDefinedTensorList(${req_inp})\
|
||||
"""
|
||||
)
|
||||
|
||||
FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate(
|
||||
"""\
|
||||
auto ${inp}_t_raw = toNonOptFwGrad(${inp});
|
||||
|
|
@ -976,6 +989,23 @@ def emit_body(
|
|||
f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative"
|
||||
)
|
||||
|
||||
if requires_derivative and not len(fw_derivatives) == 0:
|
||||
assert sum(len(derivative.var_names) for derivative in fw_derivatives) == len(
|
||||
differentiable_outputs
|
||||
), (
|
||||
"Expected the number of forward derivatives implemented to match the "
|
||||
"number of differentiable outputs. NB: This only applies when at least "
|
||||
"one forward derivative is implemented. Not implementing any forward "
|
||||
"derivatives is also okay, and we would require inputs to the op to "
|
||||
"not have associated tangents in that case."
|
||||
)
|
||||
try_jit_decomposition = (
|
||||
requires_derivative
|
||||
and len(fw_derivatives) == 0
|
||||
and (not modifies_arguments(f))
|
||||
and (not returns_void)
|
||||
)
|
||||
|
||||
def emit_save_inputs() -> List[str]:
|
||||
setup: List[str] = []
|
||||
if info is None or not info.has_derivatives:
|
||||
|
|
@ -1342,7 +1372,9 @@ def emit_body(
|
|||
)
|
||||
return call
|
||||
|
||||
def emit_call(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
|
||||
def emit_call(
|
||||
f: NativeFunction, unpacked_bindings: List[Binding], try_jit_decomposition: bool
|
||||
) -> str:
|
||||
# We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
|
||||
# (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
|
||||
# the baseType operations still dispatch to non-Variable type, even if the arguments passed
|
||||
|
|
@ -1356,13 +1388,47 @@ def emit_body(
|
|||
else:
|
||||
guard = "at::AutoDispatchBelowADInplaceOrView guard;"
|
||||
|
||||
try_jit_decomposition_bool = "true" if try_jit_decomposition else "false"
|
||||
any_has_forward_grad = (
|
||||
get_any_has_fw_grad_cond(derivative=None)
|
||||
if requires_derivative
|
||||
else "false"
|
||||
)
|
||||
return_types = ", ".join(
|
||||
[cpp.return_type(a, symint=True).cpp_type() for a in f.func.returns]
|
||||
)
|
||||
if len(f.func.returns) > 1:
|
||||
return_types = f"std::tuple<{return_types}>"
|
||||
|
||||
arg_names = [
|
||||
a.name
|
||||
for a in cpp.arguments(
|
||||
f.func.arguments,
|
||||
faithful=True,
|
||||
symint=True,
|
||||
method=False,
|
||||
cpp_no_default_args=set(),
|
||||
)
|
||||
]
|
||||
|
||||
if not modifies_arguments(f) and not returns_void:
|
||||
# Just to keep things simple here, we only care about this path
|
||||
# and always emit the if/else for now
|
||||
call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute(
|
||||
base_type_call=base_type_call, tmp_var=TMP_VAR, guard=guard
|
||||
base_type_call=base_type_call,
|
||||
tmp_var=TMP_VAR,
|
||||
guard=guard,
|
||||
try_jit_decomposition_bool=try_jit_decomposition_bool,
|
||||
any_has_forward_grad=any_has_forward_grad,
|
||||
op_name=cpp.name(f.func),
|
||||
op_overload=f.func.name.overload_name,
|
||||
return_types=return_types,
|
||||
arg_names=arg_names,
|
||||
)
|
||||
|
||||
call += wrap_output(f, unpacked_bindings, TMP_VAR)
|
||||
else:
|
||||
assert not try_jit_decomposition
|
||||
call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute(
|
||||
base_type_call=base_type_call, guard=guard
|
||||
)
|
||||
|
|
@ -1410,38 +1476,14 @@ def emit_body(
|
|||
def emit_any_has_forward_grad() -> List[str]:
|
||||
content: List[str] = []
|
||||
for derivative in fw_derivatives:
|
||||
assert derivative.required_inputs_fw_grad is not None
|
||||
requires_fw_grad = " || ".join(
|
||||
[
|
||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
|
||||
for inp in differentiable_inputs
|
||||
if inp.name in derivative.required_inputs_fw_grad
|
||||
]
|
||||
)
|
||||
if not requires_fw_grad:
|
||||
# Handle functions like stack
|
||||
# For these, we don't unpack anything and always call the user function
|
||||
if not (
|
||||
len(differentiable_inputs) == 1
|
||||
and is_tensor_list_type(differentiable_inputs[0].type)
|
||||
):
|
||||
raise RuntimeError(
|
||||
f'No differentiable input to "{name}" is a differentiable Tensor (as the provided '
|
||||
"forward AD formula does not use any input tangent) even though a forward gradient "
|
||||
"formula has been defined for it. This case should only happen for function that "
|
||||
"take a single TensorList as input. All other cases are not supported right now."
|
||||
)
|
||||
requires_fw_grad = "true"
|
||||
|
||||
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
|
||||
if info and info.output_differentiability_conditions:
|
||||
assert len(info.output_differentiability_conditions) == 1
|
||||
requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && ({requires_fw_grad})"
|
||||
|
||||
requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && {requires_fw_grad}"
|
||||
content.append(
|
||||
f"auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};\n"
|
||||
f"(void){get_any_has_forward_grad_name(derivative.var_names)};"
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
def emit_check_inplace() -> List[str]:
|
||||
|
|
@ -1564,46 +1606,83 @@ def emit_body(
|
|||
content.append("\n".join(fw_grad_setters))
|
||||
return content
|
||||
|
||||
def get_any_has_fw_grad_cond(derivative: Optional[ForwardDerivative]) -> str:
|
||||
#
|
||||
# Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)")
|
||||
#
|
||||
if derivative is None:
|
||||
# (1) If a derivative is NOT provided, cond will check fw_grad of ALL differentiable inputs
|
||||
# - Used in the out_fn case when we want to forbid fw derivatives
|
||||
# - Used in the case where the fw_derivative is not defined, but we want
|
||||
# To check if there is a decomposition registered for jvp
|
||||
to_check: List[str] = []
|
||||
for inp in list(
|
||||
mapMaybe(
|
||||
gen_differentiable_input,
|
||||
f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator]
|
||||
)
|
||||
):
|
||||
if is_tensor_type(inp.type):
|
||||
to_check.append(
|
||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
|
||||
)
|
||||
elif is_tensor_list_type(inp.type):
|
||||
to_check.append(
|
||||
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE.substitute(
|
||||
req_inp=inp.name
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Unsupported input type for "{name}" when forbidding forward AD usage.'
|
||||
)
|
||||
return f'({" || ".join(to_check)})'
|
||||
else:
|
||||
# (2) If derivative is provided, use that information to determine which inputs
|
||||
# to check fw_grad for
|
||||
assert derivative.required_inputs_fw_grad is not None
|
||||
|
||||
if len(derivative.required_inputs_fw_grad) == 0:
|
||||
# Handle functions like stack
|
||||
# For these, we don't unpack anything and always call the user function
|
||||
if not (
|
||||
len(differentiable_inputs) == 1
|
||||
and is_tensor_list_type(differentiable_inputs[0].type)
|
||||
):
|
||||
raise RuntimeError(
|
||||
f'No differentiable input to "{name}" is a differentiable Tensor (as the provided '
|
||||
"forward AD formula does not use any input tangent) even though a forward gradient "
|
||||
"formula has been defined for it. This case should only happen for function that "
|
||||
"take a single TensorList as input. All other cases are not supported right now."
|
||||
)
|
||||
any_has_fw_grad = "true"
|
||||
else:
|
||||
any_has_fw_grad = " || ".join(
|
||||
[
|
||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
|
||||
for inp in differentiable_inputs
|
||||
if inp.name in derivative.required_inputs_fw_grad
|
||||
]
|
||||
)
|
||||
any_has_fw_grad = f"({any_has_fw_grad})"
|
||||
|
||||
return any_has_fw_grad
|
||||
|
||||
def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str:
|
||||
def get_msg() -> str:
|
||||
if is_out_fn:
|
||||
msg = "because it is an out= function"
|
||||
else:
|
||||
msg = (
|
||||
"because it has not been implemented yet.\\nPlease file an issue "
|
||||
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
|
||||
"so that we can prioritize its implementation."
|
||||
)
|
||||
return msg
|
||||
|
||||
res = ""
|
||||
to_check: List[str] = []
|
||||
for inp in list(
|
||||
mapMaybe(
|
||||
gen_differentiable_input,
|
||||
f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator]
|
||||
if is_out_fn:
|
||||
msg = "because it is an out= function"
|
||||
else:
|
||||
msg = (
|
||||
"because it has not been implemented yet.\\nPlease file an issue "
|
||||
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
|
||||
"so that we can prioritize its implementation."
|
||||
)
|
||||
):
|
||||
if is_tensor_type(inp.type):
|
||||
to_check.append(
|
||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
|
||||
)
|
||||
elif is_tensor_list_type(inp.type):
|
||||
cond = FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp="_t")
|
||||
res += FW_DERIVATIVE_FORBID_LIST_TEMPLATE.substitute(
|
||||
arg=inp.name, cond=cond, name=name, msg=get_msg()
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Unsupported input type for "{name}" when forbidding forward AD usage.'
|
||||
)
|
||||
|
||||
if len(to_check) > 0:
|
||||
cond = " || ".join(to_check)
|
||||
res += FW_DERIVATIVE_FORBID_TEMPLATE.substitute(
|
||||
cond=cond, name=name, msg=get_msg()
|
||||
)
|
||||
return res
|
||||
cond = get_any_has_fw_grad_cond(derivative=None)
|
||||
return (
|
||||
FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond, name=name, msg=msg)
|
||||
if cond != ""
|
||||
else ""
|
||||
)
|
||||
|
||||
body: List[str] = []
|
||||
unpack_args_stats, unpacked_bindings = unpack_args(f)
|
||||
|
|
@ -1617,7 +1696,7 @@ def emit_body(
|
|||
body.extend(setup_derivative(differentiable_inputs))
|
||||
body.append(declare_returned_variables(f))
|
||||
|
||||
body.append(emit_call(f, unpacked_bindings))
|
||||
body.append(emit_call(f, unpacked_bindings, try_jit_decomposition))
|
||||
if requires_derivative:
|
||||
# set_flags has to appear after version_counter, because rebase_history
|
||||
# requires that the counter is incremented before it is called
|
||||
|
|
@ -1627,20 +1706,11 @@ def emit_body(
|
|||
if is_out_fn:
|
||||
body.append(emit_forbid_fw_derivatives(is_out_fn=True))
|
||||
else:
|
||||
if requires_derivative:
|
||||
body.extend(emit_fw_derivatives())
|
||||
if len(fw_derivatives) == 0:
|
||||
body.append(emit_forbid_fw_derivatives())
|
||||
if requires_derivative and not try_jit_decomposition:
|
||||
if len(fw_derivatives) > 0:
|
||||
body.extend(emit_fw_derivatives())
|
||||
else:
|
||||
assert sum(
|
||||
len(derivative.var_names) for derivative in fw_derivatives
|
||||
) == len(differentiable_outputs), (
|
||||
"Expected the number of forward derivatives implemented to match the "
|
||||
"number of differentiable outputs. NB: This only applies when at least "
|
||||
"one forward derivative is implemented. Not implementing any forward "
|
||||
"derivatives is also okay, and we would require inputs to the op to "
|
||||
"not have associated tangents in that case."
|
||||
)
|
||||
body.append(emit_forbid_fw_derivatives())
|
||||
|
||||
if requires_derivative:
|
||||
# Save only after the forward AD has been set up
|
||||
|
|
|
|||
|
|
@ -893,6 +893,25 @@ def compiled_with_cxx11_abi():
|
|||
from torch._ops import ops
|
||||
from torch._classes import classes
|
||||
|
||||
# Import from torch._decomp import decompositions_for_jvp to register
|
||||
# decompositions for jvp to the jit registry
|
||||
# (decompositions_for_jvp depends on torch.ops, so we place it after)
|
||||
#
|
||||
# FIXME: We specify that __debug__ must be True because
|
||||
# if python is run with -OO or -O flags (i.e., __debug__ is False), we encounter the
|
||||
# following error:
|
||||
#
|
||||
# Return value was annotated as having type Tuple[NoneType, NoneType] but is actually of
|
||||
# type Tuple[Tensor, Tensor]:
|
||||
# File ".../torch/_decomp/__init__.py", line 1585
|
||||
# else:
|
||||
# buffer = z
|
||||
# return min - torch.log1p(z), buffer
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
|
||||
if os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__ and not torch._C._is_deploy_enabled(): # type: ignore[attr-defined]
|
||||
from torch._decomp import decompositions_for_jvp
|
||||
del decompositions_for_jvp
|
||||
|
||||
# quantization depends on torch.fx
|
||||
# Import quantization
|
||||
from torch import quantization as quantization
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ decomposition_table_for_jvp: Dict[torch._ops.OpOverload, Callable] = {}
|
|||
register_decomposition = torch._decomp.register_decomposition
|
||||
aten = torch.ops.aten
|
||||
|
||||
# NOTE: [forward-mode AD decompositions hack]
|
||||
# NOTE: [forward-mode AD decompositions mechanism]
|
||||
#
|
||||
# The mechanism is in VariableType,
|
||||
# IF any inputs have forward grad
|
||||
|
|
@ -23,9 +23,15 @@ aten = torch.ops.aten
|
|||
# Note that we would be building the backward graph at the decomposed level
|
||||
# too, but that is OK, because we would've errored out otherwise anyway.
|
||||
#
|
||||
# TODO: what if jit decompositions exists, should we just use it?
|
||||
# or do we want to have an explicit white list like functorch had
|
||||
# using special JVP_DECOMP DynamicLayerFront kernel
|
||||
# TODO: The mechanism we are using to register decompositions doesn't
|
||||
# seem to be exclusively used for jvp. So open question here is whether
|
||||
# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
|
||||
# If that is the case, we may go down the decomposition path unexpectedly
|
||||
# (and possibly produce an unintelligible error) vs erroring out earlier and
|
||||
# printing that the forward AD formula is not implemented.
|
||||
#
|
||||
# The solution to this may be to have a explicitly white list control when
|
||||
# to enable the decomposition.
|
||||
|
||||
|
||||
def maybe_register_decomposition(op):
|
||||
|
|
@ -179,7 +185,7 @@ def native_layer_norm_backward(
|
|||
if len(outer_dim_indices) > 0:
|
||||
d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
|
||||
else:
|
||||
d_bias = grad_out
|
||||
d_bias = grad_out.clone()
|
||||
elif bias is not None:
|
||||
d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -880,7 +880,7 @@ def add(
|
|||
|
||||
# TODO: add docstring
|
||||
atan2 = _make_elementwise_binary_reference(
|
||||
prims.atan2,
|
||||
prims.atan2, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
supports_lhs_python_scalar=False,
|
||||
supports_rhs_python_scalar=False,
|
||||
|
|
@ -888,33 +888,33 @@ atan2 = _make_elementwise_binary_reference(
|
|||
|
||||
# TODO: add docstring
|
||||
bitwise_and = _make_elementwise_binary_reference(
|
||||
prims.bitwise_and,
|
||||
prims.bitwise_and, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
)
|
||||
|
||||
# TODO: add docstring
|
||||
bitwise_left_shift = _make_elementwise_binary_reference(
|
||||
prims.shift_left,
|
||||
prims.shift_left, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
aten_op=torch.ops.aten.bitwise_left_shift, # prim/aten name mismatch
|
||||
)
|
||||
|
||||
# TODO: add docstring
|
||||
bitwise_or = _make_elementwise_binary_reference(
|
||||
prims.bitwise_or,
|
||||
prims.bitwise_or, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
)
|
||||
|
||||
# TODO: add docstring
|
||||
bitwise_right_shift = _make_elementwise_binary_reference(
|
||||
prims.shift_right_arithmetic,
|
||||
prims.shift_right_arithmetic, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
aten_op=torch.ops.aten.bitwise_right_shift, # prim/aten name mismatch
|
||||
)
|
||||
|
||||
# TODO: add docstring
|
||||
bitwise_xor = _make_elementwise_binary_reference(
|
||||
prims.bitwise_xor,
|
||||
prims.bitwise_xor, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
)
|
||||
|
||||
|
|
@ -971,7 +971,7 @@ def div(
|
|||
|
||||
# TODO: add docstring
|
||||
eq = _make_elementwise_binary_reference(
|
||||
prims.eq,
|
||||
prims.eq, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
supports_lhs_python_scalar=False,
|
||||
)
|
||||
|
|
@ -1135,7 +1135,7 @@ floor_divide = _make_elementwise_binary_reference(
|
|||
|
||||
# TODO: add docstring
|
||||
fmax = _make_elementwise_binary_reference(
|
||||
prims.fmax,
|
||||
prims.fmax, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
aten_op=torch.ops.aten.fmax,
|
||||
supports_lhs_python_scalar=False,
|
||||
|
|
@ -1144,7 +1144,7 @@ fmax = _make_elementwise_binary_reference(
|
|||
|
||||
# TODO: add docstring
|
||||
fmin = _make_elementwise_binary_reference(
|
||||
prims.fmin,
|
||||
prims.fmin, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
aten_op=torch.ops.aten.fmin,
|
||||
supports_lhs_python_scalar=False,
|
||||
|
|
@ -1153,7 +1153,7 @@ fmin = _make_elementwise_binary_reference(
|
|||
|
||||
# TODO: add docstring
|
||||
fmod = _make_elementwise_binary_reference(
|
||||
prims.fmod,
|
||||
prims.fmod, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
aten_op=torch.ops.aten.fmod,
|
||||
supports_lhs_python_scalar=False,
|
||||
|
|
@ -1162,7 +1162,7 @@ fmod = _make_elementwise_binary_reference(
|
|||
|
||||
# TODO: add docstring
|
||||
gcd = _make_elementwise_binary_reference(
|
||||
prims.gcd,
|
||||
prims.gcd, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
aten_op=torch.ops.aten.gcd,
|
||||
supports_lhs_python_scalar=False,
|
||||
|
|
@ -1171,14 +1171,14 @@ gcd = _make_elementwise_binary_reference(
|
|||
|
||||
# TODO: add docstring
|
||||
ge = _make_elementwise_binary_reference(
|
||||
prims.ge,
|
||||
prims.ge, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
supports_lhs_python_scalar=False,
|
||||
)
|
||||
|
||||
# TODO: add docstring
|
||||
gt = _make_elementwise_binary_reference(
|
||||
prims.gt,
|
||||
prims.gt, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
supports_lhs_python_scalar=False,
|
||||
)
|
||||
|
|
@ -1201,21 +1201,21 @@ heaviside = _make_elementwise_binary_reference(
|
|||
)
|
||||
|
||||
hypot = _make_elementwise_binary_reference(
|
||||
prims.hypot,
|
||||
prims.hypot, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
supports_lhs_python_scalar=False,
|
||||
supports_rhs_python_scalar=False,
|
||||
)
|
||||
|
||||
igamma = _make_elementwise_binary_reference(
|
||||
prims.igamma,
|
||||
prims.igamma, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
supports_lhs_python_scalar=False,
|
||||
supports_rhs_python_scalar=False,
|
||||
)
|
||||
|
||||
igammac = _make_elementwise_binary_reference(
|
||||
prims.igammac,
|
||||
prims.igammac, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
supports_lhs_python_scalar=False,
|
||||
supports_rhs_python_scalar=False,
|
||||
|
|
@ -1320,7 +1320,7 @@ lcm = _make_elementwise_binary_reference(
|
|||
|
||||
# TODO: add docstring
|
||||
le = _make_elementwise_binary_reference(
|
||||
prims.le,
|
||||
prims.le, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
supports_lhs_python_scalar=False,
|
||||
)
|
||||
|
|
@ -1383,39 +1383,39 @@ logical_xor = _make_elementwise_binary_reference(
|
|||
|
||||
# TODO: add docstring
|
||||
lt = _make_elementwise_binary_reference(
|
||||
prims.lt,
|
||||
prims.lt, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
supports_lhs_python_scalar=False,
|
||||
)
|
||||
|
||||
# TODO: add docstring
|
||||
maximum = _make_elementwise_binary_reference(
|
||||
prims.maximum,
|
||||
prims.maximum, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
)
|
||||
|
||||
# TODO: add docstring
|
||||
minimum = _make_elementwise_binary_reference(
|
||||
prims.minimum,
|
||||
prims.minimum, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
)
|
||||
|
||||
# TODO: add docstring
|
||||
mul = _make_elementwise_binary_reference(
|
||||
prims.mul,
|
||||
prims.mul, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
)
|
||||
|
||||
# TODO: add docstring
|
||||
ne = _make_elementwise_binary_reference(
|
||||
prims.ne,
|
||||
prims.ne, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
supports_lhs_python_scalar=False,
|
||||
)
|
||||
|
||||
# TODO: add docstring
|
||||
nextafter = _make_elementwise_binary_reference(
|
||||
prims.nextafter,
|
||||
prims.nextafter, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
|
||||
supports_lhs_python_scalar=False,
|
||||
supports_rhs_python_scalar=False,
|
||||
|
|
@ -1423,7 +1423,7 @@ nextafter = _make_elementwise_binary_reference(
|
|||
|
||||
# TODO: add docstring
|
||||
remainder = _make_elementwise_binary_reference(
|
||||
prims.remainder,
|
||||
prims.remainder, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
aten_op=torch.ops.aten.remainder,
|
||||
)
|
||||
|
|
@ -1484,7 +1484,7 @@ def sub(
|
|||
|
||||
# TODO: add docstring
|
||||
true_divide = _make_elementwise_binary_reference(
|
||||
prims.div,
|
||||
prims.div, # type: ignore[has-type]
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
aten_op=None, # CompositeImplicitAutograd
|
||||
)
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
|
|||
|
||||
|
||||
zeta = _make_elementwise_binary_reference(
|
||||
prims.zeta,
|
||||
prims.zeta, # type: ignore[has-type]
|
||||
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
aten_op=torch.ops.aten.special_zeta,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -368,7 +368,7 @@ def _get_analytical_jacobian_forward_ad(fn, inputs, outputs, *, check_grad_dtype
|
|||
dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs)
|
||||
for index_o, d_o in enumerate(dual_outputs):
|
||||
val, res = fwAD.unpack_dual(d_o)
|
||||
if check_grad_dtypes and val.is_complex() != res.is_complex():
|
||||
if check_grad_dtypes and res is not None and val.is_complex() != res.is_complex():
|
||||
raise GradcheckError('Forward AD gradient has dtype mismatch.')
|
||||
|
||||
if res is None:
|
||||
|
|
|
|||
|
|
@ -1285,6 +1285,14 @@ Call this whenever a new thread is created in order to propagate values from
|
|||
std::cout << "Excluded: " << toString(local_keyset.excluded_) << "\n";
|
||||
});
|
||||
|
||||
py_module.def("_is_deploy_enabled", []() {
|
||||
#if defined(USE_DEPLOY)
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
});
|
||||
|
||||
const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
|
||||
THPDefaultCPUGenerator =
|
||||
(THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@
|
|||
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <ATen/core/boxing/KernelFunction.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
#include <torch/csrc/autograd/edge.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/autograd/functions/basic_ops.h>
|
||||
|
|
@ -11,6 +14,7 @@
|
|||
#include <torch/csrc/autograd/variable.h>
|
||||
|
||||
#include <torch/csrc/autograd/functions/utils.h>
|
||||
#include <torch/csrc/autograd/jit_decomp_interface.h>
|
||||
#include <torch/csrc/utils/variadic.h>
|
||||
|
||||
#include <array>
|
||||
|
|
@ -456,5 +460,58 @@ inline std::vector<c10::ScalarType> to_args_scalartypes(
|
|||
return args_scalartypes;
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
namespace {
|
||||
|
||||
// If run_jit_decomposition were not a member function, we would be able
|
||||
// to pass this as a template parameter to c10::Boxedkernel::makeFromFunction.
|
||||
// However, member functions cannot be passed this way - instead we wrap our
|
||||
// call in this functor so it can be passed to c10::BoxedKernel::makeFromFunctor
|
||||
class WrapperFunctor final : public c10::OperatorKernel {
|
||||
public:
|
||||
WrapperFunctor(JitDecompInterface* impl) : impl_(impl){};
|
||||
|
||||
void operator()(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet ks,
|
||||
torch::jit::Stack* stack) {
|
||||
impl_->run_jit_decomposition(op, stack);
|
||||
}
|
||||
JitDecompInterface* impl_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <class Return, class... Args>
|
||||
Return run_jit_decomposition_with_args_for_jvp(
|
||||
c10::string_view name,
|
||||
const c10::OperatorHandle& opHandle,
|
||||
c10::DispatchKeySet dispatchKeySet,
|
||||
Args&&... args) {
|
||||
// see NOTE: [Jit Decomposition Interface]
|
||||
JitDecompInterface* impl = getJitDecompImpl();
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
impl && impl->has_jit_decomposition(opHandle.schema()),
|
||||
"Trying to use forward AD with ",
|
||||
name,
|
||||
" that does not support it because it has not been implemented yet.\nPlease file an issue "
|
||||
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
|
||||
"so that we can prioritize its implementation.\n"
|
||||
"Note that forward AD support for some operators require PyTorch to be built with "
|
||||
"TorchScript and for JIT to be enabled. "
|
||||
"If the environment var PYTORCH_JIT=0 is set or if the library is not built with TorchScript, "
|
||||
"some operators may no longer be used with forward AD.");
|
||||
|
||||
return c10::KernelFunction::makeFromBoxedKernel(
|
||||
c10::BoxedKernel::makeFromFunctor(
|
||||
std::make_unique<WrapperFunctor>(impl)))
|
||||
.call<Return, Args...>(
|
||||
opHandle, dispatchKeySet, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -100,5 +100,23 @@ inline bool isFwGradDefined(const c10::optional<at::Tensor>& t) {
|
|||
return t.has_value() && t->defined() && t->_fw_grad(/*level */ 0).defined();
|
||||
}
|
||||
|
||||
inline bool isFwGradDefinedTensorList(const at::TensorList& variables) {
|
||||
bool ret = false;
|
||||
for (auto& variable : variables) {
|
||||
ret |= isFwGradDefined(variable);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline bool isFwGradDefinedTensorList(
|
||||
const c10::List<c10::optional<at::Tensor>> li) {
|
||||
bool ret = false;
|
||||
for (auto i : c10::irange(li.size())) {
|
||||
auto t = li.get(i);
|
||||
ret |= (t.has_value() && isFwGradDefined(t.value()));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
|
|
|||
21
torch/csrc/autograd/jit_decomp_interface.cpp
Normal file
21
torch/csrc/autograd/jit_decomp_interface.cpp
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
#include <torch/csrc/autograd/jit_decomp_interface.h>
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
namespace impl {
|
||||
|
||||
namespace {
|
||||
JitDecompInterface* impl = nullptr;
|
||||
}
|
||||
|
||||
void setJitDecompImpl(JitDecompInterface* impl_) {
|
||||
impl = impl_;
|
||||
}
|
||||
|
||||
JitDecompInterface* getJitDecompImpl() {
|
||||
return impl;
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
54
torch/csrc/autograd/jit_decomp_interface.h
Normal file
54
torch/csrc/autograd/jit_decomp_interface.h
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
// NOTE: [Jit Decomposition Interface]
|
||||
//
|
||||
// For some context of why we need this at all, see NOTE: [forward-mode AD
|
||||
// decompositions mechanism]
|
||||
//
|
||||
// Introducing that mechanism from the NOTE is problematic because:
|
||||
// - it relies on TorchScript, so now VariableTypeX.cpp depends on TorchScript.
|
||||
// - there exist internal builds like lite_trainer, which depend on VariableType
|
||||
// but do not depend on TorchScript.
|
||||
//
|
||||
// For internal builds like lite_trainer builds to pass, and for OSS builds that
|
||||
// do depend on TorchScript to still support the forward AD decomp mechanism, we
|
||||
// implement a PImpl pattern to avoid a static dependency in favor of a dynamic
|
||||
// one
|
||||
// - during static initialization time, if the library is built with TorchScript
|
||||
// setJitDecompImpl is called in decomposition_registry.cpp setting a global
|
||||
// ptr to the impl
|
||||
// - when the program is run,if getJitDecompImpl returns a non null ptr, we can
|
||||
// carry on normally, otherwise we gracefully error out
|
||||
//
|
||||
// For extra context, see VariableHooksInterface.h, where a similar technique
|
||||
// is used
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
namespace impl {
|
||||
|
||||
struct TORCH_API JitDecompInterface {
|
||||
virtual ~JitDecompInterface() = default;
|
||||
virtual bool has_jit_decomposition(
|
||||
const c10::FunctionSchema& schema) const = 0;
|
||||
virtual void run_jit_decomposition(
|
||||
const c10::OperatorHandle& op,
|
||||
jit::Stack* stack) const = 0;
|
||||
};
|
||||
|
||||
TORCH_API void setJitDecompImpl(JitDecompInterface* impl);
|
||||
TORCH_API JitDecompInterface* getJitDecompImpl();
|
||||
|
||||
struct TORCH_API JitDecompRegisterer {
|
||||
explicit JitDecompRegisterer(JitDecompInterface* impl) {
|
||||
setJitDecompImpl(impl);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
|
@ -8,6 +8,7 @@
|
|||
#include <torch/csrc/jit/serialization/import_source.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/autograd/jit_decomp_interface.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
|
|
@ -160,6 +161,47 @@ void RegisterDecomposition(
|
|||
schema_to_decomposition[&schema] = g;
|
||||
}
|
||||
|
||||
// see NOTE: [Jit Decomposition Interface]
|
||||
struct JitDecomp final : torch::autograd::impl::JitDecompInterface {
|
||||
bool has_jit_decomposition(const c10::FunctionSchema& schema) const override;
|
||||
void run_jit_decomposition(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack) const override;
|
||||
};
|
||||
|
||||
JitDecomp jitDecomp;
|
||||
torch::autograd::impl::JitDecompRegisterer registerJitDecomp(&jitDecomp);
|
||||
|
||||
void JitDecomp::run_jit_decomposition(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack) const {
|
||||
::torch::jit::run_jit_decomposition(op, stack);
|
||||
}
|
||||
|
||||
bool JitDecomp::has_jit_decomposition(const FunctionSchema& schema) const {
|
||||
return ::torch::jit::has_jit_decomposition(schema);
|
||||
}
|
||||
|
||||
void run_jit_decomposition(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
// TODO: templatize based on op and keep static trace_exec
|
||||
auto* trace_exec = torch::jit::GetDecompositionExecutor(schema);
|
||||
trace_exec->run((*stack));
|
||||
if (stack->back().isTuple()) {
|
||||
at::IValue tup = stack->back();
|
||||
stack->pop_back();
|
||||
for (const auto& elem : tup.toTuple()->elements()) {
|
||||
stack->push_back(elem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool has_jit_decomposition(const FunctionSchema& schema) {
|
||||
return GetDecompositionFunction(schema).has_value();
|
||||
}
|
||||
|
||||
Function* GetDecompositionExecutor(const FunctionSchema& schema) {
|
||||
auto maybe_func = GetDecompositionFunction(schema);
|
||||
TORCH_INTERNAL_ASSERT(maybe_func);
|
||||
|
|
|
|||
|
|
@ -25,5 +25,11 @@ TORCH_API Function* GetDecompositionExecutor(const char* schema_literal);
|
|||
|
||||
TORCH_API Function* GetDecompositionExecutor(const FunctionSchema& schema);
|
||||
|
||||
TORCH_API void run_jit_decomposition(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack);
|
||||
|
||||
TORCH_API bool has_jit_decomposition(const FunctionSchema& schema);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ try:
|
|||
except ImportError:
|
||||
HAS_SYMPY = False
|
||||
|
||||
aten = torch.ops.aten
|
||||
aten = torch.ops.aten # type: ignore[has-type]
|
||||
|
||||
__all__ = [
|
||||
"has_symbolic_sizes_strides", "create_contiguous", "is_symbolic_op", "handle_symbolic_op", "PySymInt", "ShapeEnv",
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@ import torch._refs as refs # noqa: F401
|
|||
import torch._refs.nn.functional
|
||||
import torch._refs.special
|
||||
import torch._refs.linalg
|
||||
|
||||
import torch._prims as prims # noqa: F401
|
||||
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
|
@ -10233,6 +10232,7 @@ op_db: List[OpInfo] = [
|
|||
assert_jit_shape_analysis=True,
|
||||
assert_autodiffed=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=True),
|
||||
OpInfo('softmax',
|
||||
aliases=('special.softmax', 'nn.functional.softmax',),
|
||||
|
|
@ -10242,6 +10242,7 @@ op_db: List[OpInfo] = [
|
|||
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
|
||||
assert_autodiffed=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=True),
|
||||
# `softmin` supports different dtypes based on whether `dtype` argument,
|
||||
# is passed or not. Hence two OpInfo entries, one with dtype and other without.
|
||||
|
|
@ -10254,6 +10255,7 @@ op_db: List[OpInfo] = [
|
|||
assert_jit_shape_analysis=False,
|
||||
assert_autodiffed=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False),
|
||||
OpInfo('nn.functional.softmin',
|
||||
variant_test_name="with_dtype",
|
||||
|
|
@ -10262,6 +10264,7 @@ op_db: List[OpInfo] = [
|
|||
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
|
||||
assert_autodiffed=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False),
|
||||
OpInfo(
|
||||
"nn.functional.cross_entropy",
|
||||
|
|
@ -10270,6 +10273,7 @@ op_db: List[OpInfo] = [
|
|||
sample_inputs_func=sample_inputs_cross_entropy,
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
decorators=(
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-3)}),
|
||||
|
|
@ -10361,6 +10365,7 @@ op_db: List[OpInfo] = [
|
|||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
assert_jit_shape_analysis=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
sample_inputs_func=sample_inputs_native_layer_norm,
|
||||
error_inputs_func=error_inputs_native_layer_norm,
|
||||
skips=(
|
||||
|
|
@ -10732,6 +10737,7 @@ op_db: List[OpInfo] = [
|
|||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
decorators=[
|
||||
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
|
||||
# Consider making it a parameter or input, or detaching the gradient
|
||||
|
|
@ -10750,6 +10756,7 @@ op_db: List[OpInfo] = [
|
|||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
assert_jit_shape_analysis=True,
|
||||
decorators=[
|
||||
DecorateInfo(
|
||||
|
|
@ -11789,6 +11796,7 @@ op_db: List[OpInfo] = [
|
|||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
assert_jit_shape_analysis=True,
|
||||
sample_inputs_func=sample_inputs_batch_norm,
|
||||
skips=(
|
||||
|
|
@ -11811,6 +11819,7 @@ op_db: List[OpInfo] = [
|
|||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
decorators=[onlyCUDA, disablecuDNN],
|
||||
skips=(
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
|
||||
|
|
@ -14772,6 +14781,7 @@ op_db: List[OpInfo] = [
|
|||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_softmax_variant,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
assert_autodiffed=True),
|
||||
OpInfo(
|
||||
'log_softmax',
|
||||
|
|
@ -14781,6 +14791,7 @@ op_db: List[OpInfo] = [
|
|||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
|
||||
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
assert_autodiffed=True),
|
||||
UnaryUfuncInfo('logit',
|
||||
aten_backward_name='logit_backward',
|
||||
|
|
@ -15659,6 +15670,7 @@ op_db: List[OpInfo] = [
|
|||
supports_out=False,
|
||||
sample_inputs_func=sample_inputs_nll_loss,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
assert_jit_shape_analysis=True,
|
||||
skips=(
|
||||
# RuntimeError:
|
||||
|
|
|
|||
|
|
@ -990,6 +990,7 @@ op_db: List[OpInfo] = [
|
|||
),
|
||||
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False,
|
||||
),
|
||||
OpInfo(
|
||||
|
|
@ -1017,6 +1018,7 @@ op_db: List[OpInfo] = [
|
|||
],
|
||||
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False,
|
||||
),
|
||||
OpInfo(
|
||||
|
|
@ -1037,6 +1039,7 @@ op_db: List[OpInfo] = [
|
|||
),
|
||||
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False,
|
||||
),
|
||||
OpInfo(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user