From af4ba78543fef4cd188c387358e4e39a3902c212 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Mon, 20 Oct 2025 15:04:39 -0700 Subject: [PATCH] [scan x vmap] support scan in vmap (#165580) This is required by the chunked_with_scan work where two nested vmap(vmap) with chunk sizes > 1 are invoked, which produces a scan-> vmap -> scan -> vmap chain and we need to handle the case of vmap(scan) and scan(vmap). The way we handle vmap(scan) is to turn it into scan(vmap(combine_fn)). The idea being that the combine_fn no longer do the combine_fn for a single slice, it vmaps over the combine_fn and do multiple combine_fns in one step. We need to need know how combine_fn propagates the batched tensor and what are the batched dims of the output. For this purpose, we use restore_vmap to give us the out_dims information. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165580 Approved by: https://github.com/zou3519 ghstack dependencies: #165675 --- test/functorch/test_control_flow.py | 254 ++++++++++++++++++++++++++++ torch/_functorch/predispatch.py | 1 + torch/_higher_order_ops/scan.py | 60 ++++++- 3 files changed, 314 insertions(+), 1 deletion(-) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index e47aaa9e9e2..5ed62a638cb 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -8087,6 +8087,260 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_ self.assertEqual(eager_out, exp_out) self.assertEqual(compiled_out, exp_out) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_scan_in_vmap_simple(self): + x = torch.randn(3, 4, 4) + y = torch.randn(4, 2) + zeros = torch.zeros(2, 3) + + def combine_fn(init, xs): + return init.clone(), xs @ y + + def fn(scan_op, x, y): + def inner_fn(zeros, x, y): + x = x.view(2, 2, 4) + + return scan_op( + combine_fn, + zeros, + x, + ) + + return torch.vmap(inner_fn, in_dims=(1, 0, None))(zeros, x, y) + + out = fn(scan, x, y) + compile_out = torch.compile(fn)(scan, x, y) + exp = fn(_fake_scan, x, y) + self.assertEqual(out, exp) + self.assertEqual(out, compile_out) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_scan_in_vmap_complex_ops(self): + # Test with various operations requiring shape reasoning + x = torch.randn(4, 5, 3, 2) + init = torch.randn(4, 3, 2) + weight = torch.randn(3, 3) + + def combine_fn(carry, xs): + # carry: (3, 2), xs: (3, 2) + intermediate = torch.nn.functional.relu(carry) + xs_t = xs.transpose(0, 1) # (2, 3) + result = xs_t @ weight # (2, 3) + new_carry = intermediate + result.transpose(0, 1) # Back to (3, 2) + output = torch.sin(carry).sum() + torch.cos(xs).mean() + return new_carry, output + + def fn(scan_op, x, init): + def inner_fn(x, init): + return scan_op(combine_fn, init, x) + + return torch.vmap(inner_fn, in_dims=(0, 0))(x, init) + + out = fn(scan, x, init) + compile_out = torch.compile(fn)(scan, x, init) + exp = fn(_fake_scan, x, init) + + self.assertEqual(out, exp) + self.assertEqual(compile_out, exp) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_scan_in_vmap_unbatched_x(self): + # Test with various operations requiring shape reasoning + x = torch.randn(5, 3, 2) + init = torch.randn(4, 3, 2) + weight = torch.randn(3, 3) + + def combine_fn(carry, xs): + # carry: (3, 2), xs: (3, 2) + intermediate = torch.nn.functional.relu(carry) + xs_t = xs.transpose(0, 1) # (2, 3) + result = xs_t @ weight # (2, 3) + new_carry = intermediate + result.transpose(0, 1) # Back to (3, 2) + output = torch.sin(carry).sum() + torch.cos(xs).mean() + return new_carry, output + + def fn(scan_op, x, init): + def inner_fn(x, init): + return scan_op(combine_fn, init, x) + + return torch.vmap(inner_fn, in_dims=(None, 0))(x, init) + + out = fn(scan, x, init) + compile_out = torch.compile(fn)(scan, x, init) + exp = fn(_fake_scan, x, init) + + self.assertEqual(out, exp) + self.assertEqual(compile_out, exp) + + @skipIfTorchDynamo("not a dynamo test") + def test_scan_in_vmap_unbatched_init_error(self): + # Test with various operations requiring shape reasoning + x = torch.randn(4, 5, 3, 2) + init = torch.randn(4, 3, 2) + weight = torch.randn(3, 3) + + def combine_fn(carry, xs): + # carry: (3, 2), xs: (3, 2) + intermediate = torch.nn.functional.relu(carry) + xs_t = xs.transpose(0, 1) # (2, 3) + result = xs_t @ weight # (2, 3) + new_carry = intermediate + result.transpose(0, 1) # Back to (3, 2) + output = torch.sin(carry).sum() + torch.cos(xs).mean() + return new_carry, output + + def vmap_fn(x, init): + def fn(x, init): + return scan(combine_fn, init, x) + + return torch.vmap(fn, in_dims=(0, None))(x, init) + + with self.assertRaisesRegex( + RuntimeError, + """The size of tensor a \\(4\\) must match the size of tensor b \\(2\\) at non-singleton dimension 4""", + ): + vmap_fn(x, init) + + @skipIfTorchDynamo("a vmap test, not a dynamo test") + def test_vmap_closure_weight_error(self): + init_batched = torch.randn(7, 2, 3) + xs_batched = torch.randn(7, 5, 4) + weight = torch.randn(7, 4, 3) + + def combine_fn(carry, xs): + # carry: (2, 3), xs: (4,), weight: (4, 3) + new_carry = carry + xs @ weight + output = carry.sum() + return new_carry, output + + def expected_fn(init, xs, weight): + def fn(init, xs, weight): + return _fake_scan(combine_fn, init, xs) + + return torch.vmap(fn, in_dims=(0, 0, 0))(init, xs, weight) + + # Note that even though weight is vampped but combine_fn is accessing + # the closure weight instead of the wrapped out weight thus causing + # a shape mismatch. + with self.assertRaisesRegex( + RuntimeError, + """The size of tensor a \\(2\\) must match the size of tensor b \\(7\\) at non-singleton dimension 1""", + ): + expected_fn(init_batched, xs_batched, weight) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_scan_in_vmap_mixed_batch_dims(self): + init = torch.randn(8, 5, 6) + xs_batched = torch.randn(3, 6, 5, 8) + scale = torch.randn([]) + + def combine_fn(carry, xs): + # carry: 8, 5 + # xs: 5, 8 + # new_carry: 8, 5 + new_carry = carry + (xs * scale).sum() + output = xs @ carry + return new_carry, output + + def fn(scan_op, init, xs): + def inner_fn(init, xs): + return scan_op(combine_fn, init, xs) + + return torch.vmap(inner_fn, in_dims=(2, 1))(init, xs) + + out = fn(scan, init, xs_batched) + compile_out = torch.compile(fn)(scan, init, xs_batched) + exp = fn(_fake_scan, init, xs_batched) + + self.assertEqual(out, exp) + self.assertEqual(compile_out, exp) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_vmap_scan_vmap_scan_nested(self): + # Outer batch: 3, inner batch: 4, outer scan: 5, inner scan: 6 + init = torch.randn(3, 4, 2, 8) + xs_outer = torch.randn(3, 5, 4, 6, 2) + + def fn(scan_op, init, xs): + def inner_combine(carry, xs): + # carry: (2, 8), xs: (2,) + new_carry = carry + xs.unsqueeze(-1) + output = carry.sum(dim=0) # (8,) + return new_carry, output + + def outer_combine(init, xs): + # carry: (4, 2, 8,), xs: (4, 6, 2) + # xs has batch dimension 4 from outer vmap + + def inner_fn(init, xs): + # init: (2, 8) + # xs: (6, 2) + # final_carry: (2, 8) + # outputs: (6, 8) + final_carry, outputs = scan_op(inner_combine, init, xs) + return (final_carry.sum(0, keepdim=True) + outputs).sum( + dim=0 + ) # (8,) + + inner_results = torch.vmap(inner_fn)(init, xs) # (4, 8) + new_carry = init + inner_results.mean(dim=0) # (8,) + output = inner_results.sum(dim=0) # (8,) + return new_carry.expand(*init.size()), output + + def vmap_inner_fn(init, xs): + # init: (4, 2, 8) + # xs: (5, 4, 6, 2) + return scan_op(outer_combine, init, xs) + + return torch.vmap(vmap_inner_fn)(init, xs) + + out = fn(scan, init, xs_outer) + compile_out = torch.compile(fn)(scan, init, xs_outer) + exp = fn(_fake_scan, init, xs_outer) + + self.assertEqual(out, exp) + self.assertEqual(compile_out, exp) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_scan_vmap_scan_nested(self): + xs_outer = torch.randn(5, 3, 4, 2) + init_outer = torch.randn(3, 8) + + def fn(scan_op, init, xs): + def inner_combine_fake(carry, xs): + # carry: 8 + # xs: 2 + new_carry = carry + xs.sum() + output = carry * 2 + return new_carry, output + + def outer_combine_fake(carry, xs): + # carry: 3, 8 + # xs: 3, 4, 2 + def inner_fn(carry_elem, xs_elem): + # carry_elem: 8 + # xs: 4, 2 + # final_carry: 8 + # outputs.sum(0): 8 + final_carry, outputs = _fake_scan( + inner_combine_fake, carry_elem, xs_elem + ) + return outputs.sum(0), final_carry + + # result: (8,) + # next_carry, (3, 8)) + result, next_carry = torch.vmap(inner_fn, in_dims=(0, 0))(carry, xs) + output = result.sum(dim=0) + return next_carry, output + + return scan_op(outer_combine_fake, init, xs) + + out = fn(scan, init_outer, xs_outer) + compile_out = torch.compile(fn)(scan, init_outer, xs_outer) + exp = fn(_fake_scan, init_outer, xs_outer) + + self.assertEqual(out, exp) + self.assertEqual(compile_out, exp) + @skipIfTorchDynamo("Skip because we're testing export") @parametrize("strict", [True, False]) @parametrize("dynamic", [True, False]) diff --git a/torch/_functorch/predispatch.py b/torch/_functorch/predispatch.py index 44fbd5b632c..aca329be3eb 100644 --- a/torch/_functorch/predispatch.py +++ b/torch/_functorch/predispatch.py @@ -28,6 +28,7 @@ def _add_batch_dim(self, batch_dim, level): from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export mode = _maybe_find_pre_dispatch_tf_mode_for_export() + batch_dim = self.ndim + batch_dim if batch_dim < 0 else batch_dim if mode: return torch.overrides.handle_torch_function( diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index e3274991cb2..d9a36ebff14 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -357,7 +357,7 @@ def generic_scan(operator, init, xs, dim=0, additional_inputs=()): # Expand outs with None depending on the tensor mask of the output outs_expanded = [outs.pop(0) if out_m else None for out_m in out_tensor_mask] - return [*carry, *outs_expanded] + return (*carry, *outs_expanded) scans = _scan(init, xs) return scans @@ -866,6 +866,64 @@ def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs): return ctx.wrap_tensors(ret) +@scan_op.py_impl(torch._C._functorch.TransformType.Vmap) +def scan_batch_rule(interpreter, combine_fn, init, xs, additional_inputs): + from torch._functorch.vmap import restore_vmap, unwrap_batched, wrap_batched + + unbatched_args, in_dims = unwrap_batched( + (init, xs, additional_inputs), interpreter.level() + ) + # move to last dim to not interfere with scan's batching + unbatched_init, unbatched_xs, unbatched_additional_inputs = pytree.tree_map( + lambda x, bdim: x.movedim(bdim, -1) if bdim is not None else x, + unbatched_args, + in_dims, + ) + after_move_dims = tuple( + pytree.tree_flatten( + pytree.tree_map(lambda x: -1 if x is not None else None, in_dims) + )[0] + ) + + with interpreter.lower(): + out_dims = None + + def wrapper(*args): + nonlocal out_dims + outputs, per_slice_out_dims = restore_vmap( + combine_fn, + after_move_dims, + interpreter.batch_size(), + interpreter.randomness(), + )(*args) + # Note: outputs are not batched, we just move the batch dim to the end + # this is to avoid it interfering with scan's batching + outputs = tuple( + pytree.tree_map( + lambda out, out_bdim: out.movedim(out_bdim, -1) + if out_bdim is not None + else out, + outputs, + per_slice_out_dims, + ) + ) + out_dims = tuple( + pytree.tree_map( + lambda out_bdim: -1 if out_bdim is not None else None, + per_slice_out_dims, + ) + ) + return outputs + + unwrapped_out = scan_op( + wrapper, unbatched_init, unbatched_xs, unbatched_additional_inputs + ) + + assert out_dims is not None + batched_out = wrap_batched(unwrapped_out, out_dims, interpreter.level()) + return batched_out + + # dense implementation for scan. Used for testing only. def _fake_scan(combine_fn, init, xs=None, dim=0, reverse=False): carry_leaves, carry_spec = pytree.tree_flatten(init)