Support NJT chunk() backward on batch dim (#144584)

Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

Implements `chunk()` backward on the batch dim, which was left out before. This PR unbinds the components and invokes `copy_()` on these to pass along the appropriate gradients.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144584
Approved by: https://github.com/soulitzer
ghstack dependencies: #144582, #144583
This commit is contained in:
Joel Schlosser 2025-01-17 16:53:45 -05:00 committed by PyTorch MergeBot
parent 8a57234033
commit 3ee531f8b9
4 changed files with 25 additions and 32 deletions

View File

@ -4188,17 +4188,13 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
self.assertEqual(chunks[i]._offsets[1:], offsets_expected)
self.assertEqual(nt._values, torch.cat([x._values for x in chunks], dim=0))
with self.assertRaisesRegex(
RuntimeError,
"dim != 0 INTERNAL ASSERT FAILED .* Nested Tensor doesn't support chunk backward on dim=0 yet.",
):
# doesn't support backward for chunk (dim=0) yet
loss = (
chunks[0].values().sum()
+ chunks[1].values().sum()
+ chunks[2].values().sum()
)
loss.backward()
# doesn't support backward for chunk (dim=0) yet
loss = (
chunks[0].values().sum()
+ chunks[1].values().sum()
+ chunks[2].values().sum()
)
loss.backward()
# chunk on ragged dim not supported
with self.assertRaisesRegex(
@ -6232,18 +6228,14 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
c = torch.nested.nested_tensor_from_jagged(
torch.ones(4, 3, device=device), offsets_2
)
# fail when tensors have the same size but not the exact same offset tensor.
with self.assertRaisesRegex(
RuntimeError,
"copy_ only supports Nested Tensors that have same size and the exact same offset tensor.",
):
a.copy_(c)
# should work even though the nested ints are different due to unbound-based copy
a.copy_(c)
# fail when tensors have different sizes
a = a.transpose(1, 2)
with self.assertRaisesRegex(
RuntimeError,
"copy_ only supports Nested Tensors that have same size and the exact same offset tensor.",
"expected compatible input and src shapes, but got",
):
a.copy_(b)
@ -8343,14 +8335,6 @@ BACKWARD_SKIPS_AND_XFAILS = [
sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name),
name="broken_copysign_backward",
),
# chunk(): backward doesn't work for the batch dim yet
XFailRule(
error_type=RuntimeError,
error_msg="Nested Tensor doesn't support chunk backward on dim=0 yet",
op_match_fn=lambda device, op: (op.full_name == "chunk"),
sample_match_fn=lambda device, sample: ("batch_dim" in sample.name),
name="broken_chunk_backward",
),
# amin() / amax(): broken in a host of ways I don't think it's a good use of time
# to try to sift through
SkipRule(

View File

@ -2076,8 +2076,6 @@ Tensor chunk_backward_nested(
self.layout() == c10::kJagged,
"Nested Strided Tensor doesn't support chunk backward.")
dim = at::maybe_wrap_dim(dim, self.dim());
TORCH_INTERNAL_ASSERT(
dim != 0, "Nested Tensor doesn't support chunk backward on dim=0 yet.")
Tensor ret = at::zeros_like(self);
std::vector<Tensor> rets = at::chunk(ret, chunks, dim);
for (const auto j : c10::irange(grads.size())) {

View File

@ -652,9 +652,20 @@ def copy_default(func, *args, **kwargs):
inp = new_kwargs.pop("input")
src = new_kwargs.pop("src")
if inp._size != src._size:
raise RuntimeError(
"copy_ only supports Nested Tensors that have same size and the exact same offset tensor."
)
# try to recursively copy_ on unbound components to get around nested int mismatch
# TODO: eventually do a direct copy when this is possible
inp_comps = inp.unbind()
inp_comp_shapes = [c.shape for c in inp_comps]
src_comps = src.unbind()
src_comp_shapes = [c.shape for c in src_comps]
if inp_comp_shapes != src_comp_shapes:
raise RuntimeError(
"copy_(): expected compatible input and src shapes, but got: "
f"{inp.shape} and {src.shape}"
)
for inp_comp, src_comp in zip(inp_comps, src_comps):
inp_comp.copy_(src_comp)
# AOTD allows mutations of inputs only, (not views of the inputs).
# NJT.values() returns _values.detach() to workaround some issues.
# To keep mutation in the graph, AOTD manually calls copy_ on the input (NJT).

View File

@ -831,7 +831,7 @@ def batchwise_reference_chunk(op, sample):
start += chunk_size
# rejoin into NJT outputs
return [torch.nested.nested_tensor(lst, layout=torch.jagged) for lst in chunks]
return [torch.nested.as_nested_tensor(lst, layout=torch.jagged) for lst in chunks]
def batchwise_reference_narrow(op, sample):