mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support NJT chunk() backward on batch dim (#144584)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests. Implements `chunk()` backward on the batch dim, which was left out before. This PR unbinds the components and invokes `copy_()` on these to pass along the appropriate gradients. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144584 Approved by: https://github.com/soulitzer ghstack dependencies: #144582, #144583
This commit is contained in:
parent
8a57234033
commit
3ee531f8b9
|
|
@ -4188,17 +4188,13 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
|||
self.assertEqual(chunks[i]._offsets[1:], offsets_expected)
|
||||
self.assertEqual(nt._values, torch.cat([x._values for x in chunks], dim=0))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"dim != 0 INTERNAL ASSERT FAILED .* Nested Tensor doesn't support chunk backward on dim=0 yet.",
|
||||
):
|
||||
# doesn't support backward for chunk (dim=0) yet
|
||||
loss = (
|
||||
chunks[0].values().sum()
|
||||
+ chunks[1].values().sum()
|
||||
+ chunks[2].values().sum()
|
||||
)
|
||||
loss.backward()
|
||||
# doesn't support backward for chunk (dim=0) yet
|
||||
loss = (
|
||||
chunks[0].values().sum()
|
||||
+ chunks[1].values().sum()
|
||||
+ chunks[2].values().sum()
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
# chunk on ragged dim not supported
|
||||
with self.assertRaisesRegex(
|
||||
|
|
@ -6232,18 +6228,14 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
|||
c = torch.nested.nested_tensor_from_jagged(
|
||||
torch.ones(4, 3, device=device), offsets_2
|
||||
)
|
||||
# fail when tensors have the same size but not the exact same offset tensor.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"copy_ only supports Nested Tensors that have same size and the exact same offset tensor.",
|
||||
):
|
||||
a.copy_(c)
|
||||
# should work even though the nested ints are different due to unbound-based copy
|
||||
a.copy_(c)
|
||||
|
||||
# fail when tensors have different sizes
|
||||
a = a.transpose(1, 2)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"copy_ only supports Nested Tensors that have same size and the exact same offset tensor.",
|
||||
"expected compatible input and src shapes, but got",
|
||||
):
|
||||
a.copy_(b)
|
||||
|
||||
|
|
@ -8343,14 +8335,6 @@ BACKWARD_SKIPS_AND_XFAILS = [
|
|||
sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name),
|
||||
name="broken_copysign_backward",
|
||||
),
|
||||
# chunk(): backward doesn't work for the batch dim yet
|
||||
XFailRule(
|
||||
error_type=RuntimeError,
|
||||
error_msg="Nested Tensor doesn't support chunk backward on dim=0 yet",
|
||||
op_match_fn=lambda device, op: (op.full_name == "chunk"),
|
||||
sample_match_fn=lambda device, sample: ("batch_dim" in sample.name),
|
||||
name="broken_chunk_backward",
|
||||
),
|
||||
# amin() / amax(): broken in a host of ways I don't think it's a good use of time
|
||||
# to try to sift through
|
||||
SkipRule(
|
||||
|
|
|
|||
|
|
@ -2076,8 +2076,6 @@ Tensor chunk_backward_nested(
|
|||
self.layout() == c10::kJagged,
|
||||
"Nested Strided Tensor doesn't support chunk backward.")
|
||||
dim = at::maybe_wrap_dim(dim, self.dim());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dim != 0, "Nested Tensor doesn't support chunk backward on dim=0 yet.")
|
||||
Tensor ret = at::zeros_like(self);
|
||||
std::vector<Tensor> rets = at::chunk(ret, chunks, dim);
|
||||
for (const auto j : c10::irange(grads.size())) {
|
||||
|
|
|
|||
|
|
@ -652,9 +652,20 @@ def copy_default(func, *args, **kwargs):
|
|||
inp = new_kwargs.pop("input")
|
||||
src = new_kwargs.pop("src")
|
||||
if inp._size != src._size:
|
||||
raise RuntimeError(
|
||||
"copy_ only supports Nested Tensors that have same size and the exact same offset tensor."
|
||||
)
|
||||
# try to recursively copy_ on unbound components to get around nested int mismatch
|
||||
# TODO: eventually do a direct copy when this is possible
|
||||
inp_comps = inp.unbind()
|
||||
inp_comp_shapes = [c.shape for c in inp_comps]
|
||||
src_comps = src.unbind()
|
||||
src_comp_shapes = [c.shape for c in src_comps]
|
||||
if inp_comp_shapes != src_comp_shapes:
|
||||
raise RuntimeError(
|
||||
"copy_(): expected compatible input and src shapes, but got: "
|
||||
f"{inp.shape} and {src.shape}"
|
||||
)
|
||||
for inp_comp, src_comp in zip(inp_comps, src_comps):
|
||||
inp_comp.copy_(src_comp)
|
||||
|
||||
# AOTD allows mutations of inputs only, (not views of the inputs).
|
||||
# NJT.values() returns _values.detach() to workaround some issues.
|
||||
# To keep mutation in the graph, AOTD manually calls copy_ on the input (NJT).
|
||||
|
|
|
|||
|
|
@ -831,7 +831,7 @@ def batchwise_reference_chunk(op, sample):
|
|||
start += chunk_size
|
||||
|
||||
# rejoin into NJT outputs
|
||||
return [torch.nested.nested_tensor(lst, layout=torch.jagged) for lst in chunks]
|
||||
return [torch.nested.as_nested_tensor(lst, layout=torch.jagged) for lst in chunks]
|
||||
|
||||
|
||||
def batchwise_reference_narrow(op, sample):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user