[scan x vmap] support scan in vmap (#165580)

This is required by the chunked_with_scan work where two nested vmap(vmap) with chunk sizes > 1 are invoked, which produces a scan-> vmap -> scan -> vmap chain and we need to handle the case of vmap(scan) and scan(vmap).

The way we handle vmap(scan) is to turn it into scan(vmap(combine_fn)). The idea being that the combine_fn no longer do the combine_fn for a single slice, it vmaps over the combine_fn and do multiple combine_fns in one step. We need to need know how combine_fn propagates the batched tensor and what are the batched dims of the output. For this purpose, we use restore_vmap to give us the out_dims information.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165580
Approved by: https://github.com/zou3519
ghstack dependencies: #165675
This commit is contained in:
Yidi Wu 2025-10-20 15:04:39 -07:00 committed by PyTorch MergeBot
parent 282f39a4bc
commit af4ba78543
3 changed files with 314 additions and 1 deletions

View File

@ -8087,6 +8087,260 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
self.assertEqual(eager_out, exp_out)
self.assertEqual(compiled_out, exp_out)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_scan_in_vmap_simple(self):
x = torch.randn(3, 4, 4)
y = torch.randn(4, 2)
zeros = torch.zeros(2, 3)
def combine_fn(init, xs):
return init.clone(), xs @ y
def fn(scan_op, x, y):
def inner_fn(zeros, x, y):
x = x.view(2, 2, 4)
return scan_op(
combine_fn,
zeros,
x,
)
return torch.vmap(inner_fn, in_dims=(1, 0, None))(zeros, x, y)
out = fn(scan, x, y)
compile_out = torch.compile(fn)(scan, x, y)
exp = fn(_fake_scan, x, y)
self.assertEqual(out, exp)
self.assertEqual(out, compile_out)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_scan_in_vmap_complex_ops(self):
# Test with various operations requiring shape reasoning
x = torch.randn(4, 5, 3, 2)
init = torch.randn(4, 3, 2)
weight = torch.randn(3, 3)
def combine_fn(carry, xs):
# carry: (3, 2), xs: (3, 2)
intermediate = torch.nn.functional.relu(carry)
xs_t = xs.transpose(0, 1) # (2, 3)
result = xs_t @ weight # (2, 3)
new_carry = intermediate + result.transpose(0, 1) # Back to (3, 2)
output = torch.sin(carry).sum() + torch.cos(xs).mean()
return new_carry, output
def fn(scan_op, x, init):
def inner_fn(x, init):
return scan_op(combine_fn, init, x)
return torch.vmap(inner_fn, in_dims=(0, 0))(x, init)
out = fn(scan, x, init)
compile_out = torch.compile(fn)(scan, x, init)
exp = fn(_fake_scan, x, init)
self.assertEqual(out, exp)
self.assertEqual(compile_out, exp)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_scan_in_vmap_unbatched_x(self):
# Test with various operations requiring shape reasoning
x = torch.randn(5, 3, 2)
init = torch.randn(4, 3, 2)
weight = torch.randn(3, 3)
def combine_fn(carry, xs):
# carry: (3, 2), xs: (3, 2)
intermediate = torch.nn.functional.relu(carry)
xs_t = xs.transpose(0, 1) # (2, 3)
result = xs_t @ weight # (2, 3)
new_carry = intermediate + result.transpose(0, 1) # Back to (3, 2)
output = torch.sin(carry).sum() + torch.cos(xs).mean()
return new_carry, output
def fn(scan_op, x, init):
def inner_fn(x, init):
return scan_op(combine_fn, init, x)
return torch.vmap(inner_fn, in_dims=(None, 0))(x, init)
out = fn(scan, x, init)
compile_out = torch.compile(fn)(scan, x, init)
exp = fn(_fake_scan, x, init)
self.assertEqual(out, exp)
self.assertEqual(compile_out, exp)
@skipIfTorchDynamo("not a dynamo test")
def test_scan_in_vmap_unbatched_init_error(self):
# Test with various operations requiring shape reasoning
x = torch.randn(4, 5, 3, 2)
init = torch.randn(4, 3, 2)
weight = torch.randn(3, 3)
def combine_fn(carry, xs):
# carry: (3, 2), xs: (3, 2)
intermediate = torch.nn.functional.relu(carry)
xs_t = xs.transpose(0, 1) # (2, 3)
result = xs_t @ weight # (2, 3)
new_carry = intermediate + result.transpose(0, 1) # Back to (3, 2)
output = torch.sin(carry).sum() + torch.cos(xs).mean()
return new_carry, output
def vmap_fn(x, init):
def fn(x, init):
return scan(combine_fn, init, x)
return torch.vmap(fn, in_dims=(0, None))(x, init)
with self.assertRaisesRegex(
RuntimeError,
"""The size of tensor a \\(4\\) must match the size of tensor b \\(2\\) at non-singleton dimension 4""",
):
vmap_fn(x, init)
@skipIfTorchDynamo("a vmap test, not a dynamo test")
def test_vmap_closure_weight_error(self):
init_batched = torch.randn(7, 2, 3)
xs_batched = torch.randn(7, 5, 4)
weight = torch.randn(7, 4, 3)
def combine_fn(carry, xs):
# carry: (2, 3), xs: (4,), weight: (4, 3)
new_carry = carry + xs @ weight
output = carry.sum()
return new_carry, output
def expected_fn(init, xs, weight):
def fn(init, xs, weight):
return _fake_scan(combine_fn, init, xs)
return torch.vmap(fn, in_dims=(0, 0, 0))(init, xs, weight)
# Note that even though weight is vampped but combine_fn is accessing
# the closure weight instead of the wrapped out weight thus causing
# a shape mismatch.
with self.assertRaisesRegex(
RuntimeError,
"""The size of tensor a \\(2\\) must match the size of tensor b \\(7\\) at non-singleton dimension 1""",
):
expected_fn(init_batched, xs_batched, weight)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_scan_in_vmap_mixed_batch_dims(self):
init = torch.randn(8, 5, 6)
xs_batched = torch.randn(3, 6, 5, 8)
scale = torch.randn([])
def combine_fn(carry, xs):
# carry: 8, 5
# xs: 5, 8
# new_carry: 8, 5
new_carry = carry + (xs * scale).sum()
output = xs @ carry
return new_carry, output
def fn(scan_op, init, xs):
def inner_fn(init, xs):
return scan_op(combine_fn, init, xs)
return torch.vmap(inner_fn, in_dims=(2, 1))(init, xs)
out = fn(scan, init, xs_batched)
compile_out = torch.compile(fn)(scan, init, xs_batched)
exp = fn(_fake_scan, init, xs_batched)
self.assertEqual(out, exp)
self.assertEqual(compile_out, exp)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_vmap_scan_vmap_scan_nested(self):
# Outer batch: 3, inner batch: 4, outer scan: 5, inner scan: 6
init = torch.randn(3, 4, 2, 8)
xs_outer = torch.randn(3, 5, 4, 6, 2)
def fn(scan_op, init, xs):
def inner_combine(carry, xs):
# carry: (2, 8), xs: (2,)
new_carry = carry + xs.unsqueeze(-1)
output = carry.sum(dim=0) # (8,)
return new_carry, output
def outer_combine(init, xs):
# carry: (4, 2, 8,), xs: (4, 6, 2)
# xs has batch dimension 4 from outer vmap
def inner_fn(init, xs):
# init: (2, 8)
# xs: (6, 2)
# final_carry: (2, 8)
# outputs: (6, 8)
final_carry, outputs = scan_op(inner_combine, init, xs)
return (final_carry.sum(0, keepdim=True) + outputs).sum(
dim=0
) # (8,)
inner_results = torch.vmap(inner_fn)(init, xs) # (4, 8)
new_carry = init + inner_results.mean(dim=0) # (8,)
output = inner_results.sum(dim=0) # (8,)
return new_carry.expand(*init.size()), output
def vmap_inner_fn(init, xs):
# init: (4, 2, 8)
# xs: (5, 4, 6, 2)
return scan_op(outer_combine, init, xs)
return torch.vmap(vmap_inner_fn)(init, xs)
out = fn(scan, init, xs_outer)
compile_out = torch.compile(fn)(scan, init, xs_outer)
exp = fn(_fake_scan, init, xs_outer)
self.assertEqual(out, exp)
self.assertEqual(compile_out, exp)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_scan_vmap_scan_nested(self):
xs_outer = torch.randn(5, 3, 4, 2)
init_outer = torch.randn(3, 8)
def fn(scan_op, init, xs):
def inner_combine_fake(carry, xs):
# carry: 8
# xs: 2
new_carry = carry + xs.sum()
output = carry * 2
return new_carry, output
def outer_combine_fake(carry, xs):
# carry: 3, 8
# xs: 3, 4, 2
def inner_fn(carry_elem, xs_elem):
# carry_elem: 8
# xs: 4, 2
# final_carry: 8
# outputs.sum(0): 8
final_carry, outputs = _fake_scan(
inner_combine_fake, carry_elem, xs_elem
)
return outputs.sum(0), final_carry
# result: (8,)
# next_carry, (3, 8))
result, next_carry = torch.vmap(inner_fn, in_dims=(0, 0))(carry, xs)
output = result.sum(dim=0)
return next_carry, output
return scan_op(outer_combine_fake, init, xs)
out = fn(scan, init_outer, xs_outer)
compile_out = torch.compile(fn)(scan, init_outer, xs_outer)
exp = fn(_fake_scan, init_outer, xs_outer)
self.assertEqual(out, exp)
self.assertEqual(compile_out, exp)
@skipIfTorchDynamo("Skip because we're testing export")
@parametrize("strict", [True, False])
@parametrize("dynamic", [True, False])

View File

@ -28,6 +28,7 @@ def _add_batch_dim(self, batch_dim, level):
from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
mode = _maybe_find_pre_dispatch_tf_mode_for_export()
batch_dim = self.ndim + batch_dim if batch_dim < 0 else batch_dim
if mode:
return torch.overrides.handle_torch_function(

View File

@ -357,7 +357,7 @@ def generic_scan(operator, init, xs, dim=0, additional_inputs=()):
# Expand outs with None depending on the tensor mask of the output
outs_expanded = [outs.pop(0) if out_m else None for out_m in out_tensor_mask]
return [*carry, *outs_expanded]
return (*carry, *outs_expanded)
scans = _scan(init, xs)
return scans
@ -866,6 +866,64 @@ def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
return ctx.wrap_tensors(ret)
@scan_op.py_impl(torch._C._functorch.TransformType.Vmap)
def scan_batch_rule(interpreter, combine_fn, init, xs, additional_inputs):
from torch._functorch.vmap import restore_vmap, unwrap_batched, wrap_batched
unbatched_args, in_dims = unwrap_batched(
(init, xs, additional_inputs), interpreter.level()
)
# move to last dim to not interfere with scan's batching
unbatched_init, unbatched_xs, unbatched_additional_inputs = pytree.tree_map(
lambda x, bdim: x.movedim(bdim, -1) if bdim is not None else x,
unbatched_args,
in_dims,
)
after_move_dims = tuple(
pytree.tree_flatten(
pytree.tree_map(lambda x: -1 if x is not None else None, in_dims)
)[0]
)
with interpreter.lower():
out_dims = None
def wrapper(*args):
nonlocal out_dims
outputs, per_slice_out_dims = restore_vmap(
combine_fn,
after_move_dims,
interpreter.batch_size(),
interpreter.randomness(),
)(*args)
# Note: outputs are not batched, we just move the batch dim to the end
# this is to avoid it interfering with scan's batching
outputs = tuple(
pytree.tree_map(
lambda out, out_bdim: out.movedim(out_bdim, -1)
if out_bdim is not None
else out,
outputs,
per_slice_out_dims,
)
)
out_dims = tuple(
pytree.tree_map(
lambda out_bdim: -1 if out_bdim is not None else None,
per_slice_out_dims,
)
)
return outputs
unwrapped_out = scan_op(
wrapper, unbatched_init, unbatched_xs, unbatched_additional_inputs
)
assert out_dims is not None
batched_out = wrap_batched(unwrapped_out, out_dims, interpreter.level())
return batched_out
# dense implementation for scan. Used for testing only.
def _fake_scan(combine_fn, init, xs=None, dim=0, reverse=False):
carry_leaves, carry_spec = pytree.tree_flatten(init)