[while_loop][autograd] add hop while_loop_stack_output (#160467)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160467
Approved by: https://github.com/zou3519
ghstack dependencies: #160548
This commit is contained in:
Yidi Wu 2025-09-05 15:39:00 -07:00 committed by PyTorch MergeBot
parent 5927a70934
commit 48e3be3ab6
10 changed files with 604 additions and 282 deletions

View File

@ -8239,6 +8239,36 @@ class GraphModule(torch.nn.Module):
""", # noqa: B950
)
@parametrize("dynamic", [True, False])
@parametrize("backend", ["eager", "aot_eager"])
def test_compile_while_loop_stack_output(self, dynamic, backend):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
c = torch.tensor(0, dtype=torch.int64)
def cond_fn(c, x):
return c < x.size(0)
def body_fn(c, x):
return c + 1, self.linear(x)
stacked_c, stacked_x = torch.ops.higher_order.while_loop_stack_output(
cond_fn, body_fn, (c, x), tuple()
)
return stacked_c, stacked_x
x = torch.randn(3, 3)
mod = Mod()
compiled_out = torch.compile(mod, backend=backend, dynamic=dynamic)(x)
self.assertEqual(len(compiled_out), 2)
self.assertEqual(compiled_out[0].size(0), 3)
self.assertEqual(compiled_out[1].size(0), 3)
self.assertEqual(compiled_out, mod(x))
def test_input_output_alias(self):
def fn(f, *args):
return torch.cond(args[0].sum() > 0, f, f, args)

View File

@ -1088,6 +1088,23 @@ class WhileLoopModels:
(c, x),
)
class WhileLoopStackOutputSimple(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.linear = torch.nn.Linear(3, 3, device=device)
def forward(self, c, x):
def cond_fn(c, x):
return c < x.size(0)
def body_fn(c, x):
return c + 1, self.linear(x)
stacked_c, stacked_x = torch.ops.higher_order.while_loop_stack_output(
cond_fn, body_fn, (c, x), tuple()
)
return stacked_c, stacked_x
class WhileLoopTests(TestCase):
def _run_test(
@ -1407,6 +1424,17 @@ class WhileLoopTests(TestCase):
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
def test_while_loop_stack_output_simple(self, device, dynamic):
self._run_test(
model=WhileLoopModels.WhileLoopStackOutputSimple(device),
inputs=(torch.randn(3, 3, dtype=torch.float32),),
device=device,
dynamic=dynamic,
)
class AssociativeScanTests(TestCase):
@requires_gpu

View File

@ -312,6 +312,246 @@ def _check_supported_callable_arg(
)
def _call_while_loop(
self: VariableTracker,
tx: "InstructionTranslator",
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
stack_output: bool,
) -> VariableTracker:
from torch._higher_order_ops.while_loop import _create_unbacked_symint
from . import TensorVariable
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
cond_fn, body_fn, operands, additional_inputs = args
# Input checks
for i, k in enumerate(["cond_fn", "body_fn", "operands"]):
if v := kwargs.pop(k, None):
assert i == len(args), (
"did not provide the right number of non-keyword args"
)
args.append(v)
if kwargs:
unimplemented(f"torch.while_loop: Got unexpected kwargs: {list(kwargs.keys())}")
if len(args) != 4:
unimplemented(
f"Expected 4 arguments but got {len(args)}.\n"
f"Usage: while_loop(cond_fn, body_fn, operands)",
)
# cond_fn and body_fn input check
_check_supported_callable_arg(tx, cond_fn, "cond_fn")
_check_supported_callable_arg(tx, body_fn, "body_fn")
# operands input check
operands_seq = operands.unpack_var_sequence(tx)
# additional_inputs input check
if not isinstance(additional_inputs, (ListVariable, TupleVariable)):
unimplemented(
f"Expected additional_inputs to be a list/tuple but got "
f"{additional_inputs.python_type()}. It seems to be an "
f"internal error, please report an issue to PyTorch."
)
additional_inputs_seq = additional_inputs.unpack_var_sequence(tx)
with discard_graph_changes(tx):
# Note: this must be run under discard graph changes.
def unspecialize_carried_inputs(tx, carry) -> VariableTracker:
# See NOTE [unspecialize int carry with unbacked symints]
if (
isinstance(carry, ConstantVariable) and carry.python_type() is int
) or isinstance(carry, SymNodeVariable):
example_value = _create_unbacked_symint(
tx.output.fake_mode, ignore_fresh_unbacked_symbols=True
)
proxy = tx.output.current_tracer.create_graph_input(
"unbacked_symint", type(example_value), example_value
)
return SymNodeVariable.create(tx, proxy, example_value)
else:
# See NOTE [unspecialize constant tensor carry]
assert isinstance(carry, TensorVariable)
cloned_carry = carry.clone()
cloned_carry.proxy.node.meta["example_value"].constant = None
return cloned_carry
# clone inputs across subgraphs, to avoid unbacked memoization in fake prop
cond_operands_seq = [
unspecialize_carried_inputs(
tx,
(
carry.call_method(tx, "clone", args=(), kwargs={})
if isinstance(carry, TensorVariable)
else carry
),
)
for carry in operands_seq
]
body_operands_seq = [
unspecialize_carried_inputs(
tx,
(
carry.call_method(tx, "clone", args=(), kwargs={})
if isinstance(carry, TensorVariable)
else carry
),
)
for carry in operands_seq
]
# create cond subgrpahs
(
(cond_r, _cond_treespec),
cond_graph,
cond_lifted_freevars,
) = speculate_subgraph(
tx,
cond_fn,
cond_operands_seq + additional_inputs_seq,
{},
"while_loop",
source_target=self.value,
# NOTE [why we cannot use "automatic" for while_loop]:
# The reason is that we want to enforce
# the ordering of inputs and outputs to be consistent and the the ordering
# of cond_fn and body_fn to the consistent.
# e.g. suppose we use "automatic" and we have:
#
# def body_fn(ph1, ph2):
# new_a, new_b = ph2.cos(), ph1.sin()
# return new_a, new_b
#
# a, b = torch.randn(3), torch.randn(3)
# new_a, new_b = body_fn(a, b)
#
# Using automatic, the ordering of arguments will be the order that they're
# used. In this example, the capture graph looks like:
#
# def captured_body(ph1, ph2):
# new_a, new_b = ph1.cos(), ph2.add_(1)
# return new_a, new_b
#
# This is fine when we change the calling convention of captured_body to be
# new_a, new_b = captured_body(b, a).
# But for while_loop, the next iteration's input is previous iteration output
# we'll end up feeding captured_body(new_a, new_b) instead.
# So it's best we always enforce the ordering of carried_inputs the same as outputs
# with "flatten_manual".
set_subgraph_inputs="flatten_manual",
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
remove_consts_from_outputs=False,
)
cond_nn_modules = dict(tx.output.nn_modules)
validate_subgraph_output_types(cond_r)
if isinstance(cond_r, TensorVariable):
cond_r_meta = _extract_tensor_metadata(
cond_r.proxy.node.meta["example_value"], include_contiguity=False
)
if not cond_r_meta.dtype == torch.bool or not cond_r_meta.shape == torch.Size(
[]
):
unimplemented(
f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}"
)
elif isinstance(cond_r, ConstantVariable):
# short-circuiting while_loop when cond_fn returns a constant such as 0, 1 True or False
pred = cond_r.as_python_constant()
if pred:
unimplemented(
f"Infinite loop detected because while_loop's cond_fn always returns the same value {pred}"
)
else:
return operands
# create body subgraph
(
(body_r, body_treespec),
body_graph,
body_lifted_freevars,
) = speculate_subgraph(
tx,
body_fn,
body_operands_seq + additional_inputs_seq,
{},
"while_loop",
source_target=self.value,
set_subgraph_inputs="flatten_manual",
should_flatten_outputs=True,
supports_input_mutation=False,
supports_aliasing=False,
remove_consts_from_outputs=False,
)
validate_subgraph_output_types(body_r)
# We set include contiguity=False because we have vmap x HOP tests, where if
# include_contiguity=True will call t.is_contiguous inside of vmap and get an error
# "querying is_contiguous inside of vmap for memory_format other than
# torch.contiguous_format is not yet implemented". This is okay because stride
# is still checked.
check_meta_consistency_vt(
body_r.unpack_var_sequence(tx),
operands_seq,
"body_fn_output",
"carried_inputs",
include_contiguity=False,
)
(
cond_graph,
body_graph,
cond_shared,
_body_shared,
cond_unique,
body_unique,
) = _merge_graph_inputs(
cond_graph,
cond_lifted_freevars,
"cond_fn",
body_graph,
body_lifted_freevars,
"body_fn",
)
# Note: cond_shared and body_shared refer to the same proxy in parent graph
# so using either of them is OK. Use cond_shared as it doesn't matter.
additional_lifted_inputs = cond_shared + cond_unique + body_unique
body_nn_modules = dict(tx.output.nn_modules)
cond_gm = torch.fx.GraphModule(cond_nn_modules, cond_graph)
body_gm = torch.fx.GraphModule(body_nn_modules, body_graph)
cond_name = tx.output.install_subgraph("cond_fn", cond_gm)
body_name = tx.output.install_subgraph("body_fn", body_gm)
cond_node = make_attr(tx, cond_name)
body_node = make_attr(tx, body_name)
operands_proxy = tuple(operand.as_proxy() for operand in operands_seq)
additional_inputs_proxy = tuple(
[inp.as_proxy() for inp in additional_inputs_seq] + additional_lifted_inputs
)
p_args = (
cond_node,
body_node,
operands_proxy,
additional_inputs_proxy,
)
return _call_function_and_unflatten_output(
tx,
self.value,
p_args,
{},
None,
body_treespec,
)
def are_same_graph_modules(fn_name, a_mod, b_mod, fake_mode):
from torch._subclasses._fake_tensor_utils import _CacheKeyState
from torch._subclasses.fake_tensor import extract_tensor_metadata
@ -1280,243 +1520,23 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
from torch._higher_order_ops.while_loop import _create_unbacked_symint
return _call_while_loop(self, tx, args, kwargs, stack_output=False)
from . import TensorVariable
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
cond_fn, body_fn, operands, additional_inputs = args
class WhileLoopStackOutputHigherOrderVariable(TorchHigherOrderOperatorVariable):
supports_input_mutation = False
supports_aliasing = False
# Input checks
for i, k in enumerate(["cond_fn", "body_fn", "operands"]):
if v := kwargs.pop(k, None):
assert i == len(args), (
"did not provide the right number of non-keyword args"
)
args.append(v)
if kwargs:
unimplemented(
f"torch.while_loop: Got unexpected kwargs: {list(kwargs.keys())}"
)
if len(args) != 4:
unimplemented(
f"Expected 4 arguments but got {len(args)}.\n"
f"Usage: while_loop(cond_fn, body_fn, operands)",
)
# cond_fn and body_fn input check
_check_supported_callable_arg(tx, cond_fn, "cond_fn")
_check_supported_callable_arg(tx, body_fn, "body_fn")
# operands input check
operands_seq = operands.unpack_var_sequence(tx)
# additional_inputs input check
if not isinstance(additional_inputs, (ListVariable, TupleVariable)):
unimplemented(
f"Expected additional_inputs to be a list/tuple but got "
f"{additional_inputs.python_type()}. It seems to be an "
f"internal error, please report an issue to PyTorch."
)
additional_inputs_seq = additional_inputs.unpack_var_sequence(tx)
with discard_graph_changes(tx):
# Note: this must be run under discard graph changes.
def unspecialize_carried_inputs(tx, carry) -> VariableTracker:
# See NOTE [unspecialize int carry with unbacked symints]
if (
isinstance(carry, ConstantVariable) and carry.python_type() is int
) or isinstance(carry, SymNodeVariable):
example_value = _create_unbacked_symint(
tx.output.fake_mode, ignore_fresh_unbacked_symbols=True
)
proxy = tx.output.current_tracer.create_graph_input(
"unbacked_symint", type(example_value), example_value
)
return SymNodeVariable.create(tx, proxy, example_value)
else:
# See NOTE [unspecialize constant tensor carry]
assert isinstance(carry, TensorVariable)
cloned_carry = carry.clone()
cloned_carry.proxy.node.meta["example_value"].constant = None
return cloned_carry
# clone inputs across subgraphs, to avoid unbacked memoization in fake prop
cond_operands_seq = [
unspecialize_carried_inputs(
tx,
(
carry.call_method(tx, "clone", args=(), kwargs={})
if isinstance(carry, TensorVariable)
else carry
),
)
for carry in operands_seq
]
body_operands_seq = [
unspecialize_carried_inputs(
tx,
(
carry.call_method(tx, "clone", args=(), kwargs={})
if isinstance(carry, TensorVariable)
else carry
),
)
for carry in operands_seq
]
# create cond subgrpahs
(
(cond_r, _cond_spec),
cond_graph,
cond_lifted_freevars,
) = speculate_subgraph(
tx,
cond_fn,
cond_operands_seq + additional_inputs_seq,
{},
"while_loop",
source_target=self.value,
# NOTE [why we cannot use "automatic" for while_loop]:
# The reason is that we want to enforce
# the ordering of inputs and outputs to be consistent and the the ordering
# of cond_fn and body_fn to the consistent.
# e.g. suppose we use "automatic" and we have:
#
# def body_fn(ph1, ph2):
# new_a, new_b = ph2.cos(), ph1.sin()
# return new_a, new_b
#
# a, b = torch.randn(3), torch.randn(3)
# new_a, new_b = body_fn(a, b)
#
# Using automatic, the ordering of arguments will be the order that they're
# used. In this example, the capture graph looks like:
#
# def captured_body(ph1, ph2):
# new_a, new_b = ph1.cos(), ph2.add_(1)
# return new_a, new_b
#
# This is fine when we change the calling convention of captured_body to be
# new_a, new_b = captured_body(b, a).
# But for while_loop, the next iteration's input is previous iteration output
# we'll end up feeding captured_body(new_a, new_b) instead.
# So it's best we always enforce the ordering of carried_inputs the same as outputs
# with "flatten_manual".
set_subgraph_inputs="flatten_manual",
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
)
cond_nn_modules = dict(tx.output.nn_modules)
validate_subgraph_output_types(cond_r)
if isinstance(cond_r, TensorVariable):
cond_r_meta = _extract_tensor_metadata(
cond_r.proxy.node.meta["example_value"], include_contiguity=False
)
if (
not cond_r_meta.dtype == torch.bool
or not cond_r_meta.shape == torch.Size([])
):
unimplemented(
f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}"
)
elif isinstance(cond_r, ConstantVariable):
# short-circuiting while_loop when cond_fn returns a constant such as 0, 1 True or False
pred = cond_r.as_python_constant()
if pred:
unimplemented(
f"Infinite loop detected because while_loop's cond_fn always returns the same value {pred}"
)
else:
return operands
# create body subgraph
(
(body_r, body_spec),
body_graph,
body_lifted_freevars,
) = speculate_subgraph(
tx,
body_fn,
body_operands_seq + additional_inputs_seq,
{},
"while_loop",
source_target=self.value,
set_subgraph_inputs="flatten_manual",
should_flatten_outputs=True,
# TODO - removing consts from control flow ops need more work
remove_consts_from_outputs=False,
supports_input_mutation=False,
supports_aliasing=False,
)
validate_subgraph_output_types(body_r)
# We set include contiguity=False because we have vmap x HOP tests, where if
# include_contiguity=True will call t.is_contiguous inside of vmap and get an error
# "querying is_contiguous inside of vmap for memory_format other than
# torch.contiguous_format is not yet implemented". This is okay because stride
# is still checked.
check_meta_consistency_vt(
body_r.unpack_var_sequence(tx),
operands_seq,
"body_fn_output",
"carried_inputs",
include_contiguity=False,
)
(
cond_graph,
body_graph,
cond_shared,
_body_shared,
cond_unique,
body_unique,
) = _merge_graph_inputs(
cond_graph,
cond_lifted_freevars,
"cond_fn",
body_graph,
body_lifted_freevars,
"body_fn",
)
# Note: cond_shared and body_shared refer to the same proxy in parent graph
# so using either of them is OK. Use cond_shared as it doesn't matter.
additional_lifted_inputs = cond_shared + cond_unique + body_unique
body_nn_modules = dict(tx.output.nn_modules)
cond_name = tx.output.install_subgraph(
"cond_fn",
torch.fx.GraphModule(cond_nn_modules, cond_graph),
)
body_name = tx.output.install_subgraph(
"body_fn",
torch.fx.GraphModule(body_nn_modules, body_graph),
)
cond_node = make_attr(tx, cond_name)
body_node = make_attr(tx, body_name)
p_args = (
cond_node,
body_node,
tuple([operand.as_proxy() for operand in operands_seq]),
tuple(
[inp.as_proxy() for inp in additional_inputs_seq]
+ additional_lifted_inputs
),
)
return _call_function_and_unflatten_output(
tx,
torch.ops.higher_order.while_loop,
p_args,
{},
None,
body_spec,
)
@raise_hard_error_if_graph_break(
reason="while_loop_stack_output doesn't work unless it is captured completely with torch.compile."
)
def _call_function(
self,
tx: "InstructionTranslator",
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
return _call_while_loop(self, tx, args, kwargs, stack_output=True)
class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
@ -3481,6 +3501,7 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
_hop_name_to_variable_class = {
"cond": CondHigherOrderVariable,
"while_loop": WhileLoopHigherOrderVariable,
"while_loop_stack_output": WhileLoopStackOutputHigherOrderVariable,
"map_impl": MapHigherOrderVariable,
"executorch_call_delegate": ExecutorchCallDelegateHigherOrderVariable,
"out_dtype": OutDtypeHigherOrderVariable,

View File

@ -27,7 +27,10 @@ from torch._higher_order_ops.run_const_graph import run_const_graph
from torch._higher_order_ops.scan import scan
from torch._higher_order_ops.strict_mode import strict_mode
from torch._higher_order_ops.torchbind import call_torchbind
from torch._higher_order_ops.while_loop import while_loop
from torch._higher_order_ops.while_loop import (
while_loop,
while_loop_stack_output_op as while_loop_stack_output,
)
from torch._higher_order_ops.wrap import (
dynamo_bypassing_wrapper,
tag_activation_checkpoint,
@ -69,4 +72,5 @@ __all__ = [
"strict_mode",
"aoti_call_delegate",
"map",
"while_loop_stack_output",
]

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import contextlib
import functools
from typing import Any, Callable, Union
import torch
@ -260,7 +261,9 @@ def while_loop(cond_fn, body_fn, carried_inputs):
@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
def while_loop_dense(
cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False
):
carried_vals = carried_inputs
def _validate_cond_output(pred):
@ -285,13 +288,25 @@ def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
_validate_cond_output(should_loop)
if not should_loop:
return tuple(
val.clone() if isinstance(val, torch.Tensor) else val
for val in carried_vals + additional_inputs
)
if stack_output:
return tuple(
val.unsqueeze(0).clone() if isinstance(val, torch.Tensor) else val
for val in carried_vals
)
else:
return tuple(
val.clone() if isinstance(val, torch.Tensor) else val
for val in carried_vals
)
outputs: list[list[torch.Tensor]] = [[] for _ in carried_vals]
while should_loop:
out = body_fn(*carried_vals, *additional_inputs)
if stack_output:
for i, o in enumerate(out):
outputs[i].append(o)
assert isinstance(out, tuple), (
f"body_fn should return a tuple but got {type(out)}"
)
@ -302,6 +317,12 @@ def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
should_loop = cond_fn(*carried_vals, *additional_inputs)
if stack_output:
outs: list[torch.Tensor] = []
for i, out in enumerate(outputs):
outs.append(torch.stack(out, dim=0))
return tuple(outs)
return carried_vals
@ -336,9 +357,18 @@ def _create_unbacked_symint(
@while_loop_op.py_impl(ProxyTorchDispatchMode)
def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs):
def while_loop_tracing(
mode,
cond_fn,
body_fn,
carried_inputs,
additional_inputs,
stack_output=False,
):
op = while_loop_stack_output_op if stack_output else while_loop_op
def _trace_while_loop(
proxy_mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
proxy_mode, op, cond_fn, body_fn, carried_inputs, additional_inputs
):
# NOTE [unspecialize int carry with unbacked symints]
# When we support int carry, we'll also need to support int output of body_fn because.
@ -437,10 +467,10 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", while_loop_op, proxy_args, {}, name="while_loop"
"call_function", op, proxy_args, {}, name=op._name
)
out = while_loop_op(
out = op(
cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs
)
return track_tensor_tree(
@ -448,13 +478,18 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
)
return _trace_while_loop(
mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
mode,
op,
cond_fn,
body_fn,
carried_inputs,
additional_inputs,
)
@while_loop_op.py_impl(FakeTensorMode)
def while_loop_fake_tensor_mode(
mode, cond_fn, body_fn, carried_inputs, additional_inputs
mode, cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False
):
with mode:
# NOTE: [Handling unback symints in subgraph of while_loop]
@ -499,6 +534,26 @@ def while_loop_fake_tensor_mode(
"body_output",
include_contiguity=False,
)
if stack_output:
n_iter = _create_unbacked_symint(mode, ignore_fresh_unbacked_symbols=False)
assert all(isinstance(x, torch.Tensor) for x in carried_inputs)
fake_outputs = tuple(
out.clone()
.unsqueeze(0)
.repeat((n_iter,) + tuple(1 for _ in range(out.dim())))
for out in body_outs
)
return pytree.tree_map_only(
(int, torch.SymInt),
# For while_loop's unbacked symint output, we want them to be bound
# to the proxy of while_loop's output.
lambda _: _create_unbacked_symint(
mode, ignore_fresh_unbacked_symbols=False
),
fake_outputs,
)
# See NOTE [unspecialize int carry with unbacked symints]
return pytree.tree_map_only(
(int, torch.SymInt),
@ -512,9 +567,13 @@ def while_loop_fake_tensor_mode(
@while_loop_op.py_functionalize_impl
def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
def while_loop_func(
ctx, cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False
):
from torch._higher_order_ops.utils import _check_alias_and_mutation
op = while_loop_stack_output_op if stack_output else while_loop_op
unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs
@ -527,10 +586,72 @@ def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
(body_fn, "body_fn"),
]:
_check_alias_and_mutation(fn, unwrapped_inputs, fn_name, pre_dispatch)
ret = while_loop_op(
ret = op(
functional_cond_fn,
functional_body_fn,
unwrapped_carried_inputs,
unwrapped_additional_inputs,
)
return ctx.wrap_tensors(ret)
class WhileLoopStackOutputOp(HigherOrderOperator):
"""
while_loop_stack_output is a variant of while_loop that returns a stack of outputs.
Its semantic can be illurated using python code as:
def while_loop_stack_output(cond_fn, body_fn, carried_inputs, additional_inputs):
outs = []
while cond_fn(*carried_inputs, *additional_inputs):
out = body_fn(*carried_inputs, *additional_inputs)
outs.append(out)
return torch.stack(outs)
It's useful for supporting autograd of while_loop.
"""
def __init__(self) -> None:
super().__init__("while_loop_stack_output")
def __call__(
self,
cond_fn: Callable,
body_fn: Callable,
carried_inputs: tuple[Union[torch.Tensor, int, float, bool]],
additional_inputs: tuple[Union[torch.Tensor, torch.SymInt, int], ...],
/,
):
if not isinstance(carried_inputs, (tuple, list)):
raise RuntimeError(
f"carried_inputs must be a tuple or list, got {type(carried_inputs)}"
)
if not isinstance(additional_inputs, (tuple, list)):
raise RuntimeError(
f"additional_inputs must be a tuple or list, got {type(additional_inputs)}"
)
validate_subgraph_args_types(carried_inputs)
validate_subgraph_args_types(additional_inputs)
return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
while_loop_stack_output_op = WhileLoopStackOutputOp()
while_loop_stack_output_op.py_impl(DispatchKey.CompositeExplicitAutograd)(
functools.partial(while_loop_dense, stack_output=True)
)
while_loop_stack_output_op.py_impl(ProxyTorchDispatchMode)(
functools.partial(while_loop_tracing, stack_output=True)
)
while_loop_stack_output_op.py_impl(FakeTensorMode)(
functools.partial(while_loop_fake_tensor_mode, stack_output=True)
)
while_loop_stack_output_op.py_functionalize_impl(
functools.partial(while_loop_func, stack_output=True)
)
while_loop_stack_output_op.py_autograd_impl(
autograd_not_implemented(while_loop_stack_output_op, deferred_error=True)
)

View File

@ -1947,7 +1947,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
finally:
self.pop_codegened_graph()
def codegen_while_loop(self, while_loop):
def codegen_while_loop(self, while_loop, stack_output=False):
if stack_output:
raise NotImplementedError("NYI cpp wrapper for while_loop_stack_output")
is_bool_pred = isinstance(
while_loop.cond_subgraph.graph.graph_outputs[0], ir.ShapeAsConstantBuffer
)

View File

@ -3370,7 +3370,9 @@ class PythonWrapperCodegen(CodeGen):
self.codegen_subgraph(conditional.false_subgraph, outer_inputs, name)
self.writeline(ExitSubgraphLine(self))
def codegen_while_loop(self, while_loop):
def codegen_while_loop(self, while_loop, stack_output):
"""while_loop is codegened as a host side while_loop"""
def codegen_subgraph(subgraph, outer_inputs, outer_outputs):
"""Helper method to deduplicate subgraph codegen logic"""
if V.graph.aot_mode:
@ -3388,7 +3390,13 @@ class PythonWrapperCodegen(CodeGen):
buf.codegen_reference() for buf in while_loop.additional_inputs
]
ckp_offset = len(outer_carried_inputs)
self.writeline(f"{name} = [None] * {len(outer_carried_inputs)}")
if stack_output:
self.writeline(
f"{name}.extend([[] for _ in range({len(outer_carried_inputs)})])"
)
for i, inp in enumerate(outer_carried_inputs):
# set the initial state before the loop
self.writeline(f"{name}[{i}] = {inp}")
@ -3411,10 +3419,21 @@ class PythonWrapperCodegen(CodeGen):
)
self.writeline(f"should_loop = {cond_outer_outputs[0]}")
self.writeline("if not should_loop:")
for i, (carried_input, carried_buf) in enumerate(
zip(outer_carried_inputs, while_loop.carried_inputs)
):
self.writeline(f" {name}[{i}] = {carried_input}.clone()")
if stack_output:
# Handle the case when loop never executes
for i, (carried_input, carried_buf) in enumerate(
zip(outer_carried_inputs, while_loop.carried_inputs)
):
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
self.writeline(f"{name}[{i}] = {carried_input}.unsqueeze(0).clone()")
self.writeline(ExitSubgraphLine(self))
else:
for i, (carried_input, carried_buf) in enumerate(
zip(outer_carried_inputs, while_loop.carried_inputs)
):
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
self.writeline(f"{name}[{i}] = {carried_input}.clone()")
self.writeline(ExitSubgraphLine(self))
self.writeline("while should_loop:")
# Body execution
@ -3424,6 +3443,13 @@ class PythonWrapperCodegen(CodeGen):
)
self.writeline(ExitSubgraphLine(self))
# Collect outputs if enabled
if stack_output:
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
for i in range(len(outer_carried_inputs)):
self.writeline(f"{name}[{i + ckp_offset}].append({name}[{i}])")
self.writeline(ExitSubgraphLine(self))
# Condition check at end of loop
self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph))
codegen_subgraph(
@ -3432,6 +3458,17 @@ class PythonWrapperCodegen(CodeGen):
self.writeline(ExitSubgraphLine(self))
self.writeline(f" should_loop = {cond_outer_outputs[0]}")
# Stack outputs after loop completion
if stack_output:
self.writeline("# Stack outputs after loop completion")
for i in range(len(outer_carried_inputs)):
self.writeline(f"if len({name}[{i + ckp_offset}]) > 0:")
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
self.writeline(
f"{name}[{i}] = torch.stack({name}[{i + ckp_offset}], dim=0)"
)
self.writeline(ExitSubgraphLine(self))
@staticmethod
def statically_known_int_or_none(x):
try:

View File

@ -8431,6 +8431,12 @@ class Conditional(ExternKernel):
self.name = V.graph.register_buffer(self)
V.graph.register_operation(self)
@staticmethod
def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]:
if isinstance(s, int):
return s
return s.node.expr
@classmethod
def create(
cls,
@ -8497,18 +8503,15 @@ class Conditional(ExternKernel):
unbacked_bindings=unbacked_bindings,
)
def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]:
if isinstance(s, int):
return s
return s.node.expr
outputs = [
MultiOutput(
FixedLayout(
device=device,
dtype=output.get_dtype(),
size=[_maybe_expr(sz) for sz in merged_output.size()],
stride=[_maybe_expr(sz) for sz in merged_output.stride()],
size=[Conditional._maybe_expr(sz) for sz in merged_output.size()],
stride=[
Conditional._maybe_expr(sz) for sz in merged_output.stride()
],
offset=output.get_layout().offset,
is_pinned=output.get_layout().is_pinned,
),
@ -8558,7 +8561,7 @@ def _split_by_sym_type(
@ir_dataclass(frozen=False)
class WhileLoop(ExternKernel):
"""IR node for while_loop, which supports input mutations"""
"""The IR node for while_loop and while_loop_stack_output. It supports input mutation."""
carried_inputs: Optional[Sequence[IRNode]] = None
additional_inputs: Optional[Sequence[IRNode]] = None
@ -8573,6 +8576,8 @@ class WhileLoop(ExternKernel):
cond_subgraph: Subgraph,
body_subgraph: Subgraph,
layout: MultiOutputLayout,
unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
stack_output: bool,
) -> None:
self.carried_inputs = carried_inputs
self.additional_inputs = additional_inputs
@ -8588,6 +8593,9 @@ class WhileLoop(ExternKernel):
inputs=tensor_args,
constant_args=sym_args,
)
if unbacked_bindings is not None:
self.unbacked_bindings = unbacked_bindings
self.stack_output = stack_output
self.name = V.graph.register_buffer(self)
V.graph.register_operation(self)
@ -8631,7 +8639,11 @@ class WhileLoop(ExternKernel):
body_fn: Subgraph,
carried_inputs: Sequence[IRNode],
additional_inputs: Sequence[IRNode],
stack_output: bool,
) -> Union[IRNode, Sequence[IRNode]]:
"""create the while_loop IR node. stack_output controls whether it stack
each iterations' output, which is necessary for training.
"""
from torch._higher_order_ops.utils import check_input_alias_and_mutation
def _require_exact_strides(
@ -8740,6 +8752,12 @@ class WhileLoop(ExternKernel):
assert op.get_dtype() == bo.get_dtype(), (i, op, bo)
assert device is not None
unbacked_bindings = resolve_unbacked_bindings(
V.graph.sizevars.shape_env,
V.graph.current_node.meta.get("unbacked_bindings", None),
)
while_loop = WhileLoop(
carried_inputs=carried_inputs_,
additional_inputs=additional_inputs_,
@ -8747,6 +8765,8 @@ class WhileLoop(ExternKernel):
body_subgraph=body_fn,
# asserted above that there is at least one operand
layout=MultiOutputLayout(device=device),
unbacked_bindings=unbacked_bindings,
stack_output=stack_output,
)
assert body_fn.graph is not None and isinstance(
@ -8762,34 +8782,51 @@ class WhileLoop(ExternKernel):
# Create all outputs first
mutated_inputs_iter = iter(mutated_inputs)
all_outputs = []
all_outputs: list[IRNode] = []
while_loop.outputs = []
while_loop.mutation_outputs = []
for idx, output in enumerate(body_outputs):
if idx in mutated_idx_set:
assert idx < len(carried_inputs), "only carries can be mutated."
# Create MutationOutput for mutated inputs
mutated_input = next(mutated_inputs_iter)
while_loop.mutation_outputs.append(
MutationOutput(mutated_input.layout, mutated_input, while_loop) # type: ignore[attr-defined, union-attr]
)
all_outputs.append(mutated_input)
else:
if stack_output:
assert len(mutated_idx_set) == 0, (
"NYI: while_loop_stack_output input mutations."
)
for idx, output in enumerate(V.graph.current_node.meta["val"]):
# Create MultiOutput for regular outputs
multi_out = MultiOutput(
FixedLayout(
device=output.get_device(), # type: ignore[arg-type]
dtype=output.get_dtype(),
size=output.get_size(),
stride=output.get_stride(),
offset=output.get_layout().offset,
device=output.device, # type: ignore[arg-type]
dtype=output.dtype,
size=[Conditional._maybe_expr(sz) for sz in output.size()],
stride=[Conditional._maybe_expr(st) for st in output.stride()],
),
while_loop,
[(list, idx)],
)
while_loop.outputs.append(multi_out)
all_outputs.append(multi_out)
else:
for idx, output in enumerate(body_outputs):
if idx in mutated_idx_set:
assert idx < len(carried_inputs), "only carries can be mutated."
# Create MutationOutput for mutated inputs
mutated_input = next(mutated_inputs_iter)
while_loop.mutation_outputs.append(
MutationOutput(mutated_input.layout, mutated_input, while_loop) # type: ignore[attr-defined, union-attr]
)
all_outputs.append(mutated_input)
else:
multi_out = MultiOutput(
FixedLayout(
device=output.get_device(), # type: ignore[arg-type]
dtype=output.get_dtype(),
size=output.get_size(),
stride=output.get_stride(),
offset=output.get_layout().offset,
),
while_loop,
[(list, idx)],
)
while_loop.outputs.append(multi_out)
all_outputs.append(multi_out)
for inp, out in zip(carried_inputs, all_outputs):
if inp.get_name() in V.graph.graph_inputs:
@ -8802,7 +8839,20 @@ class WhileLoop(ExternKernel):
return all_outputs
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
wrapper.codegen_while_loop(self)
wrapper.codegen_while_loop(self, self.stack_output)
wrapper.codegen_unbacked_symbol_defs_for_outputs(
self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {})
)
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
if unbacked_bindings := getattr(self, "unbacked_bindings", None):
resolved = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, unbacked_bindings
)
assert resolved is not None
return OrderedSet(resolved.keys())
else:
return OrderedSet()
class EffectfulKernel(FallbackKernel):

View File

@ -7042,7 +7042,7 @@ def cond(pred, true_fn, false_fn, operands):
@register_lowering(torch.ops.higher_order.while_loop, type_promotion_kind=None)
def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False):
if any(
isinstance(x, IRNode) and is_triton(x)
for x in carried_inputs + additional_inputs
@ -7062,11 +7062,18 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
else:
raise RuntimeError(f"NYI unsupported output type: {type(out)}")
result = ir.WhileLoop.create(cond_fn, body_fn, carried_inputs, additional_inputs)
result = ir.WhileLoop.create(
cond_fn, body_fn, carried_inputs, additional_inputs, stack_output
)
assert isinstance(result, Sequence)
return list(map(_map_output, result))
register_lowering(
torch.ops.higher_order.while_loop_stack_output, type_promotion_kind=None
)(functools.partial(while_loop, stack_output=True))
@register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None)
def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands):
result = ir.InvokeSubgraph.create(subgraph_fn, *operands)

View File

@ -202,6 +202,15 @@ def simple_while_loop(iter_t, x):
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x))
def simple_while_loop_stack_output(iter_t, x):
def cond_fn(iter_t, x):
return iter_t > 0
def body_fn(iter_t, x):
return iter_t - 1, x.cos()
return torch._higher_order_ops.while_loop_stack_output(cond_fn, body_fn, (iter_t, x), tuple())
def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
@ -374,6 +383,19 @@ hop_db = [
check_inplace_batched_forward_grad=False,
supports_autograd=False,
),
OpInfo(
name="while_loop_stack_output",
variant_test_name="simple",
op=simple_while_loop_stack_output,
sample_inputs_func=sample_inputs_while_loop,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
supports_autograd=False,
),
OpInfo(
name="auto_functionalize",
variant_test_name="simple",