mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[scan x vmap] support scan in vmap (#165580)
This is required by the chunked_with_scan work where two nested vmap(vmap) with chunk sizes > 1 are invoked, which produces a scan-> vmap -> scan -> vmap chain and we need to handle the case of vmap(scan) and scan(vmap). The way we handle vmap(scan) is to turn it into scan(vmap(combine_fn)). The idea being that the combine_fn no longer do the combine_fn for a single slice, it vmaps over the combine_fn and do multiple combine_fns in one step. We need to need know how combine_fn propagates the batched tensor and what are the batched dims of the output. For this purpose, we use restore_vmap to give us the out_dims information. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165580 Approved by: https://github.com/zou3519 ghstack dependencies: #165675
This commit is contained in:
parent
282f39a4bc
commit
af4ba78543
|
|
@ -8087,6 +8087,260 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
|
|||
self.assertEqual(eager_out, exp_out)
|
||||
self.assertEqual(compiled_out, exp_out)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_scan_in_vmap_simple(self):
|
||||
x = torch.randn(3, 4, 4)
|
||||
y = torch.randn(4, 2)
|
||||
zeros = torch.zeros(2, 3)
|
||||
|
||||
def combine_fn(init, xs):
|
||||
return init.clone(), xs @ y
|
||||
|
||||
def fn(scan_op, x, y):
|
||||
def inner_fn(zeros, x, y):
|
||||
x = x.view(2, 2, 4)
|
||||
|
||||
return scan_op(
|
||||
combine_fn,
|
||||
zeros,
|
||||
x,
|
||||
)
|
||||
|
||||
return torch.vmap(inner_fn, in_dims=(1, 0, None))(zeros, x, y)
|
||||
|
||||
out = fn(scan, x, y)
|
||||
compile_out = torch.compile(fn)(scan, x, y)
|
||||
exp = fn(_fake_scan, x, y)
|
||||
self.assertEqual(out, exp)
|
||||
self.assertEqual(out, compile_out)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_scan_in_vmap_complex_ops(self):
|
||||
# Test with various operations requiring shape reasoning
|
||||
x = torch.randn(4, 5, 3, 2)
|
||||
init = torch.randn(4, 3, 2)
|
||||
weight = torch.randn(3, 3)
|
||||
|
||||
def combine_fn(carry, xs):
|
||||
# carry: (3, 2), xs: (3, 2)
|
||||
intermediate = torch.nn.functional.relu(carry)
|
||||
xs_t = xs.transpose(0, 1) # (2, 3)
|
||||
result = xs_t @ weight # (2, 3)
|
||||
new_carry = intermediate + result.transpose(0, 1) # Back to (3, 2)
|
||||
output = torch.sin(carry).sum() + torch.cos(xs).mean()
|
||||
return new_carry, output
|
||||
|
||||
def fn(scan_op, x, init):
|
||||
def inner_fn(x, init):
|
||||
return scan_op(combine_fn, init, x)
|
||||
|
||||
return torch.vmap(inner_fn, in_dims=(0, 0))(x, init)
|
||||
|
||||
out = fn(scan, x, init)
|
||||
compile_out = torch.compile(fn)(scan, x, init)
|
||||
exp = fn(_fake_scan, x, init)
|
||||
|
||||
self.assertEqual(out, exp)
|
||||
self.assertEqual(compile_out, exp)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_scan_in_vmap_unbatched_x(self):
|
||||
# Test with various operations requiring shape reasoning
|
||||
x = torch.randn(5, 3, 2)
|
||||
init = torch.randn(4, 3, 2)
|
||||
weight = torch.randn(3, 3)
|
||||
|
||||
def combine_fn(carry, xs):
|
||||
# carry: (3, 2), xs: (3, 2)
|
||||
intermediate = torch.nn.functional.relu(carry)
|
||||
xs_t = xs.transpose(0, 1) # (2, 3)
|
||||
result = xs_t @ weight # (2, 3)
|
||||
new_carry = intermediate + result.transpose(0, 1) # Back to (3, 2)
|
||||
output = torch.sin(carry).sum() + torch.cos(xs).mean()
|
||||
return new_carry, output
|
||||
|
||||
def fn(scan_op, x, init):
|
||||
def inner_fn(x, init):
|
||||
return scan_op(combine_fn, init, x)
|
||||
|
||||
return torch.vmap(inner_fn, in_dims=(None, 0))(x, init)
|
||||
|
||||
out = fn(scan, x, init)
|
||||
compile_out = torch.compile(fn)(scan, x, init)
|
||||
exp = fn(_fake_scan, x, init)
|
||||
|
||||
self.assertEqual(out, exp)
|
||||
self.assertEqual(compile_out, exp)
|
||||
|
||||
@skipIfTorchDynamo("not a dynamo test")
|
||||
def test_scan_in_vmap_unbatched_init_error(self):
|
||||
# Test with various operations requiring shape reasoning
|
||||
x = torch.randn(4, 5, 3, 2)
|
||||
init = torch.randn(4, 3, 2)
|
||||
weight = torch.randn(3, 3)
|
||||
|
||||
def combine_fn(carry, xs):
|
||||
# carry: (3, 2), xs: (3, 2)
|
||||
intermediate = torch.nn.functional.relu(carry)
|
||||
xs_t = xs.transpose(0, 1) # (2, 3)
|
||||
result = xs_t @ weight # (2, 3)
|
||||
new_carry = intermediate + result.transpose(0, 1) # Back to (3, 2)
|
||||
output = torch.sin(carry).sum() + torch.cos(xs).mean()
|
||||
return new_carry, output
|
||||
|
||||
def vmap_fn(x, init):
|
||||
def fn(x, init):
|
||||
return scan(combine_fn, init, x)
|
||||
|
||||
return torch.vmap(fn, in_dims=(0, None))(x, init)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"""The size of tensor a \\(4\\) must match the size of tensor b \\(2\\) at non-singleton dimension 4""",
|
||||
):
|
||||
vmap_fn(x, init)
|
||||
|
||||
@skipIfTorchDynamo("a vmap test, not a dynamo test")
|
||||
def test_vmap_closure_weight_error(self):
|
||||
init_batched = torch.randn(7, 2, 3)
|
||||
xs_batched = torch.randn(7, 5, 4)
|
||||
weight = torch.randn(7, 4, 3)
|
||||
|
||||
def combine_fn(carry, xs):
|
||||
# carry: (2, 3), xs: (4,), weight: (4, 3)
|
||||
new_carry = carry + xs @ weight
|
||||
output = carry.sum()
|
||||
return new_carry, output
|
||||
|
||||
def expected_fn(init, xs, weight):
|
||||
def fn(init, xs, weight):
|
||||
return _fake_scan(combine_fn, init, xs)
|
||||
|
||||
return torch.vmap(fn, in_dims=(0, 0, 0))(init, xs, weight)
|
||||
|
||||
# Note that even though weight is vampped but combine_fn is accessing
|
||||
# the closure weight instead of the wrapped out weight thus causing
|
||||
# a shape mismatch.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"""The size of tensor a \\(2\\) must match the size of tensor b \\(7\\) at non-singleton dimension 1""",
|
||||
):
|
||||
expected_fn(init_batched, xs_batched, weight)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_scan_in_vmap_mixed_batch_dims(self):
|
||||
init = torch.randn(8, 5, 6)
|
||||
xs_batched = torch.randn(3, 6, 5, 8)
|
||||
scale = torch.randn([])
|
||||
|
||||
def combine_fn(carry, xs):
|
||||
# carry: 8, 5
|
||||
# xs: 5, 8
|
||||
# new_carry: 8, 5
|
||||
new_carry = carry + (xs * scale).sum()
|
||||
output = xs @ carry
|
||||
return new_carry, output
|
||||
|
||||
def fn(scan_op, init, xs):
|
||||
def inner_fn(init, xs):
|
||||
return scan_op(combine_fn, init, xs)
|
||||
|
||||
return torch.vmap(inner_fn, in_dims=(2, 1))(init, xs)
|
||||
|
||||
out = fn(scan, init, xs_batched)
|
||||
compile_out = torch.compile(fn)(scan, init, xs_batched)
|
||||
exp = fn(_fake_scan, init, xs_batched)
|
||||
|
||||
self.assertEqual(out, exp)
|
||||
self.assertEqual(compile_out, exp)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_vmap_scan_vmap_scan_nested(self):
|
||||
# Outer batch: 3, inner batch: 4, outer scan: 5, inner scan: 6
|
||||
init = torch.randn(3, 4, 2, 8)
|
||||
xs_outer = torch.randn(3, 5, 4, 6, 2)
|
||||
|
||||
def fn(scan_op, init, xs):
|
||||
def inner_combine(carry, xs):
|
||||
# carry: (2, 8), xs: (2,)
|
||||
new_carry = carry + xs.unsqueeze(-1)
|
||||
output = carry.sum(dim=0) # (8,)
|
||||
return new_carry, output
|
||||
|
||||
def outer_combine(init, xs):
|
||||
# carry: (4, 2, 8,), xs: (4, 6, 2)
|
||||
# xs has batch dimension 4 from outer vmap
|
||||
|
||||
def inner_fn(init, xs):
|
||||
# init: (2, 8)
|
||||
# xs: (6, 2)
|
||||
# final_carry: (2, 8)
|
||||
# outputs: (6, 8)
|
||||
final_carry, outputs = scan_op(inner_combine, init, xs)
|
||||
return (final_carry.sum(0, keepdim=True) + outputs).sum(
|
||||
dim=0
|
||||
) # (8,)
|
||||
|
||||
inner_results = torch.vmap(inner_fn)(init, xs) # (4, 8)
|
||||
new_carry = init + inner_results.mean(dim=0) # (8,)
|
||||
output = inner_results.sum(dim=0) # (8,)
|
||||
return new_carry.expand(*init.size()), output
|
||||
|
||||
def vmap_inner_fn(init, xs):
|
||||
# init: (4, 2, 8)
|
||||
# xs: (5, 4, 6, 2)
|
||||
return scan_op(outer_combine, init, xs)
|
||||
|
||||
return torch.vmap(vmap_inner_fn)(init, xs)
|
||||
|
||||
out = fn(scan, init, xs_outer)
|
||||
compile_out = torch.compile(fn)(scan, init, xs_outer)
|
||||
exp = fn(_fake_scan, init, xs_outer)
|
||||
|
||||
self.assertEqual(out, exp)
|
||||
self.assertEqual(compile_out, exp)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_scan_vmap_scan_nested(self):
|
||||
xs_outer = torch.randn(5, 3, 4, 2)
|
||||
init_outer = torch.randn(3, 8)
|
||||
|
||||
def fn(scan_op, init, xs):
|
||||
def inner_combine_fake(carry, xs):
|
||||
# carry: 8
|
||||
# xs: 2
|
||||
new_carry = carry + xs.sum()
|
||||
output = carry * 2
|
||||
return new_carry, output
|
||||
|
||||
def outer_combine_fake(carry, xs):
|
||||
# carry: 3, 8
|
||||
# xs: 3, 4, 2
|
||||
def inner_fn(carry_elem, xs_elem):
|
||||
# carry_elem: 8
|
||||
# xs: 4, 2
|
||||
# final_carry: 8
|
||||
# outputs.sum(0): 8
|
||||
final_carry, outputs = _fake_scan(
|
||||
inner_combine_fake, carry_elem, xs_elem
|
||||
)
|
||||
return outputs.sum(0), final_carry
|
||||
|
||||
# result: (8,)
|
||||
# next_carry, (3, 8))
|
||||
result, next_carry = torch.vmap(inner_fn, in_dims=(0, 0))(carry, xs)
|
||||
output = result.sum(dim=0)
|
||||
return next_carry, output
|
||||
|
||||
return scan_op(outer_combine_fake, init, xs)
|
||||
|
||||
out = fn(scan, init_outer, xs_outer)
|
||||
compile_out = torch.compile(fn)(scan, init_outer, xs_outer)
|
||||
exp = fn(_fake_scan, init_outer, xs_outer)
|
||||
|
||||
self.assertEqual(out, exp)
|
||||
self.assertEqual(compile_out, exp)
|
||||
|
||||
@skipIfTorchDynamo("Skip because we're testing export")
|
||||
@parametrize("strict", [True, False])
|
||||
@parametrize("dynamic", [True, False])
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ def _add_batch_dim(self, batch_dim, level):
|
|||
from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
|
||||
|
||||
mode = _maybe_find_pre_dispatch_tf_mode_for_export()
|
||||
batch_dim = self.ndim + batch_dim if batch_dim < 0 else batch_dim
|
||||
|
||||
if mode:
|
||||
return torch.overrides.handle_torch_function(
|
||||
|
|
|
|||
|
|
@ -357,7 +357,7 @@ def generic_scan(operator, init, xs, dim=0, additional_inputs=()):
|
|||
# Expand outs with None depending on the tensor mask of the output
|
||||
outs_expanded = [outs.pop(0) if out_m else None for out_m in out_tensor_mask]
|
||||
|
||||
return [*carry, *outs_expanded]
|
||||
return (*carry, *outs_expanded)
|
||||
|
||||
scans = _scan(init, xs)
|
||||
return scans
|
||||
|
|
@ -866,6 +866,64 @@ def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
|
|||
return ctx.wrap_tensors(ret)
|
||||
|
||||
|
||||
@scan_op.py_impl(torch._C._functorch.TransformType.Vmap)
|
||||
def scan_batch_rule(interpreter, combine_fn, init, xs, additional_inputs):
|
||||
from torch._functorch.vmap import restore_vmap, unwrap_batched, wrap_batched
|
||||
|
||||
unbatched_args, in_dims = unwrap_batched(
|
||||
(init, xs, additional_inputs), interpreter.level()
|
||||
)
|
||||
# move to last dim to not interfere with scan's batching
|
||||
unbatched_init, unbatched_xs, unbatched_additional_inputs = pytree.tree_map(
|
||||
lambda x, bdim: x.movedim(bdim, -1) if bdim is not None else x,
|
||||
unbatched_args,
|
||||
in_dims,
|
||||
)
|
||||
after_move_dims = tuple(
|
||||
pytree.tree_flatten(
|
||||
pytree.tree_map(lambda x: -1 if x is not None else None, in_dims)
|
||||
)[0]
|
||||
)
|
||||
|
||||
with interpreter.lower():
|
||||
out_dims = None
|
||||
|
||||
def wrapper(*args):
|
||||
nonlocal out_dims
|
||||
outputs, per_slice_out_dims = restore_vmap(
|
||||
combine_fn,
|
||||
after_move_dims,
|
||||
interpreter.batch_size(),
|
||||
interpreter.randomness(),
|
||||
)(*args)
|
||||
# Note: outputs are not batched, we just move the batch dim to the end
|
||||
# this is to avoid it interfering with scan's batching
|
||||
outputs = tuple(
|
||||
pytree.tree_map(
|
||||
lambda out, out_bdim: out.movedim(out_bdim, -1)
|
||||
if out_bdim is not None
|
||||
else out,
|
||||
outputs,
|
||||
per_slice_out_dims,
|
||||
)
|
||||
)
|
||||
out_dims = tuple(
|
||||
pytree.tree_map(
|
||||
lambda out_bdim: -1 if out_bdim is not None else None,
|
||||
per_slice_out_dims,
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
|
||||
unwrapped_out = scan_op(
|
||||
wrapper, unbatched_init, unbatched_xs, unbatched_additional_inputs
|
||||
)
|
||||
|
||||
assert out_dims is not None
|
||||
batched_out = wrap_batched(unwrapped_out, out_dims, interpreter.level())
|
||||
return batched_out
|
||||
|
||||
|
||||
# dense implementation for scan. Used for testing only.
|
||||
def _fake_scan(combine_fn, init, xs=None, dim=0, reverse=False):
|
||||
carry_leaves, carry_spec = pytree.tree_flatten(init)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user