Fix autograd.Function + NJT when an output grad is None (#136875)

For `autograd.Function`, the engine will try to allocate correctly-shaped zeros for `None` grads (i.e. in the case where the output isn't used downstream). It determines the shape of these zeros from the `VariableInfo` entry, which is derived from the forward output shape. For the NJT forward output case, the size info stored will contain a nested int, and calling `zeros()` with this size throws:
```
RuntimeError: .../build/aten/src/ATen/RegisterCPU.cpp:5260: SymIntArrayRef expected to contain only concrete integers
```

This PR fixes this by storing the full tensor in the `VariableInfo` for the nested case and calling `zeros_like()` to allocate correctly-shaped zeros. This is pretty inefficient; ideally we would want to save just the NJT shape and be able to construct zeros from it, but this requires factory function support for nested ints (WIP). So this is a short-term fix until we have that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136875
Approved by: https://github.com/soulitzer, https://github.com/huydhn
This commit is contained in:
Joel Schlosser 2024-10-14 12:58:33 -04:00 committed by PyTorch MergeBot
parent 197601eeea
commit 19918a1863
4 changed files with 52 additions and 4 deletions

View File

@ -7055,6 +7055,36 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
if nt._lengths is not None:
self.assertEqual(nt3._lengths.device, other_device)
@dtypes(torch.float32)
def test_autograd_function_with_None_grad(self, device, dtype):
class MyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
ctx.save_for_backward(inp)
out1 = inp + 1
out2 = inp * 2
return out1, out2
@staticmethod
def backward(ctx, grad_out1, grad_out2):
(inp,) = ctx.saved_tensors
return grad_out1 + grad_out2
f = MyFunction.apply
nt = random_nt_from_dims(
[5, None, 10],
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=True,
)
# Only use one of the autograd.Function outputs downstream so that the grad
# for the other output is None. We're testing that the engine can allocate
# correctly-shaped (NJT) zeros for the grad of the other output in this case.
(out1, _) = f(nt)
out1.backward(torch.ones_like(out1))
@dtypes(torch.float64, torch.float32, torch.half)
def test_jagged_padded_dense_conversion_kernels(self, device, dtype):
values = torch.randn(10, 5, device=device, dtype=dtype)

View File

@ -733,8 +733,18 @@ static void _wrap_outputs(
PyTuple_SetItem(outputs, i, obj);
} else {
if (is_executable) {
// If one of the grad outputs is undefined, a correctly-shaped zeros
// should be used instead. To construct these for NJT, zeros_like() must
// be used until we have factory function support.
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
self->output_info.emplace_back(*wrapped_outputs[i]);
bool is_differentiable =
(non_differentiable.count(
wrapped_outputs[i]->unsafeGetTensorImpl()) == 0 &&
isDifferentiableType(wrapped_outputs[i]->scalar_type()));
bool use_zeros_like = is_differentiable && num_outputs > 1 &&
wrapped_outputs[i]->is_nested();
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
self->output_info.emplace_back(*wrapped_outputs[i], use_zeros_like);
}
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i]));

View File

@ -2,6 +2,7 @@
#include <ATen/Functions.h>
#else
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like.h>
#endif
#include <torch/csrc/autograd/variable.h>
@ -9,13 +10,16 @@
namespace torch::autograd {
VariableInfo::VariableInfo(const Variable& var)
VariableInfo::VariableInfo(const Variable& var, bool use_zeros_like)
: layout(var.layout()),
device(var.device()),
scalar_type(var.scalar_type()),
size(var.sym_sizes().vec()),
requires_grad(var.requires_grad()),
is_empty(false) {}
is_empty(false),
the_var(
use_zeros_like ? std::optional<Variable>(var.detach())
: std::nullopt) {}
VariableInfo::VariableInfo() : requires_grad(false), is_empty(true) {}
@ -23,6 +27,8 @@ Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
if (is_empty) {
// Return undefined tensor.
return at::Tensor();
} else if (the_var.has_value()) {
return at::zeros_like(*the_var);
} else {
return at::zeros_symint(
size, at::TensorOptions(scalar_type).device(device).layout(layout));

View File

@ -6,7 +6,7 @@ namespace torch::autograd {
struct TORCH_API VariableInfo {
explicit VariableInfo();
explicit VariableInfo(const Variable& var);
explicit VariableInfo(const Variable& var, bool use_zeros_like = false);
Variable zeros(at::OptionalDeviceGuard& device_guard) const;
@ -16,6 +16,8 @@ struct TORCH_API VariableInfo {
std::vector<c10::SymInt> size;
bool requires_grad;
bool is_empty;
// needed for e.g. NJTs since they only support zeros_like()
std::optional<Variable> the_var;
};
} // namespace torch::autograd