mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[scan x vmap] support scan in vmap (#165580)
This is required by the chunked_with_scan work where two nested vmap(vmap) with chunk sizes > 1 are invoked, which produces a scan-> vmap -> scan -> vmap chain and we need to handle the case of vmap(scan) and scan(vmap). The way we handle vmap(scan) is to turn it into scan(vmap(combine_fn)). The idea being that the combine_fn no longer do the combine_fn for a single slice, it vmaps over the combine_fn and do multiple combine_fns in one step. We need to need know how combine_fn propagates the batched tensor and what are the batched dims of the output. For this purpose, we use restore_vmap to give us the out_dims information. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165580 Approved by: https://github.com/zou3519 ghstack dependencies: #165675
This commit is contained in:
parent
282f39a4bc
commit
af4ba78543
|
|
@ -8087,6 +8087,260 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
|
||||||
self.assertEqual(eager_out, exp_out)
|
self.assertEqual(eager_out, exp_out)
|
||||||
self.assertEqual(compiled_out, exp_out)
|
self.assertEqual(compiled_out, exp_out)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
|
def test_scan_in_vmap_simple(self):
|
||||||
|
x = torch.randn(3, 4, 4)
|
||||||
|
y = torch.randn(4, 2)
|
||||||
|
zeros = torch.zeros(2, 3)
|
||||||
|
|
||||||
|
def combine_fn(init, xs):
|
||||||
|
return init.clone(), xs @ y
|
||||||
|
|
||||||
|
def fn(scan_op, x, y):
|
||||||
|
def inner_fn(zeros, x, y):
|
||||||
|
x = x.view(2, 2, 4)
|
||||||
|
|
||||||
|
return scan_op(
|
||||||
|
combine_fn,
|
||||||
|
zeros,
|
||||||
|
x,
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.vmap(inner_fn, in_dims=(1, 0, None))(zeros, x, y)
|
||||||
|
|
||||||
|
out = fn(scan, x, y)
|
||||||
|
compile_out = torch.compile(fn)(scan, x, y)
|
||||||
|
exp = fn(_fake_scan, x, y)
|
||||||
|
self.assertEqual(out, exp)
|
||||||
|
self.assertEqual(out, compile_out)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
|
def test_scan_in_vmap_complex_ops(self):
|
||||||
|
# Test with various operations requiring shape reasoning
|
||||||
|
x = torch.randn(4, 5, 3, 2)
|
||||||
|
init = torch.randn(4, 3, 2)
|
||||||
|
weight = torch.randn(3, 3)
|
||||||
|
|
||||||
|
def combine_fn(carry, xs):
|
||||||
|
# carry: (3, 2), xs: (3, 2)
|
||||||
|
intermediate = torch.nn.functional.relu(carry)
|
||||||
|
xs_t = xs.transpose(0, 1) # (2, 3)
|
||||||
|
result = xs_t @ weight # (2, 3)
|
||||||
|
new_carry = intermediate + result.transpose(0, 1) # Back to (3, 2)
|
||||||
|
output = torch.sin(carry).sum() + torch.cos(xs).mean()
|
||||||
|
return new_carry, output
|
||||||
|
|
||||||
|
def fn(scan_op, x, init):
|
||||||
|
def inner_fn(x, init):
|
||||||
|
return scan_op(combine_fn, init, x)
|
||||||
|
|
||||||
|
return torch.vmap(inner_fn, in_dims=(0, 0))(x, init)
|
||||||
|
|
||||||
|
out = fn(scan, x, init)
|
||||||
|
compile_out = torch.compile(fn)(scan, x, init)
|
||||||
|
exp = fn(_fake_scan, x, init)
|
||||||
|
|
||||||
|
self.assertEqual(out, exp)
|
||||||
|
self.assertEqual(compile_out, exp)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
|
def test_scan_in_vmap_unbatched_x(self):
|
||||||
|
# Test with various operations requiring shape reasoning
|
||||||
|
x = torch.randn(5, 3, 2)
|
||||||
|
init = torch.randn(4, 3, 2)
|
||||||
|
weight = torch.randn(3, 3)
|
||||||
|
|
||||||
|
def combine_fn(carry, xs):
|
||||||
|
# carry: (3, 2), xs: (3, 2)
|
||||||
|
intermediate = torch.nn.functional.relu(carry)
|
||||||
|
xs_t = xs.transpose(0, 1) # (2, 3)
|
||||||
|
result = xs_t @ weight # (2, 3)
|
||||||
|
new_carry = intermediate + result.transpose(0, 1) # Back to (3, 2)
|
||||||
|
output = torch.sin(carry).sum() + torch.cos(xs).mean()
|
||||||
|
return new_carry, output
|
||||||
|
|
||||||
|
def fn(scan_op, x, init):
|
||||||
|
def inner_fn(x, init):
|
||||||
|
return scan_op(combine_fn, init, x)
|
||||||
|
|
||||||
|
return torch.vmap(inner_fn, in_dims=(None, 0))(x, init)
|
||||||
|
|
||||||
|
out = fn(scan, x, init)
|
||||||
|
compile_out = torch.compile(fn)(scan, x, init)
|
||||||
|
exp = fn(_fake_scan, x, init)
|
||||||
|
|
||||||
|
self.assertEqual(out, exp)
|
||||||
|
self.assertEqual(compile_out, exp)
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("not a dynamo test")
|
||||||
|
def test_scan_in_vmap_unbatched_init_error(self):
|
||||||
|
# Test with various operations requiring shape reasoning
|
||||||
|
x = torch.randn(4, 5, 3, 2)
|
||||||
|
init = torch.randn(4, 3, 2)
|
||||||
|
weight = torch.randn(3, 3)
|
||||||
|
|
||||||
|
def combine_fn(carry, xs):
|
||||||
|
# carry: (3, 2), xs: (3, 2)
|
||||||
|
intermediate = torch.nn.functional.relu(carry)
|
||||||
|
xs_t = xs.transpose(0, 1) # (2, 3)
|
||||||
|
result = xs_t @ weight # (2, 3)
|
||||||
|
new_carry = intermediate + result.transpose(0, 1) # Back to (3, 2)
|
||||||
|
output = torch.sin(carry).sum() + torch.cos(xs).mean()
|
||||||
|
return new_carry, output
|
||||||
|
|
||||||
|
def vmap_fn(x, init):
|
||||||
|
def fn(x, init):
|
||||||
|
return scan(combine_fn, init, x)
|
||||||
|
|
||||||
|
return torch.vmap(fn, in_dims=(0, None))(x, init)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"""The size of tensor a \\(4\\) must match the size of tensor b \\(2\\) at non-singleton dimension 4""",
|
||||||
|
):
|
||||||
|
vmap_fn(x, init)
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("a vmap test, not a dynamo test")
|
||||||
|
def test_vmap_closure_weight_error(self):
|
||||||
|
init_batched = torch.randn(7, 2, 3)
|
||||||
|
xs_batched = torch.randn(7, 5, 4)
|
||||||
|
weight = torch.randn(7, 4, 3)
|
||||||
|
|
||||||
|
def combine_fn(carry, xs):
|
||||||
|
# carry: (2, 3), xs: (4,), weight: (4, 3)
|
||||||
|
new_carry = carry + xs @ weight
|
||||||
|
output = carry.sum()
|
||||||
|
return new_carry, output
|
||||||
|
|
||||||
|
def expected_fn(init, xs, weight):
|
||||||
|
def fn(init, xs, weight):
|
||||||
|
return _fake_scan(combine_fn, init, xs)
|
||||||
|
|
||||||
|
return torch.vmap(fn, in_dims=(0, 0, 0))(init, xs, weight)
|
||||||
|
|
||||||
|
# Note that even though weight is vampped but combine_fn is accessing
|
||||||
|
# the closure weight instead of the wrapped out weight thus causing
|
||||||
|
# a shape mismatch.
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"""The size of tensor a \\(2\\) must match the size of tensor b \\(7\\) at non-singleton dimension 1""",
|
||||||
|
):
|
||||||
|
expected_fn(init_batched, xs_batched, weight)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
|
def test_scan_in_vmap_mixed_batch_dims(self):
|
||||||
|
init = torch.randn(8, 5, 6)
|
||||||
|
xs_batched = torch.randn(3, 6, 5, 8)
|
||||||
|
scale = torch.randn([])
|
||||||
|
|
||||||
|
def combine_fn(carry, xs):
|
||||||
|
# carry: 8, 5
|
||||||
|
# xs: 5, 8
|
||||||
|
# new_carry: 8, 5
|
||||||
|
new_carry = carry + (xs * scale).sum()
|
||||||
|
output = xs @ carry
|
||||||
|
return new_carry, output
|
||||||
|
|
||||||
|
def fn(scan_op, init, xs):
|
||||||
|
def inner_fn(init, xs):
|
||||||
|
return scan_op(combine_fn, init, xs)
|
||||||
|
|
||||||
|
return torch.vmap(inner_fn, in_dims=(2, 1))(init, xs)
|
||||||
|
|
||||||
|
out = fn(scan, init, xs_batched)
|
||||||
|
compile_out = torch.compile(fn)(scan, init, xs_batched)
|
||||||
|
exp = fn(_fake_scan, init, xs_batched)
|
||||||
|
|
||||||
|
self.assertEqual(out, exp)
|
||||||
|
self.assertEqual(compile_out, exp)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
|
def test_vmap_scan_vmap_scan_nested(self):
|
||||||
|
# Outer batch: 3, inner batch: 4, outer scan: 5, inner scan: 6
|
||||||
|
init = torch.randn(3, 4, 2, 8)
|
||||||
|
xs_outer = torch.randn(3, 5, 4, 6, 2)
|
||||||
|
|
||||||
|
def fn(scan_op, init, xs):
|
||||||
|
def inner_combine(carry, xs):
|
||||||
|
# carry: (2, 8), xs: (2,)
|
||||||
|
new_carry = carry + xs.unsqueeze(-1)
|
||||||
|
output = carry.sum(dim=0) # (8,)
|
||||||
|
return new_carry, output
|
||||||
|
|
||||||
|
def outer_combine(init, xs):
|
||||||
|
# carry: (4, 2, 8,), xs: (4, 6, 2)
|
||||||
|
# xs has batch dimension 4 from outer vmap
|
||||||
|
|
||||||
|
def inner_fn(init, xs):
|
||||||
|
# init: (2, 8)
|
||||||
|
# xs: (6, 2)
|
||||||
|
# final_carry: (2, 8)
|
||||||
|
# outputs: (6, 8)
|
||||||
|
final_carry, outputs = scan_op(inner_combine, init, xs)
|
||||||
|
return (final_carry.sum(0, keepdim=True) + outputs).sum(
|
||||||
|
dim=0
|
||||||
|
) # (8,)
|
||||||
|
|
||||||
|
inner_results = torch.vmap(inner_fn)(init, xs) # (4, 8)
|
||||||
|
new_carry = init + inner_results.mean(dim=0) # (8,)
|
||||||
|
output = inner_results.sum(dim=0) # (8,)
|
||||||
|
return new_carry.expand(*init.size()), output
|
||||||
|
|
||||||
|
def vmap_inner_fn(init, xs):
|
||||||
|
# init: (4, 2, 8)
|
||||||
|
# xs: (5, 4, 6, 2)
|
||||||
|
return scan_op(outer_combine, init, xs)
|
||||||
|
|
||||||
|
return torch.vmap(vmap_inner_fn)(init, xs)
|
||||||
|
|
||||||
|
out = fn(scan, init, xs_outer)
|
||||||
|
compile_out = torch.compile(fn)(scan, init, xs_outer)
|
||||||
|
exp = fn(_fake_scan, init, xs_outer)
|
||||||
|
|
||||||
|
self.assertEqual(out, exp)
|
||||||
|
self.assertEqual(compile_out, exp)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
|
def test_scan_vmap_scan_nested(self):
|
||||||
|
xs_outer = torch.randn(5, 3, 4, 2)
|
||||||
|
init_outer = torch.randn(3, 8)
|
||||||
|
|
||||||
|
def fn(scan_op, init, xs):
|
||||||
|
def inner_combine_fake(carry, xs):
|
||||||
|
# carry: 8
|
||||||
|
# xs: 2
|
||||||
|
new_carry = carry + xs.sum()
|
||||||
|
output = carry * 2
|
||||||
|
return new_carry, output
|
||||||
|
|
||||||
|
def outer_combine_fake(carry, xs):
|
||||||
|
# carry: 3, 8
|
||||||
|
# xs: 3, 4, 2
|
||||||
|
def inner_fn(carry_elem, xs_elem):
|
||||||
|
# carry_elem: 8
|
||||||
|
# xs: 4, 2
|
||||||
|
# final_carry: 8
|
||||||
|
# outputs.sum(0): 8
|
||||||
|
final_carry, outputs = _fake_scan(
|
||||||
|
inner_combine_fake, carry_elem, xs_elem
|
||||||
|
)
|
||||||
|
return outputs.sum(0), final_carry
|
||||||
|
|
||||||
|
# result: (8,)
|
||||||
|
# next_carry, (3, 8))
|
||||||
|
result, next_carry = torch.vmap(inner_fn, in_dims=(0, 0))(carry, xs)
|
||||||
|
output = result.sum(dim=0)
|
||||||
|
return next_carry, output
|
||||||
|
|
||||||
|
return scan_op(outer_combine_fake, init, xs)
|
||||||
|
|
||||||
|
out = fn(scan, init_outer, xs_outer)
|
||||||
|
compile_out = torch.compile(fn)(scan, init_outer, xs_outer)
|
||||||
|
exp = fn(_fake_scan, init_outer, xs_outer)
|
||||||
|
|
||||||
|
self.assertEqual(out, exp)
|
||||||
|
self.assertEqual(compile_out, exp)
|
||||||
|
|
||||||
@skipIfTorchDynamo("Skip because we're testing export")
|
@skipIfTorchDynamo("Skip because we're testing export")
|
||||||
@parametrize("strict", [True, False])
|
@parametrize("strict", [True, False])
|
||||||
@parametrize("dynamic", [True, False])
|
@parametrize("dynamic", [True, False])
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ def _add_batch_dim(self, batch_dim, level):
|
||||||
from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
|
from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
|
||||||
|
|
||||||
mode = _maybe_find_pre_dispatch_tf_mode_for_export()
|
mode = _maybe_find_pre_dispatch_tf_mode_for_export()
|
||||||
|
batch_dim = self.ndim + batch_dim if batch_dim < 0 else batch_dim
|
||||||
|
|
||||||
if mode:
|
if mode:
|
||||||
return torch.overrides.handle_torch_function(
|
return torch.overrides.handle_torch_function(
|
||||||
|
|
|
||||||
|
|
@ -357,7 +357,7 @@ def generic_scan(operator, init, xs, dim=0, additional_inputs=()):
|
||||||
# Expand outs with None depending on the tensor mask of the output
|
# Expand outs with None depending on the tensor mask of the output
|
||||||
outs_expanded = [outs.pop(0) if out_m else None for out_m in out_tensor_mask]
|
outs_expanded = [outs.pop(0) if out_m else None for out_m in out_tensor_mask]
|
||||||
|
|
||||||
return [*carry, *outs_expanded]
|
return (*carry, *outs_expanded)
|
||||||
|
|
||||||
scans = _scan(init, xs)
|
scans = _scan(init, xs)
|
||||||
return scans
|
return scans
|
||||||
|
|
@ -866,6 +866,64 @@ def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
|
||||||
return ctx.wrap_tensors(ret)
|
return ctx.wrap_tensors(ret)
|
||||||
|
|
||||||
|
|
||||||
|
@scan_op.py_impl(torch._C._functorch.TransformType.Vmap)
|
||||||
|
def scan_batch_rule(interpreter, combine_fn, init, xs, additional_inputs):
|
||||||
|
from torch._functorch.vmap import restore_vmap, unwrap_batched, wrap_batched
|
||||||
|
|
||||||
|
unbatched_args, in_dims = unwrap_batched(
|
||||||
|
(init, xs, additional_inputs), interpreter.level()
|
||||||
|
)
|
||||||
|
# move to last dim to not interfere with scan's batching
|
||||||
|
unbatched_init, unbatched_xs, unbatched_additional_inputs = pytree.tree_map(
|
||||||
|
lambda x, bdim: x.movedim(bdim, -1) if bdim is not None else x,
|
||||||
|
unbatched_args,
|
||||||
|
in_dims,
|
||||||
|
)
|
||||||
|
after_move_dims = tuple(
|
||||||
|
pytree.tree_flatten(
|
||||||
|
pytree.tree_map(lambda x: -1 if x is not None else None, in_dims)
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
with interpreter.lower():
|
||||||
|
out_dims = None
|
||||||
|
|
||||||
|
def wrapper(*args):
|
||||||
|
nonlocal out_dims
|
||||||
|
outputs, per_slice_out_dims = restore_vmap(
|
||||||
|
combine_fn,
|
||||||
|
after_move_dims,
|
||||||
|
interpreter.batch_size(),
|
||||||
|
interpreter.randomness(),
|
||||||
|
)(*args)
|
||||||
|
# Note: outputs are not batched, we just move the batch dim to the end
|
||||||
|
# this is to avoid it interfering with scan's batching
|
||||||
|
outputs = tuple(
|
||||||
|
pytree.tree_map(
|
||||||
|
lambda out, out_bdim: out.movedim(out_bdim, -1)
|
||||||
|
if out_bdim is not None
|
||||||
|
else out,
|
||||||
|
outputs,
|
||||||
|
per_slice_out_dims,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
out_dims = tuple(
|
||||||
|
pytree.tree_map(
|
||||||
|
lambda out_bdim: -1 if out_bdim is not None else None,
|
||||||
|
per_slice_out_dims,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
unwrapped_out = scan_op(
|
||||||
|
wrapper, unbatched_init, unbatched_xs, unbatched_additional_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
assert out_dims is not None
|
||||||
|
batched_out = wrap_batched(unwrapped_out, out_dims, interpreter.level())
|
||||||
|
return batched_out
|
||||||
|
|
||||||
|
|
||||||
# dense implementation for scan. Used for testing only.
|
# dense implementation for scan. Used for testing only.
|
||||||
def _fake_scan(combine_fn, init, xs=None, dim=0, reverse=False):
|
def _fake_scan(combine_fn, init, xs=None, dim=0, reverse=False):
|
||||||
carry_leaves, carry_spec = pytree.tree_flatten(init)
|
carry_leaves, carry_spec = pytree.tree_flatten(init)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user