mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[while_loop][autograd] add hop while_loop_stack_output (#160467)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160467 Approved by: https://github.com/zou3519 ghstack dependencies: #160548
This commit is contained in:
parent
5927a70934
commit
48e3be3ab6
|
|
@ -8239,6 +8239,36 @@ class GraphModule(torch.nn.Module):
|
|||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("backend", ["eager", "aot_eager"])
|
||||
def test_compile_while_loop_stack_output(self, dynamic, backend):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
c = torch.tensor(0, dtype=torch.int64)
|
||||
|
||||
def cond_fn(c, x):
|
||||
return c < x.size(0)
|
||||
|
||||
def body_fn(c, x):
|
||||
return c + 1, self.linear(x)
|
||||
|
||||
stacked_c, stacked_x = torch.ops.higher_order.while_loop_stack_output(
|
||||
cond_fn, body_fn, (c, x), tuple()
|
||||
)
|
||||
return stacked_c, stacked_x
|
||||
|
||||
x = torch.randn(3, 3)
|
||||
mod = Mod()
|
||||
compiled_out = torch.compile(mod, backend=backend, dynamic=dynamic)(x)
|
||||
self.assertEqual(len(compiled_out), 2)
|
||||
self.assertEqual(compiled_out[0].size(0), 3)
|
||||
self.assertEqual(compiled_out[1].size(0), 3)
|
||||
self.assertEqual(compiled_out, mod(x))
|
||||
|
||||
def test_input_output_alias(self):
|
||||
def fn(f, *args):
|
||||
return torch.cond(args[0].sum() > 0, f, f, args)
|
||||
|
|
|
|||
|
|
@ -1088,6 +1088,23 @@ class WhileLoopModels:
|
|||
(c, x),
|
||||
)
|
||||
|
||||
class WhileLoopStackOutputSimple(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3, device=device)
|
||||
|
||||
def forward(self, c, x):
|
||||
def cond_fn(c, x):
|
||||
return c < x.size(0)
|
||||
|
||||
def body_fn(c, x):
|
||||
return c + 1, self.linear(x)
|
||||
|
||||
stacked_c, stacked_x = torch.ops.higher_order.while_loop_stack_output(
|
||||
cond_fn, body_fn, (c, x), tuple()
|
||||
)
|
||||
return stacked_c, stacked_x
|
||||
|
||||
|
||||
class WhileLoopTests(TestCase):
|
||||
def _run_test(
|
||||
|
|
@ -1407,6 +1424,17 @@ class WhileLoopTests(TestCase):
|
|||
dynamic=dynamic,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
def test_while_loop_stack_output_simple(self, device, dynamic):
|
||||
self._run_test(
|
||||
model=WhileLoopModels.WhileLoopStackOutputSimple(device),
|
||||
inputs=(torch.randn(3, 3, dtype=torch.float32),),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
)
|
||||
|
||||
|
||||
class AssociativeScanTests(TestCase):
|
||||
@requires_gpu
|
||||
|
|
|
|||
|
|
@ -312,6 +312,246 @@ def _check_supported_callable_arg(
|
|||
)
|
||||
|
||||
|
||||
def _call_while_loop(
|
||||
self: VariableTracker,
|
||||
tx: "InstructionTranslator",
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
stack_output: bool,
|
||||
) -> VariableTracker:
|
||||
from torch._higher_order_ops.while_loop import _create_unbacked_symint
|
||||
|
||||
from . import TensorVariable
|
||||
|
||||
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
|
||||
cond_fn, body_fn, operands, additional_inputs = args
|
||||
|
||||
# Input checks
|
||||
for i, k in enumerate(["cond_fn", "body_fn", "operands"]):
|
||||
if v := kwargs.pop(k, None):
|
||||
assert i == len(args), (
|
||||
"did not provide the right number of non-keyword args"
|
||||
)
|
||||
args.append(v)
|
||||
|
||||
if kwargs:
|
||||
unimplemented(f"torch.while_loop: Got unexpected kwargs: {list(kwargs.keys())}")
|
||||
|
||||
if len(args) != 4:
|
||||
unimplemented(
|
||||
f"Expected 4 arguments but got {len(args)}.\n"
|
||||
f"Usage: while_loop(cond_fn, body_fn, operands)",
|
||||
)
|
||||
|
||||
# cond_fn and body_fn input check
|
||||
_check_supported_callable_arg(tx, cond_fn, "cond_fn")
|
||||
_check_supported_callable_arg(tx, body_fn, "body_fn")
|
||||
|
||||
# operands input check
|
||||
operands_seq = operands.unpack_var_sequence(tx)
|
||||
|
||||
# additional_inputs input check
|
||||
if not isinstance(additional_inputs, (ListVariable, TupleVariable)):
|
||||
unimplemented(
|
||||
f"Expected additional_inputs to be a list/tuple but got "
|
||||
f"{additional_inputs.python_type()}. It seems to be an "
|
||||
f"internal error, please report an issue to PyTorch."
|
||||
)
|
||||
additional_inputs_seq = additional_inputs.unpack_var_sequence(tx)
|
||||
|
||||
with discard_graph_changes(tx):
|
||||
# Note: this must be run under discard graph changes.
|
||||
def unspecialize_carried_inputs(tx, carry) -> VariableTracker:
|
||||
# See NOTE [unspecialize int carry with unbacked symints]
|
||||
if (
|
||||
isinstance(carry, ConstantVariable) and carry.python_type() is int
|
||||
) or isinstance(carry, SymNodeVariable):
|
||||
example_value = _create_unbacked_symint(
|
||||
tx.output.fake_mode, ignore_fresh_unbacked_symbols=True
|
||||
)
|
||||
proxy = tx.output.current_tracer.create_graph_input(
|
||||
"unbacked_symint", type(example_value), example_value
|
||||
)
|
||||
return SymNodeVariable.create(tx, proxy, example_value)
|
||||
else:
|
||||
# See NOTE [unspecialize constant tensor carry]
|
||||
assert isinstance(carry, TensorVariable)
|
||||
cloned_carry = carry.clone()
|
||||
cloned_carry.proxy.node.meta["example_value"].constant = None
|
||||
return cloned_carry
|
||||
|
||||
# clone inputs across subgraphs, to avoid unbacked memoization in fake prop
|
||||
cond_operands_seq = [
|
||||
unspecialize_carried_inputs(
|
||||
tx,
|
||||
(
|
||||
carry.call_method(tx, "clone", args=(), kwargs={})
|
||||
if isinstance(carry, TensorVariable)
|
||||
else carry
|
||||
),
|
||||
)
|
||||
for carry in operands_seq
|
||||
]
|
||||
body_operands_seq = [
|
||||
unspecialize_carried_inputs(
|
||||
tx,
|
||||
(
|
||||
carry.call_method(tx, "clone", args=(), kwargs={})
|
||||
if isinstance(carry, TensorVariable)
|
||||
else carry
|
||||
),
|
||||
)
|
||||
for carry in operands_seq
|
||||
]
|
||||
|
||||
# create cond subgrpahs
|
||||
(
|
||||
(cond_r, _cond_treespec),
|
||||
cond_graph,
|
||||
cond_lifted_freevars,
|
||||
) = speculate_subgraph(
|
||||
tx,
|
||||
cond_fn,
|
||||
cond_operands_seq + additional_inputs_seq,
|
||||
{},
|
||||
"while_loop",
|
||||
source_target=self.value,
|
||||
# NOTE [why we cannot use "automatic" for while_loop]:
|
||||
# The reason is that we want to enforce
|
||||
# the ordering of inputs and outputs to be consistent and the the ordering
|
||||
# of cond_fn and body_fn to the consistent.
|
||||
# e.g. suppose we use "automatic" and we have:
|
||||
#
|
||||
# def body_fn(ph1, ph2):
|
||||
# new_a, new_b = ph2.cos(), ph1.sin()
|
||||
# return new_a, new_b
|
||||
#
|
||||
# a, b = torch.randn(3), torch.randn(3)
|
||||
# new_a, new_b = body_fn(a, b)
|
||||
#
|
||||
# Using automatic, the ordering of arguments will be the order that they're
|
||||
# used. In this example, the capture graph looks like:
|
||||
#
|
||||
# def captured_body(ph1, ph2):
|
||||
# new_a, new_b = ph1.cos(), ph2.add_(1)
|
||||
# return new_a, new_b
|
||||
#
|
||||
# This is fine when we change the calling convention of captured_body to be
|
||||
# new_a, new_b = captured_body(b, a).
|
||||
# But for while_loop, the next iteration's input is previous iteration output
|
||||
# we'll end up feeding captured_body(new_a, new_b) instead.
|
||||
# So it's best we always enforce the ordering of carried_inputs the same as outputs
|
||||
# with "flatten_manual".
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
supports_input_mutation=self.supports_input_mutation,
|
||||
supports_aliasing=self.supports_aliasing,
|
||||
remove_consts_from_outputs=False,
|
||||
)
|
||||
cond_nn_modules = dict(tx.output.nn_modules)
|
||||
validate_subgraph_output_types(cond_r)
|
||||
if isinstance(cond_r, TensorVariable):
|
||||
cond_r_meta = _extract_tensor_metadata(
|
||||
cond_r.proxy.node.meta["example_value"], include_contiguity=False
|
||||
)
|
||||
if not cond_r_meta.dtype == torch.bool or not cond_r_meta.shape == torch.Size(
|
||||
[]
|
||||
):
|
||||
unimplemented(
|
||||
f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}"
|
||||
)
|
||||
elif isinstance(cond_r, ConstantVariable):
|
||||
# short-circuiting while_loop when cond_fn returns a constant such as 0, 1 True or False
|
||||
pred = cond_r.as_python_constant()
|
||||
if pred:
|
||||
unimplemented(
|
||||
f"Infinite loop detected because while_loop's cond_fn always returns the same value {pred}"
|
||||
)
|
||||
else:
|
||||
return operands
|
||||
|
||||
# create body subgraph
|
||||
(
|
||||
(body_r, body_treespec),
|
||||
body_graph,
|
||||
body_lifted_freevars,
|
||||
) = speculate_subgraph(
|
||||
tx,
|
||||
body_fn,
|
||||
body_operands_seq + additional_inputs_seq,
|
||||
{},
|
||||
"while_loop",
|
||||
source_target=self.value,
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
should_flatten_outputs=True,
|
||||
supports_input_mutation=False,
|
||||
supports_aliasing=False,
|
||||
remove_consts_from_outputs=False,
|
||||
)
|
||||
validate_subgraph_output_types(body_r)
|
||||
|
||||
# We set include contiguity=False because we have vmap x HOP tests, where if
|
||||
# include_contiguity=True will call t.is_contiguous inside of vmap and get an error
|
||||
# "querying is_contiguous inside of vmap for memory_format other than
|
||||
# torch.contiguous_format is not yet implemented". This is okay because stride
|
||||
# is still checked.
|
||||
check_meta_consistency_vt(
|
||||
body_r.unpack_var_sequence(tx),
|
||||
operands_seq,
|
||||
"body_fn_output",
|
||||
"carried_inputs",
|
||||
include_contiguity=False,
|
||||
)
|
||||
|
||||
(
|
||||
cond_graph,
|
||||
body_graph,
|
||||
cond_shared,
|
||||
_body_shared,
|
||||
cond_unique,
|
||||
body_unique,
|
||||
) = _merge_graph_inputs(
|
||||
cond_graph,
|
||||
cond_lifted_freevars,
|
||||
"cond_fn",
|
||||
body_graph,
|
||||
body_lifted_freevars,
|
||||
"body_fn",
|
||||
)
|
||||
|
||||
# Note: cond_shared and body_shared refer to the same proxy in parent graph
|
||||
# so using either of them is OK. Use cond_shared as it doesn't matter.
|
||||
additional_lifted_inputs = cond_shared + cond_unique + body_unique
|
||||
|
||||
body_nn_modules = dict(tx.output.nn_modules)
|
||||
|
||||
cond_gm = torch.fx.GraphModule(cond_nn_modules, cond_graph)
|
||||
body_gm = torch.fx.GraphModule(body_nn_modules, body_graph)
|
||||
cond_name = tx.output.install_subgraph("cond_fn", cond_gm)
|
||||
body_name = tx.output.install_subgraph("body_fn", body_gm)
|
||||
|
||||
cond_node = make_attr(tx, cond_name)
|
||||
body_node = make_attr(tx, body_name)
|
||||
|
||||
operands_proxy = tuple(operand.as_proxy() for operand in operands_seq)
|
||||
additional_inputs_proxy = tuple(
|
||||
[inp.as_proxy() for inp in additional_inputs_seq] + additional_lifted_inputs
|
||||
)
|
||||
p_args = (
|
||||
cond_node,
|
||||
body_node,
|
||||
operands_proxy,
|
||||
additional_inputs_proxy,
|
||||
)
|
||||
return _call_function_and_unflatten_output(
|
||||
tx,
|
||||
self.value,
|
||||
p_args,
|
||||
{},
|
||||
None,
|
||||
body_treespec,
|
||||
)
|
||||
|
||||
|
||||
def are_same_graph_modules(fn_name, a_mod, b_mod, fake_mode):
|
||||
from torch._subclasses._fake_tensor_utils import _CacheKeyState
|
||||
from torch._subclasses.fake_tensor import extract_tensor_metadata
|
||||
|
|
@ -1280,243 +1520,23 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
from torch._higher_order_ops.while_loop import _create_unbacked_symint
|
||||
return _call_while_loop(self, tx, args, kwargs, stack_output=False)
|
||||
|
||||
from . import TensorVariable
|
||||
|
||||
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
|
||||
cond_fn, body_fn, operands, additional_inputs = args
|
||||
class WhileLoopStackOutputHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
supports_input_mutation = False
|
||||
supports_aliasing = False
|
||||
|
||||
# Input checks
|
||||
for i, k in enumerate(["cond_fn", "body_fn", "operands"]):
|
||||
if v := kwargs.pop(k, None):
|
||||
assert i == len(args), (
|
||||
"did not provide the right number of non-keyword args"
|
||||
)
|
||||
args.append(v)
|
||||
|
||||
if kwargs:
|
||||
unimplemented(
|
||||
f"torch.while_loop: Got unexpected kwargs: {list(kwargs.keys())}"
|
||||
)
|
||||
|
||||
if len(args) != 4:
|
||||
unimplemented(
|
||||
f"Expected 4 arguments but got {len(args)}.\n"
|
||||
f"Usage: while_loop(cond_fn, body_fn, operands)",
|
||||
)
|
||||
|
||||
# cond_fn and body_fn input check
|
||||
_check_supported_callable_arg(tx, cond_fn, "cond_fn")
|
||||
_check_supported_callable_arg(tx, body_fn, "body_fn")
|
||||
|
||||
# operands input check
|
||||
operands_seq = operands.unpack_var_sequence(tx)
|
||||
|
||||
# additional_inputs input check
|
||||
if not isinstance(additional_inputs, (ListVariable, TupleVariable)):
|
||||
unimplemented(
|
||||
f"Expected additional_inputs to be a list/tuple but got "
|
||||
f"{additional_inputs.python_type()}. It seems to be an "
|
||||
f"internal error, please report an issue to PyTorch."
|
||||
)
|
||||
additional_inputs_seq = additional_inputs.unpack_var_sequence(tx)
|
||||
|
||||
with discard_graph_changes(tx):
|
||||
# Note: this must be run under discard graph changes.
|
||||
def unspecialize_carried_inputs(tx, carry) -> VariableTracker:
|
||||
# See NOTE [unspecialize int carry with unbacked symints]
|
||||
if (
|
||||
isinstance(carry, ConstantVariable) and carry.python_type() is int
|
||||
) or isinstance(carry, SymNodeVariable):
|
||||
example_value = _create_unbacked_symint(
|
||||
tx.output.fake_mode, ignore_fresh_unbacked_symbols=True
|
||||
)
|
||||
proxy = tx.output.current_tracer.create_graph_input(
|
||||
"unbacked_symint", type(example_value), example_value
|
||||
)
|
||||
return SymNodeVariable.create(tx, proxy, example_value)
|
||||
else:
|
||||
# See NOTE [unspecialize constant tensor carry]
|
||||
assert isinstance(carry, TensorVariable)
|
||||
cloned_carry = carry.clone()
|
||||
cloned_carry.proxy.node.meta["example_value"].constant = None
|
||||
return cloned_carry
|
||||
|
||||
# clone inputs across subgraphs, to avoid unbacked memoization in fake prop
|
||||
cond_operands_seq = [
|
||||
unspecialize_carried_inputs(
|
||||
tx,
|
||||
(
|
||||
carry.call_method(tx, "clone", args=(), kwargs={})
|
||||
if isinstance(carry, TensorVariable)
|
||||
else carry
|
||||
),
|
||||
)
|
||||
for carry in operands_seq
|
||||
]
|
||||
body_operands_seq = [
|
||||
unspecialize_carried_inputs(
|
||||
tx,
|
||||
(
|
||||
carry.call_method(tx, "clone", args=(), kwargs={})
|
||||
if isinstance(carry, TensorVariable)
|
||||
else carry
|
||||
),
|
||||
)
|
||||
for carry in operands_seq
|
||||
]
|
||||
|
||||
# create cond subgrpahs
|
||||
(
|
||||
(cond_r, _cond_spec),
|
||||
cond_graph,
|
||||
cond_lifted_freevars,
|
||||
) = speculate_subgraph(
|
||||
tx,
|
||||
cond_fn,
|
||||
cond_operands_seq + additional_inputs_seq,
|
||||
{},
|
||||
"while_loop",
|
||||
source_target=self.value,
|
||||
# NOTE [why we cannot use "automatic" for while_loop]:
|
||||
# The reason is that we want to enforce
|
||||
# the ordering of inputs and outputs to be consistent and the the ordering
|
||||
# of cond_fn and body_fn to the consistent.
|
||||
# e.g. suppose we use "automatic" and we have:
|
||||
#
|
||||
# def body_fn(ph1, ph2):
|
||||
# new_a, new_b = ph2.cos(), ph1.sin()
|
||||
# return new_a, new_b
|
||||
#
|
||||
# a, b = torch.randn(3), torch.randn(3)
|
||||
# new_a, new_b = body_fn(a, b)
|
||||
#
|
||||
# Using automatic, the ordering of arguments will be the order that they're
|
||||
# used. In this example, the capture graph looks like:
|
||||
#
|
||||
# def captured_body(ph1, ph2):
|
||||
# new_a, new_b = ph1.cos(), ph2.add_(1)
|
||||
# return new_a, new_b
|
||||
#
|
||||
# This is fine when we change the calling convention of captured_body to be
|
||||
# new_a, new_b = captured_body(b, a).
|
||||
# But for while_loop, the next iteration's input is previous iteration output
|
||||
# we'll end up feeding captured_body(new_a, new_b) instead.
|
||||
# So it's best we always enforce the ordering of carried_inputs the same as outputs
|
||||
# with "flatten_manual".
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
supports_input_mutation=self.supports_input_mutation,
|
||||
supports_aliasing=self.supports_aliasing,
|
||||
)
|
||||
cond_nn_modules = dict(tx.output.nn_modules)
|
||||
validate_subgraph_output_types(cond_r)
|
||||
if isinstance(cond_r, TensorVariable):
|
||||
cond_r_meta = _extract_tensor_metadata(
|
||||
cond_r.proxy.node.meta["example_value"], include_contiguity=False
|
||||
)
|
||||
if (
|
||||
not cond_r_meta.dtype == torch.bool
|
||||
or not cond_r_meta.shape == torch.Size([])
|
||||
):
|
||||
unimplemented(
|
||||
f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}"
|
||||
)
|
||||
elif isinstance(cond_r, ConstantVariable):
|
||||
# short-circuiting while_loop when cond_fn returns a constant such as 0, 1 True or False
|
||||
pred = cond_r.as_python_constant()
|
||||
if pred:
|
||||
unimplemented(
|
||||
f"Infinite loop detected because while_loop's cond_fn always returns the same value {pred}"
|
||||
)
|
||||
else:
|
||||
return operands
|
||||
|
||||
# create body subgraph
|
||||
(
|
||||
(body_r, body_spec),
|
||||
body_graph,
|
||||
body_lifted_freevars,
|
||||
) = speculate_subgraph(
|
||||
tx,
|
||||
body_fn,
|
||||
body_operands_seq + additional_inputs_seq,
|
||||
{},
|
||||
"while_loop",
|
||||
source_target=self.value,
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
should_flatten_outputs=True,
|
||||
# TODO - removing consts from control flow ops need more work
|
||||
remove_consts_from_outputs=False,
|
||||
supports_input_mutation=False,
|
||||
supports_aliasing=False,
|
||||
)
|
||||
validate_subgraph_output_types(body_r)
|
||||
|
||||
# We set include contiguity=False because we have vmap x HOP tests, where if
|
||||
# include_contiguity=True will call t.is_contiguous inside of vmap and get an error
|
||||
# "querying is_contiguous inside of vmap for memory_format other than
|
||||
# torch.contiguous_format is not yet implemented". This is okay because stride
|
||||
# is still checked.
|
||||
check_meta_consistency_vt(
|
||||
body_r.unpack_var_sequence(tx),
|
||||
operands_seq,
|
||||
"body_fn_output",
|
||||
"carried_inputs",
|
||||
include_contiguity=False,
|
||||
)
|
||||
|
||||
(
|
||||
cond_graph,
|
||||
body_graph,
|
||||
cond_shared,
|
||||
_body_shared,
|
||||
cond_unique,
|
||||
body_unique,
|
||||
) = _merge_graph_inputs(
|
||||
cond_graph,
|
||||
cond_lifted_freevars,
|
||||
"cond_fn",
|
||||
body_graph,
|
||||
body_lifted_freevars,
|
||||
"body_fn",
|
||||
)
|
||||
|
||||
# Note: cond_shared and body_shared refer to the same proxy in parent graph
|
||||
# so using either of them is OK. Use cond_shared as it doesn't matter.
|
||||
additional_lifted_inputs = cond_shared + cond_unique + body_unique
|
||||
|
||||
body_nn_modules = dict(tx.output.nn_modules)
|
||||
|
||||
cond_name = tx.output.install_subgraph(
|
||||
"cond_fn",
|
||||
torch.fx.GraphModule(cond_nn_modules, cond_graph),
|
||||
)
|
||||
body_name = tx.output.install_subgraph(
|
||||
"body_fn",
|
||||
torch.fx.GraphModule(body_nn_modules, body_graph),
|
||||
)
|
||||
|
||||
cond_node = make_attr(tx, cond_name)
|
||||
body_node = make_attr(tx, body_name)
|
||||
|
||||
p_args = (
|
||||
cond_node,
|
||||
body_node,
|
||||
tuple([operand.as_proxy() for operand in operands_seq]),
|
||||
tuple(
|
||||
[inp.as_proxy() for inp in additional_inputs_seq]
|
||||
+ additional_lifted_inputs
|
||||
),
|
||||
)
|
||||
return _call_function_and_unflatten_output(
|
||||
tx,
|
||||
torch.ops.higher_order.while_loop,
|
||||
p_args,
|
||||
{},
|
||||
None,
|
||||
body_spec,
|
||||
)
|
||||
@raise_hard_error_if_graph_break(
|
||||
reason="while_loop_stack_output doesn't work unless it is captured completely with torch.compile."
|
||||
)
|
||||
def _call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
return _call_while_loop(self, tx, args, kwargs, stack_output=True)
|
||||
|
||||
|
||||
class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
|
|
@ -3481,6 +3501,7 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
|||
_hop_name_to_variable_class = {
|
||||
"cond": CondHigherOrderVariable,
|
||||
"while_loop": WhileLoopHigherOrderVariable,
|
||||
"while_loop_stack_output": WhileLoopStackOutputHigherOrderVariable,
|
||||
"map_impl": MapHigherOrderVariable,
|
||||
"executorch_call_delegate": ExecutorchCallDelegateHigherOrderVariable,
|
||||
"out_dtype": OutDtypeHigherOrderVariable,
|
||||
|
|
|
|||
|
|
@ -27,7 +27,10 @@ from torch._higher_order_ops.run_const_graph import run_const_graph
|
|||
from torch._higher_order_ops.scan import scan
|
||||
from torch._higher_order_ops.strict_mode import strict_mode
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
from torch._higher_order_ops.while_loop import while_loop
|
||||
from torch._higher_order_ops.while_loop import (
|
||||
while_loop,
|
||||
while_loop_stack_output_op as while_loop_stack_output,
|
||||
)
|
||||
from torch._higher_order_ops.wrap import (
|
||||
dynamo_bypassing_wrapper,
|
||||
tag_activation_checkpoint,
|
||||
|
|
@ -69,4 +72,5 @@ __all__ = [
|
|||
"strict_mode",
|
||||
"aoti_call_delegate",
|
||||
"map",
|
||||
"while_loop_stack_output",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -260,7 +261,9 @@ def while_loop(cond_fn, body_fn, carried_inputs):
|
|||
|
||||
|
||||
@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
def while_loop_dense(
|
||||
cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False
|
||||
):
|
||||
carried_vals = carried_inputs
|
||||
|
||||
def _validate_cond_output(pred):
|
||||
|
|
@ -285,13 +288,25 @@ def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
|
|||
_validate_cond_output(should_loop)
|
||||
|
||||
if not should_loop:
|
||||
return tuple(
|
||||
val.clone() if isinstance(val, torch.Tensor) else val
|
||||
for val in carried_vals + additional_inputs
|
||||
)
|
||||
if stack_output:
|
||||
return tuple(
|
||||
val.unsqueeze(0).clone() if isinstance(val, torch.Tensor) else val
|
||||
for val in carried_vals
|
||||
)
|
||||
else:
|
||||
return tuple(
|
||||
val.clone() if isinstance(val, torch.Tensor) else val
|
||||
for val in carried_vals
|
||||
)
|
||||
|
||||
outputs: list[list[torch.Tensor]] = [[] for _ in carried_vals]
|
||||
|
||||
while should_loop:
|
||||
out = body_fn(*carried_vals, *additional_inputs)
|
||||
if stack_output:
|
||||
for i, o in enumerate(out):
|
||||
outputs[i].append(o)
|
||||
|
||||
assert isinstance(out, tuple), (
|
||||
f"body_fn should return a tuple but got {type(out)}"
|
||||
)
|
||||
|
|
@ -302,6 +317,12 @@ def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
|
|||
|
||||
should_loop = cond_fn(*carried_vals, *additional_inputs)
|
||||
|
||||
if stack_output:
|
||||
outs: list[torch.Tensor] = []
|
||||
for i, out in enumerate(outputs):
|
||||
outs.append(torch.stack(out, dim=0))
|
||||
return tuple(outs)
|
||||
|
||||
return carried_vals
|
||||
|
||||
|
||||
|
|
@ -336,9 +357,18 @@ def _create_unbacked_symint(
|
|||
|
||||
|
||||
@while_loop_op.py_impl(ProxyTorchDispatchMode)
|
||||
def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
def while_loop_tracing(
|
||||
mode,
|
||||
cond_fn,
|
||||
body_fn,
|
||||
carried_inputs,
|
||||
additional_inputs,
|
||||
stack_output=False,
|
||||
):
|
||||
op = while_loop_stack_output_op if stack_output else while_loop_op
|
||||
|
||||
def _trace_while_loop(
|
||||
proxy_mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
|
||||
proxy_mode, op, cond_fn, body_fn, carried_inputs, additional_inputs
|
||||
):
|
||||
# NOTE [unspecialize int carry with unbacked symints]
|
||||
# When we support int carry, we'll also need to support int output of body_fn because.
|
||||
|
|
@ -437,10 +467,10 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
|
|||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
||||
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", while_loop_op, proxy_args, {}, name="while_loop"
|
||||
"call_function", op, proxy_args, {}, name=op._name
|
||||
)
|
||||
|
||||
out = while_loop_op(
|
||||
out = op(
|
||||
cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs
|
||||
)
|
||||
return track_tensor_tree(
|
||||
|
|
@ -448,13 +478,18 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
|
|||
)
|
||||
|
||||
return _trace_while_loop(
|
||||
mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
|
||||
mode,
|
||||
op,
|
||||
cond_fn,
|
||||
body_fn,
|
||||
carried_inputs,
|
||||
additional_inputs,
|
||||
)
|
||||
|
||||
|
||||
@while_loop_op.py_impl(FakeTensorMode)
|
||||
def while_loop_fake_tensor_mode(
|
||||
mode, cond_fn, body_fn, carried_inputs, additional_inputs
|
||||
mode, cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False
|
||||
):
|
||||
with mode:
|
||||
# NOTE: [Handling unback symints in subgraph of while_loop]
|
||||
|
|
@ -499,6 +534,26 @@ def while_loop_fake_tensor_mode(
|
|||
"body_output",
|
||||
include_contiguity=False,
|
||||
)
|
||||
|
||||
if stack_output:
|
||||
n_iter = _create_unbacked_symint(mode, ignore_fresh_unbacked_symbols=False)
|
||||
assert all(isinstance(x, torch.Tensor) for x in carried_inputs)
|
||||
fake_outputs = tuple(
|
||||
out.clone()
|
||||
.unsqueeze(0)
|
||||
.repeat((n_iter,) + tuple(1 for _ in range(out.dim())))
|
||||
for out in body_outs
|
||||
)
|
||||
return pytree.tree_map_only(
|
||||
(int, torch.SymInt),
|
||||
# For while_loop's unbacked symint output, we want them to be bound
|
||||
# to the proxy of while_loop's output.
|
||||
lambda _: _create_unbacked_symint(
|
||||
mode, ignore_fresh_unbacked_symbols=False
|
||||
),
|
||||
fake_outputs,
|
||||
)
|
||||
|
||||
# See NOTE [unspecialize int carry with unbacked symints]
|
||||
return pytree.tree_map_only(
|
||||
(int, torch.SymInt),
|
||||
|
|
@ -512,9 +567,13 @@ def while_loop_fake_tensor_mode(
|
|||
|
||||
|
||||
@while_loop_op.py_functionalize_impl
|
||||
def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
def while_loop_func(
|
||||
ctx, cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False
|
||||
):
|
||||
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
||||
|
||||
op = while_loop_stack_output_op if stack_output else while_loop_op
|
||||
|
||||
unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
|
||||
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
|
||||
unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs
|
||||
|
|
@ -527,10 +586,72 @@ def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
|
|||
(body_fn, "body_fn"),
|
||||
]:
|
||||
_check_alias_and_mutation(fn, unwrapped_inputs, fn_name, pre_dispatch)
|
||||
ret = while_loop_op(
|
||||
ret = op(
|
||||
functional_cond_fn,
|
||||
functional_body_fn,
|
||||
unwrapped_carried_inputs,
|
||||
unwrapped_additional_inputs,
|
||||
)
|
||||
return ctx.wrap_tensors(ret)
|
||||
|
||||
|
||||
class WhileLoopStackOutputOp(HigherOrderOperator):
|
||||
"""
|
||||
while_loop_stack_output is a variant of while_loop that returns a stack of outputs.
|
||||
Its semantic can be illurated using python code as:
|
||||
def while_loop_stack_output(cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
outs = []
|
||||
while cond_fn(*carried_inputs, *additional_inputs):
|
||||
out = body_fn(*carried_inputs, *additional_inputs)
|
||||
outs.append(out)
|
||||
return torch.stack(outs)
|
||||
|
||||
It's useful for supporting autograd of while_loop.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("while_loop_stack_output")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
cond_fn: Callable,
|
||||
body_fn: Callable,
|
||||
carried_inputs: tuple[Union[torch.Tensor, int, float, bool]],
|
||||
additional_inputs: tuple[Union[torch.Tensor, torch.SymInt, int], ...],
|
||||
/,
|
||||
):
|
||||
if not isinstance(carried_inputs, (tuple, list)):
|
||||
raise RuntimeError(
|
||||
f"carried_inputs must be a tuple or list, got {type(carried_inputs)}"
|
||||
)
|
||||
if not isinstance(additional_inputs, (tuple, list)):
|
||||
raise RuntimeError(
|
||||
f"additional_inputs must be a tuple or list, got {type(additional_inputs)}"
|
||||
)
|
||||
|
||||
validate_subgraph_args_types(carried_inputs)
|
||||
validate_subgraph_args_types(additional_inputs)
|
||||
return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
|
||||
|
||||
|
||||
while_loop_stack_output_op = WhileLoopStackOutputOp()
|
||||
|
||||
while_loop_stack_output_op.py_impl(DispatchKey.CompositeExplicitAutograd)(
|
||||
functools.partial(while_loop_dense, stack_output=True)
|
||||
)
|
||||
|
||||
while_loop_stack_output_op.py_impl(ProxyTorchDispatchMode)(
|
||||
functools.partial(while_loop_tracing, stack_output=True)
|
||||
)
|
||||
|
||||
while_loop_stack_output_op.py_impl(FakeTensorMode)(
|
||||
functools.partial(while_loop_fake_tensor_mode, stack_output=True)
|
||||
)
|
||||
|
||||
while_loop_stack_output_op.py_functionalize_impl(
|
||||
functools.partial(while_loop_func, stack_output=True)
|
||||
)
|
||||
|
||||
while_loop_stack_output_op.py_autograd_impl(
|
||||
autograd_not_implemented(while_loop_stack_output_op, deferred_error=True)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1947,7 +1947,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
finally:
|
||||
self.pop_codegened_graph()
|
||||
|
||||
def codegen_while_loop(self, while_loop):
|
||||
def codegen_while_loop(self, while_loop, stack_output=False):
|
||||
if stack_output:
|
||||
raise NotImplementedError("NYI cpp wrapper for while_loop_stack_output")
|
||||
is_bool_pred = isinstance(
|
||||
while_loop.cond_subgraph.graph.graph_outputs[0], ir.ShapeAsConstantBuffer
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3370,7 +3370,9 @@ class PythonWrapperCodegen(CodeGen):
|
|||
self.codegen_subgraph(conditional.false_subgraph, outer_inputs, name)
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
|
||||
def codegen_while_loop(self, while_loop):
|
||||
def codegen_while_loop(self, while_loop, stack_output):
|
||||
"""while_loop is codegened as a host side while_loop"""
|
||||
|
||||
def codegen_subgraph(subgraph, outer_inputs, outer_outputs):
|
||||
"""Helper method to deduplicate subgraph codegen logic"""
|
||||
if V.graph.aot_mode:
|
||||
|
|
@ -3388,7 +3390,13 @@ class PythonWrapperCodegen(CodeGen):
|
|||
buf.codegen_reference() for buf in while_loop.additional_inputs
|
||||
]
|
||||
|
||||
ckp_offset = len(outer_carried_inputs)
|
||||
self.writeline(f"{name} = [None] * {len(outer_carried_inputs)}")
|
||||
if stack_output:
|
||||
self.writeline(
|
||||
f"{name}.extend([[] for _ in range({len(outer_carried_inputs)})])"
|
||||
)
|
||||
|
||||
for i, inp in enumerate(outer_carried_inputs):
|
||||
# set the initial state before the loop
|
||||
self.writeline(f"{name}[{i}] = {inp}")
|
||||
|
|
@ -3411,10 +3419,21 @@ class PythonWrapperCodegen(CodeGen):
|
|||
)
|
||||
self.writeline(f"should_loop = {cond_outer_outputs[0]}")
|
||||
self.writeline("if not should_loop:")
|
||||
for i, (carried_input, carried_buf) in enumerate(
|
||||
zip(outer_carried_inputs, while_loop.carried_inputs)
|
||||
):
|
||||
self.writeline(f" {name}[{i}] = {carried_input}.clone()")
|
||||
if stack_output:
|
||||
# Handle the case when loop never executes
|
||||
for i, (carried_input, carried_buf) in enumerate(
|
||||
zip(outer_carried_inputs, while_loop.carried_inputs)
|
||||
):
|
||||
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
|
||||
self.writeline(f"{name}[{i}] = {carried_input}.unsqueeze(0).clone()")
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
else:
|
||||
for i, (carried_input, carried_buf) in enumerate(
|
||||
zip(outer_carried_inputs, while_loop.carried_inputs)
|
||||
):
|
||||
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
|
||||
self.writeline(f"{name}[{i}] = {carried_input}.clone()")
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
|
||||
self.writeline("while should_loop:")
|
||||
# Body execution
|
||||
|
|
@ -3424,6 +3443,13 @@ class PythonWrapperCodegen(CodeGen):
|
|||
)
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
|
||||
# Collect outputs if enabled
|
||||
if stack_output:
|
||||
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
|
||||
for i in range(len(outer_carried_inputs)):
|
||||
self.writeline(f"{name}[{i + ckp_offset}].append({name}[{i}])")
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
|
||||
# Condition check at end of loop
|
||||
self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph))
|
||||
codegen_subgraph(
|
||||
|
|
@ -3432,6 +3458,17 @@ class PythonWrapperCodegen(CodeGen):
|
|||
self.writeline(ExitSubgraphLine(self))
|
||||
self.writeline(f" should_loop = {cond_outer_outputs[0]}")
|
||||
|
||||
# Stack outputs after loop completion
|
||||
if stack_output:
|
||||
self.writeline("# Stack outputs after loop completion")
|
||||
for i in range(len(outer_carried_inputs)):
|
||||
self.writeline(f"if len({name}[{i + ckp_offset}]) > 0:")
|
||||
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
|
||||
self.writeline(
|
||||
f"{name}[{i}] = torch.stack({name}[{i + ckp_offset}], dim=0)"
|
||||
)
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
|
||||
@staticmethod
|
||||
def statically_known_int_or_none(x):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -8431,6 +8431,12 @@ class Conditional(ExternKernel):
|
|||
self.name = V.graph.register_buffer(self)
|
||||
V.graph.register_operation(self)
|
||||
|
||||
@staticmethod
|
||||
def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]:
|
||||
if isinstance(s, int):
|
||||
return s
|
||||
return s.node.expr
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
|
|
@ -8497,18 +8503,15 @@ class Conditional(ExternKernel):
|
|||
unbacked_bindings=unbacked_bindings,
|
||||
)
|
||||
|
||||
def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]:
|
||||
if isinstance(s, int):
|
||||
return s
|
||||
return s.node.expr
|
||||
|
||||
outputs = [
|
||||
MultiOutput(
|
||||
FixedLayout(
|
||||
device=device,
|
||||
dtype=output.get_dtype(),
|
||||
size=[_maybe_expr(sz) for sz in merged_output.size()],
|
||||
stride=[_maybe_expr(sz) for sz in merged_output.stride()],
|
||||
size=[Conditional._maybe_expr(sz) for sz in merged_output.size()],
|
||||
stride=[
|
||||
Conditional._maybe_expr(sz) for sz in merged_output.stride()
|
||||
],
|
||||
offset=output.get_layout().offset,
|
||||
is_pinned=output.get_layout().is_pinned,
|
||||
),
|
||||
|
|
@ -8558,7 +8561,7 @@ def _split_by_sym_type(
|
|||
|
||||
@ir_dataclass(frozen=False)
|
||||
class WhileLoop(ExternKernel):
|
||||
"""IR node for while_loop, which supports input mutations"""
|
||||
"""The IR node for while_loop and while_loop_stack_output. It supports input mutation."""
|
||||
|
||||
carried_inputs: Optional[Sequence[IRNode]] = None
|
||||
additional_inputs: Optional[Sequence[IRNode]] = None
|
||||
|
|
@ -8573,6 +8576,8 @@ class WhileLoop(ExternKernel):
|
|||
cond_subgraph: Subgraph,
|
||||
body_subgraph: Subgraph,
|
||||
layout: MultiOutputLayout,
|
||||
unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
|
||||
stack_output: bool,
|
||||
) -> None:
|
||||
self.carried_inputs = carried_inputs
|
||||
self.additional_inputs = additional_inputs
|
||||
|
|
@ -8588,6 +8593,9 @@ class WhileLoop(ExternKernel):
|
|||
inputs=tensor_args,
|
||||
constant_args=sym_args,
|
||||
)
|
||||
if unbacked_bindings is not None:
|
||||
self.unbacked_bindings = unbacked_bindings
|
||||
self.stack_output = stack_output
|
||||
|
||||
self.name = V.graph.register_buffer(self)
|
||||
V.graph.register_operation(self)
|
||||
|
|
@ -8631,7 +8639,11 @@ class WhileLoop(ExternKernel):
|
|||
body_fn: Subgraph,
|
||||
carried_inputs: Sequence[IRNode],
|
||||
additional_inputs: Sequence[IRNode],
|
||||
stack_output: bool,
|
||||
) -> Union[IRNode, Sequence[IRNode]]:
|
||||
"""create the while_loop IR node. stack_output controls whether it stack
|
||||
each iterations' output, which is necessary for training.
|
||||
"""
|
||||
from torch._higher_order_ops.utils import check_input_alias_and_mutation
|
||||
|
||||
def _require_exact_strides(
|
||||
|
|
@ -8740,6 +8752,12 @@ class WhileLoop(ExternKernel):
|
|||
assert op.get_dtype() == bo.get_dtype(), (i, op, bo)
|
||||
|
||||
assert device is not None
|
||||
|
||||
unbacked_bindings = resolve_unbacked_bindings(
|
||||
V.graph.sizevars.shape_env,
|
||||
V.graph.current_node.meta.get("unbacked_bindings", None),
|
||||
)
|
||||
|
||||
while_loop = WhileLoop(
|
||||
carried_inputs=carried_inputs_,
|
||||
additional_inputs=additional_inputs_,
|
||||
|
|
@ -8747,6 +8765,8 @@ class WhileLoop(ExternKernel):
|
|||
body_subgraph=body_fn,
|
||||
# asserted above that there is at least one operand
|
||||
layout=MultiOutputLayout(device=device),
|
||||
unbacked_bindings=unbacked_bindings,
|
||||
stack_output=stack_output,
|
||||
)
|
||||
|
||||
assert body_fn.graph is not None and isinstance(
|
||||
|
|
@ -8762,34 +8782,51 @@ class WhileLoop(ExternKernel):
|
|||
|
||||
# Create all outputs first
|
||||
mutated_inputs_iter = iter(mutated_inputs)
|
||||
all_outputs = []
|
||||
all_outputs: list[IRNode] = []
|
||||
while_loop.outputs = []
|
||||
while_loop.mutation_outputs = []
|
||||
|
||||
for idx, output in enumerate(body_outputs):
|
||||
if idx in mutated_idx_set:
|
||||
assert idx < len(carried_inputs), "only carries can be mutated."
|
||||
# Create MutationOutput for mutated inputs
|
||||
mutated_input = next(mutated_inputs_iter)
|
||||
while_loop.mutation_outputs.append(
|
||||
MutationOutput(mutated_input.layout, mutated_input, while_loop) # type: ignore[attr-defined, union-attr]
|
||||
)
|
||||
all_outputs.append(mutated_input)
|
||||
else:
|
||||
if stack_output:
|
||||
assert len(mutated_idx_set) == 0, (
|
||||
"NYI: while_loop_stack_output input mutations."
|
||||
)
|
||||
for idx, output in enumerate(V.graph.current_node.meta["val"]):
|
||||
# Create MultiOutput for regular outputs
|
||||
multi_out = MultiOutput(
|
||||
FixedLayout(
|
||||
device=output.get_device(), # type: ignore[arg-type]
|
||||
dtype=output.get_dtype(),
|
||||
size=output.get_size(),
|
||||
stride=output.get_stride(),
|
||||
offset=output.get_layout().offset,
|
||||
device=output.device, # type: ignore[arg-type]
|
||||
dtype=output.dtype,
|
||||
size=[Conditional._maybe_expr(sz) for sz in output.size()],
|
||||
stride=[Conditional._maybe_expr(st) for st in output.stride()],
|
||||
),
|
||||
while_loop,
|
||||
[(list, idx)],
|
||||
)
|
||||
while_loop.outputs.append(multi_out)
|
||||
all_outputs.append(multi_out)
|
||||
else:
|
||||
for idx, output in enumerate(body_outputs):
|
||||
if idx in mutated_idx_set:
|
||||
assert idx < len(carried_inputs), "only carries can be mutated."
|
||||
# Create MutationOutput for mutated inputs
|
||||
mutated_input = next(mutated_inputs_iter)
|
||||
while_loop.mutation_outputs.append(
|
||||
MutationOutput(mutated_input.layout, mutated_input, while_loop) # type: ignore[attr-defined, union-attr]
|
||||
)
|
||||
all_outputs.append(mutated_input)
|
||||
else:
|
||||
multi_out = MultiOutput(
|
||||
FixedLayout(
|
||||
device=output.get_device(), # type: ignore[arg-type]
|
||||
dtype=output.get_dtype(),
|
||||
size=output.get_size(),
|
||||
stride=output.get_stride(),
|
||||
offset=output.get_layout().offset,
|
||||
),
|
||||
while_loop,
|
||||
[(list, idx)],
|
||||
)
|
||||
while_loop.outputs.append(multi_out)
|
||||
all_outputs.append(multi_out)
|
||||
|
||||
for inp, out in zip(carried_inputs, all_outputs):
|
||||
if inp.get_name() in V.graph.graph_inputs:
|
||||
|
|
@ -8802,7 +8839,20 @@ class WhileLoop(ExternKernel):
|
|||
return all_outputs
|
||||
|
||||
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
|
||||
wrapper.codegen_while_loop(self)
|
||||
wrapper.codegen_while_loop(self, self.stack_output)
|
||||
wrapper.codegen_unbacked_symbol_defs_for_outputs(
|
||||
self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {})
|
||||
)
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
if unbacked_bindings := getattr(self, "unbacked_bindings", None):
|
||||
resolved = resolve_unbacked_bindings(
|
||||
V.graph.sizevars.shape_env, unbacked_bindings
|
||||
)
|
||||
assert resolved is not None
|
||||
return OrderedSet(resolved.keys())
|
||||
else:
|
||||
return OrderedSet()
|
||||
|
||||
|
||||
class EffectfulKernel(FallbackKernel):
|
||||
|
|
|
|||
|
|
@ -7042,7 +7042,7 @@ def cond(pred, true_fn, false_fn, operands):
|
|||
|
||||
|
||||
@register_lowering(torch.ops.higher_order.while_loop, type_promotion_kind=None)
|
||||
def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False):
|
||||
if any(
|
||||
isinstance(x, IRNode) and is_triton(x)
|
||||
for x in carried_inputs + additional_inputs
|
||||
|
|
@ -7062,11 +7062,18 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
|
|||
else:
|
||||
raise RuntimeError(f"NYI unsupported output type: {type(out)}")
|
||||
|
||||
result = ir.WhileLoop.create(cond_fn, body_fn, carried_inputs, additional_inputs)
|
||||
result = ir.WhileLoop.create(
|
||||
cond_fn, body_fn, carried_inputs, additional_inputs, stack_output
|
||||
)
|
||||
assert isinstance(result, Sequence)
|
||||
return list(map(_map_output, result))
|
||||
|
||||
|
||||
register_lowering(
|
||||
torch.ops.higher_order.while_loop_stack_output, type_promotion_kind=None
|
||||
)(functools.partial(while_loop, stack_output=True))
|
||||
|
||||
|
||||
@register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None)
|
||||
def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands):
|
||||
result = ir.InvokeSubgraph.create(subgraph_fn, *operands)
|
||||
|
|
|
|||
|
|
@ -202,6 +202,15 @@ def simple_while_loop(iter_t, x):
|
|||
|
||||
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x))
|
||||
|
||||
def simple_while_loop_stack_output(iter_t, x):
|
||||
def cond_fn(iter_t, x):
|
||||
return iter_t > 0
|
||||
|
||||
def body_fn(iter_t, x):
|
||||
return iter_t - 1, x.cos()
|
||||
|
||||
return torch._higher_order_ops.while_loop_stack_output(cond_fn, body_fn, (iter_t, x), tuple())
|
||||
|
||||
|
||||
def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(
|
||||
|
|
@ -374,6 +383,19 @@ hop_db = [
|
|||
check_inplace_batched_forward_grad=False,
|
||||
supports_autograd=False,
|
||||
),
|
||||
OpInfo(
|
||||
name="while_loop_stack_output",
|
||||
variant_test_name="simple",
|
||||
op=simple_while_loop_stack_output,
|
||||
sample_inputs_func=sample_inputs_while_loop,
|
||||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
supports_out=False,
|
||||
check_batched_grad=False,
|
||||
check_batched_gradgrad=False,
|
||||
check_batched_forward_grad=False,
|
||||
check_inplace_batched_forward_grad=False,
|
||||
supports_autograd=False,
|
||||
),
|
||||
OpInfo(
|
||||
name="auto_functionalize",
|
||||
variant_test_name="simple",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user