[scan] create fw and bw graphs via partitioning (#162754)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162754
Approved by: https://github.com/zou3519
ghstack dependencies: #161557, #161664, #161808, #162025, #161732
This commit is contained in:
Yidi Wu 2025-09-26 19:28:13 -07:00 committed by PyTorch MergeBot
parent 3413490f53
commit 8f6dbc0ba8
6 changed files with 752 additions and 427 deletions

View File

@ -1830,34 +1830,34 @@ def forward(self, pred_1, x_1):
]
)
init_clone = [i.clone() for i in init]
init_clone2 = [i.clone() for i in init]
elements_clone = [ele.clone() for ele in elements]
elements_clone2 = [ele.clone() for ele in elements]
result = scan_fct(
get_scan_combine_fn("s5_operator", False),
init,
elements,
dim=0,
init_clone,
elements_clone,
reverse=reverse,
)
expected_result = _fake_scan(
get_scan_combine_fn("s5_operator", False),
init=init,
xs=elements,
dim=0,
init_clone2,
elements_clone2,
reverse=reverse,
)
self.assertEqual(result, expected_result)
if autograd:
init_flatten, _ = pytree.tree_flatten(init)
elements_flatten, _ = pytree.tree_flatten(elements)
result_flatten, _ = pytree.tree_flatten(result)
result_exp_flatten, _ = pytree.tree_flatten(expected_result)
grad_out = [torch.ones_like(el) for el in result_exp_flatten]
expected_grads = torch.autograd.grad(
result_exp_flatten, (*init_flatten, *elements_flatten), grad_out
result_exp_flatten, (*init_clone2, *elements_clone2), grad_out
)
grads = torch.autograd.grad(
result_flatten, (*init_flatten, *elements_flatten), grad_out
result_flatten, (*init_clone, *elements_clone), grad_out
)
self.assertEqual(grads, expected_grads)
@ -2757,9 +2757,7 @@ class GraphModule(torch.nn.Module):
l_init_1_ = L_init_1_
l_xs_ = L_xs_
elem: "f32[3, 10, 2]" = torch.movedim(l_xs_, 0, 0); l_xs_ = None
flip: "f32[3, 10, 2]" = torch.flip(elem, [0]); elem = None
flip: "f32[3, 10, 2]" = torch.flip(l_xs_, [0]); l_xs_ = None
scan_combine_fn_0 = self.scan_combine_fn_0
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_0_, l_init_1_], [flip], []); scan_combine_fn_0 = l_init_0_ = l_init_1_ = flip = None
@ -3457,8 +3455,7 @@ class GraphModule(torch.nn.Module):
gm.code.strip(),
"""\
def forward(self, fct_1, init_1, xs_1):
permute = torch.ops.aten.permute.default(xs_1, [0, 1, 2])
flip = torch.ops.aten.flip.default(permute, [0]); permute = None
flip = torch.ops.aten.flip.default(xs_1, [0])
sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(init_1, 2)
sym_size_int_3 = torch.ops.aten.sym_size.int(xs_1, 1)
@ -3482,8 +3479,7 @@ def forward(self, fct_1, init_1, xs_1):
def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor):
l_init_ = L_init_
l_xs_ = L_xs_
elem = torch.movedim(l_xs_, 0, 0); l_xs_ = None
flip = torch.flip(elem, [0]); elem = None
flip = torch.flip(l_xs_, [0]); l_xs_ = None
scan_combine_fn_0 = self.scan_combine_fn_0
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [flip], []); scan_combine_fn_0 = l_init_ = flip = None
carry = scan[0]
@ -8069,9 +8065,8 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
l_xs_ = L_xs_
l_add_closure_0_cell_contents_0_param_ = L_add_closure_0_cell_contents_0_param_
l_add_closure_0_cell_contents_1_0_ = L_add_closure_0_cell_contents_1_0_
r = torch.movedim(l_xs_, 0, 0); l_xs_ = None
scan_combine_fn_0 = self.scan_combine_fn_0
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [r], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = r = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = l_xs_ = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
carry = scan[0]
out = scan[1]; scan = None
return (carry, out)""", # noqa: B950
@ -8085,9 +8080,8 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
l_xs_ = L_xs_
l_add_closure_0_cell_contents_0_param_ = L_add_closure_0_cell_contents_0_param_
l_add_closure_0_cell_contents_1_0_ = L_add_closure_0_cell_contents_1_0_
movedim = torch.movedim(l_xs_, 0, 0); l_xs_ = None
scan_combine_fn_0 = self.scan_combine_fn_0
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [movedim], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = movedim = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = l_xs_ = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
carry = scan[0]
out = scan[1]; scan = None
return (carry, out)""", # noqa: B950

View File

@ -1561,6 +1561,8 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
for user in node.users:
if (
must_recompute(user)
and "ac_graph_id" in user.meta
and "ac_graph_id" in node.meta
and user.meta["ac_graph_id"] > node.meta["ac_graph_id"]
):
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE

View File

@ -0,0 +1,364 @@
import logging
from typing import Any, Callable, Union
import torch
from torch._higher_order_ops.utils import create_bw_fn, materialize_as_graph
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def _find_hop_subgraph_outputs(gm: torch.fx.GraphModule) -> tuple[torch.fx.Node]:
output_node_args = gm.graph.find_nodes(op="output")[0].args
assert isinstance(output_node_args, tuple)
return output_node_args[0]
def is_complex_expr(expr: Any) -> bool:
return not expr.is_symbol and not expr.is_constant()
class HopPartitionedGraph:
def __init__(
self,
fw_gm: torch.fx.GraphModule,
bw_gm: torch.fx.GraphModule,
n_fw_outputs: int,
n_intermediates: int,
no_complex_exprs_at_boundary: bool,
):
self.fw_gm = fw_gm
self.bw_gm = bw_gm
self.n_fw_outputs = n_fw_outputs
self.n_intermediates = n_intermediates
self.no_complex_exprs_at_boundary = no_complex_exprs_at_boundary
self._reorder_fw_output()
self._check_partition_boundary()
def _check_partition_boundary(self) -> None:
"""check partitioned graph is in valid state."""
invalid_reasons = []
fw_outputs = _find_hop_subgraph_outputs(self.fw_gm)
for i, out in enumerate(fw_outputs):
if "val" not in out.meta:
invalid_reasons.append(f"fw_gm output[{i}] doesn't have a 'val' meta.")
elif not isinstance(out.meta["val"], (torch.SymInt, torch.Tensor)):
invalid_reasons.append(
f"fw_gm output[{i}] is of type {type(out.meta['val'])} but only SymInt or Tensor are allowed."
)
elif (
isinstance(out.meta["val"], torch.SymInt)
and is_complex_expr(out.meta["val"].node.expr)
and self.no_complex_exprs_at_boundary
):
invalid_reasons.append(
f"fw_gm output[{i}] must be of type SymInt with basic symbols or "
f"Tensor but got {type(out.meta['val'])} {out.meta['val']}"
)
if len(fw_outputs) != self.n_fw_outputs + self.n_intermediates:
invalid_reasons.append(
f"len(fw_outputs) ({len(fw_outputs)}) != n_fw_outputs ({self.n_fw_outputs}) + n_intermediates ({self.n_intermediates})" # noqa: B950
)
bw_phs = list(self.bw_gm.graph.find_nodes(op="placeholder"))
if len(fw_outputs) != len(bw_phs):
invalid_reasons.append(
f"Expect number of fw_gm's output to be the same as bw_gm's input but "
f"fw_gm has {len(fw_outputs)} outputs, bw_gm takes {len(bw_phs)} inputs."
)
original_forward_outputs = fw_outputs[: self.n_fw_outputs]
fw_intermediates = fw_outputs[self.n_fw_outputs :]
bw_intermediates = bw_phs[: -self.n_fw_outputs]
bw_grads = bw_phs[-self.n_fw_outputs :]
def _match_size_or_expr(
val1: Union[torch.SymInt, torch.Tensor],
val2: Union[torch.SymInt, torch.Tensor],
) -> bool:
if type(val1) != type(val2):
return False
if isinstance(val1, torch.SymInt) and isinstance(val2, torch.SymInt):
return val1.node.expr == val2.node.expr
elif isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor):
return val1.size() == val2.size()
return False
for fw, bw in zip(fw_intermediates, bw_intermediates):
if fw.name != bw.name or not _match_size_or_expr(
fw.meta["val"], bw.meta["val"]
):
invalid_reasons.append("fw intermediates don't match bw intermediates")
for fw_out, bw_grad in zip(original_forward_outputs, bw_grads):
if not _match_size_or_expr(fw_out.meta["val"], bw_grad.meta["val"]):
invalid_reasons.append("fw outputs don't match bw gradients")
if len(invalid_reasons) > 0:
newline = "\n"
raise RuntimeError(
"Invalid HopPartitionedGraph. Reasons:\n",
f"{newline.join(invalid_reasons)}",
)
def _reorder_fw_output(self) -> None:
"""
Before the pass, fw_gm returns (*fw_outputs, *intermediates1)
and bw_gm takes (*intermediates2, *grad_fw_outputs) as input.
intermediates1 and intermediates2 share the same node names but
they might be in different order. E.g. this could happen if there
are inputs that contain symints.
To simplify downstream processing, this graph pass normalizes the output of fw_gm
to be consistent with the bacwkard inputs:
fw_gm:
- input: fw_args
- output: (*fw_outputs, *intermediates)
bw_gm:
- input: (*intermediates, *grad_fw_outputs)
- output: grad_fw_args
Example:
def fw_gm(x, y, z):
a, b, c = f(x), g(y), k(z)
return a, b, c, f_tmp, g_tmp, k_tmp
, where a, b, c are fw_outputs, f_tmp, g_tmp, k_tmp are intermediates
The corresponding bw_gm has the following signature:
def bw_gm(f_tmp, g_tmp, k_tmp, grad_a, grad_b, grac):
return grad_x, grad_y, grad_z
"""
fw_gm_output_nodes = _find_hop_subgraph_outputs(self.fw_gm)
fw_outputs_nodes = fw_gm_output_nodes[: self.n_fw_outputs]
fw_intermediates_nodes = fw_gm_output_nodes[self.n_fw_outputs :]
if len(fw_intermediates_nodes) > 0:
fw_intermediates_name_to_node = {n.name: n for n in fw_intermediates_nodes}
# First n_intermediates placeholders
bw_names: list[str] = [
ph.name
for ph in list(self.bw_gm.graph.find_nodes(op="placeholder"))[
: self.n_intermediates
]
]
new_fw_outputs = list(fw_outputs_nodes) + [
fw_intermediates_name_to_node[name] for name in bw_names
]
output_node = self.fw_gm.graph.find_nodes(op="output")[0]
output_node.args = (tuple(new_fw_outputs),)
self.fw_gm.graph.lint()
self.fw_gm.recompile()
class HopJointGraph:
def __init__(
self,
joint_gm: torch.fx.GraphModule,
n_primals: int,
n_fw_outputs: int,
*,
functionalized: bool,
):
self.joint_gm = joint_gm
self.n_primals = n_primals
self.n_fw_outputs = n_fw_outputs
self.functionalized = functionalized
self._rename_phs()
self._remove_redundant_sym_size_ops()
def _rename_phs(self) -> None:
"""
Rename the placeholders for joint_gm so that the partitioner
could recognize which inputs are primals and which are tangents.
"""
self.n_tangents = 0
for i, ph in enumerate(self.joint_gm.graph.find_nodes(op="placeholder")):
if i < self.n_primals:
ph.target = f"primals_{i}"
ph.name = f"primals_{i}"
else:
self.n_tangents += 1
ph.target = f"tangents_{i - self.n_primals}"
ph.name = f"tangents_{i - self.n_primals}"
self.joint_gm.graph.lint()
self.joint_gm.compile()
def _remove_redundant_sym_size_ops(self) -> None:
"""
Deletes torch.ops.sym_size.int operators whose output is a
corresponding placeholder that holds the same symbol, and replace all usage
of the sym_size node to be directly using the placeholders.
This is to make sure all basic symbols come from inputs.
"""
placeholder_exprs = {}
for node in self.joint_gm.graph.nodes:
if (
isinstance(node, torch.fx.Node)
and node.op == "placeholder"
and hasattr(node, "meta")
and "val" in node.meta
):
val = node.meta["val"]
if isinstance(val, torch.SymInt):
placeholder_exprs[val.node.expr] = node
nodes_to_remove = []
for node in self.joint_gm.graph.find_nodes(
op="call_function", target=torch.ops.aten.sym_size.int
):
assert hasattr(node, "meta") and "val" in node.meta, node
val = node.meta["val"]
expr = val.node.expr
if expr in placeholder_exprs:
placeholder_node = placeholder_exprs[expr]
node.replace_all_uses_with(placeholder_node)
nodes_to_remove.append(node)
for node in nodes_to_remove:
self.joint_gm.graph.erase_node(node)
self.joint_gm.graph.lint()
self.joint_gm.recompile()
def _mark_complex_exprs_as_must_recompute(self) -> None:
"""
For control flow operators such as scan, we don't want to
have symint in the partitioning boundaries because otherwise we would need to support stacking
the symints up, which causes more entropy in the stack.
By marking the recompute polify for complex nodes as MUST_RECOMPUTE, the partitioning boundary
no longer contains complex expressions.
Note that this pass doesn't exclude basic symbols from partitioning boundary
and it's up to the downstream to decide whether to return the basic symbol
or have a separate graph pass to remove them.
"""
from torch._functorch.partitioners import CheckpointPolicy
for n in (
node for node in self.joint_gm.graph.nodes if node.op == "call_function"
):
if "val" not in n.meta:
continue
val = n.meta["val"]
if isinstance(val, torch.SymInt) and is_complex_expr(val.node.expr):
assert n.meta.get("recompute", None) is None
n.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE
self.joint_gm.graph.lint()
self.joint_gm.recompile()
def partition(
self, partition_fn: Callable, always_recompute_complex_exprs: bool
) -> HopPartitionedGraph:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"before min_cut_partition:\n%s",
self.joint_gm.print_readable(print_output=False),
)
if always_recompute_complex_exprs:
self._mark_complex_exprs_as_must_recompute()
fw_gm, bw_gm = partition_fn(
self.joint_gm, None, num_fwd_outputs=self.n_fw_outputs
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("after partition_fn:")
logger.debug("fw_gm:\n%s", fw_gm.print_readable(print_output=False))
logger.debug("bw_gm:\n%s", bw_gm.print_readable(print_output=False))
n_intermediates = len(_find_hop_subgraph_outputs(fw_gm)) - self.n_fw_outputs
return HopPartitionedGraph(
fw_gm,
bw_gm,
self.n_fw_outputs,
n_intermediates,
always_recompute_complex_exprs,
)
def create_hop_joint_graph(
fw_fn: Callable,
fw_args: tuple[Union[torch.Tensor, torch.SymInt], ...],
functionalize: bool,
) -> HopJointGraph:
fw_gm = materialize_as_graph(fw_fn, fw_args, force_enable_grad=True)
fw_gm_output_nodes = _find_hop_subgraph_outputs(fw_gm)
assert all(
isinstance(n, torch.fx.Node) and "val" in n.meta for n in fw_gm_output_nodes
)
fw_gm_output_vals = tuple(n.meta["val"] for n in fw_gm_output_nodes) # type: ignore[arg-type]
assert all(isinstance(val, torch.Tensor) for val in fw_gm_output_vals)
example_grads = tuple(torch.zeros_like(val) for val in fw_gm_output_vals)
joint_fn = create_bw_fn(fw_fn, fw_args, return_fw_outputs=True)
joint_gm = materialize_as_graph(
joint_fn, fw_args + example_grads, force_enable_grad=True
)
if functionalize:
# Need to first trace out the joint_fn with autograd info on
# then functionalize the graph otherwise the grad information is lost
joint_gm = materialize_as_graph(
torch.func.functionalize(joint_gm, remove="mutations_and_views"),
fw_args + example_grads,
)
return HopJointGraph(
joint_gm,
len(fw_args),
len(fw_gm_output_nodes),
functionalized=functionalize,
)
class HopGraphMinCutPartitioner:
@staticmethod
def create_partitioned_graph(
fw_fn: Callable,
fw_args: tuple[Union[torch.Tensor, torch.SymInt], ...],
*,
always_recompute_complex_exprs: bool = False,
) -> HopPartitionedGraph:
"""
Inputs:
- fw_fn: the forward function that we'll use to create a joint graph and partition
- fw_args: the flat_args to fw_fn
- always_recompute_complex_exprs: when set to True, the bw_gm will do a re-compute
for inputs that are complex expressions such that the partitioning boundary
only consists of basic symbols and tensors.
Returns a HopPartitionedGraph
"""
from torch._functorch.partitioners import min_cut_rematerialization_partition
joint_graph: HopJointGraph = create_hop_joint_graph(
fw_fn, fw_args, functionalize=True
)
return joint_graph.partition(
min_cut_rematerialization_partition, always_recompute_complex_exprs
)

View File

@ -1,26 +1,30 @@
# mypy: allow-untyped-defs
import enum
import functools
import itertools
import logging
from typing import Any, Callable
import torch
import torch._prims_common as utils
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.partitioner import (
_find_hop_subgraph_outputs,
HopGraphMinCutPartitioner,
HopPartitionedGraph,
)
from torch._higher_order_ops.utils import (
_maybe_compile_and_run_fn,
check_input_alias_and_mutation_return_outputs,
check_meta_consistency,
create_bw_fn,
fill_none_with_masks,
filter_with_masks,
first_slice_copy,
first_slice_copy_with_grad,
get_tensor_mask,
mask_list,
materialize_as_graph,
reenter_make_fx,
save_tensors_and_symints_for_backward,
saved_tensors_and_symints,
split_into_chunks,
unique_graph_id,
validate_subgraph_args_types,
@ -35,6 +39,7 @@ from torch.fx.experimental.proxy_tensor import (
from torch.utils._python_dispatch import _get_current_dispatch_mode
logger: logging.Logger = logging.getLogger(__name__)
aten = torch._ops.ops.aten
@ -42,7 +47,7 @@ def wrap_combine_fn_flat(
*args, combine_fn, spec_init, spec_xs, num_init_leaves, num_inp_leaves
):
assert len(args) == (num_init_leaves + num_inp_leaves), (
f"Combin_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}"
f"combine_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}"
)
carry = pytree.tree_unflatten(args[:num_init_leaves], spec_init)
xs = pytree.tree_unflatten(args[num_init_leaves:], spec_xs)
@ -182,7 +187,7 @@ def scan(
# Move scan dim to 0 and always perform scan on dim 0
leaves_xs = []
for elem in leaves_xs_orig:
leaves_xs.append(torch.movedim(elem, dim, 0))
leaves_xs.append(torch.movedim(elem, dim, 0) if dim != 0 else elem)
if reverse:
leaves_xs = [torch.flip(elem, [0]) for elem in leaves_xs]
@ -429,121 +434,6 @@ def scan_op_dense(combine_fn, init, xs, additional_inputs):
class ScanAutogradOp(torch.autograd.Function):
"""
Example ::
def combine_fn(x: torch.Tensor, y: torch.Tensor):
next_carry = y = x * y
return next_carry, y
The ``combine_fn_bw``, computing the gradients for x and y of ``combine_fn`` is computed as:
def combine_fn_bw(x: torch.Tensor, y: torch.Tensor, g_carry: torch.Tensor, g_y: torch.Tensor):
return g_y * y + g_carry * y, g_y * x + g_carry * x
Note: In a real usecase of scan, there may be additional_inputs that participate in the
forward as well as in the backward of the scan operator. For the sake of readability those inputs
have been omitted in the following example, but are included in the subsequent detailed description below
The forward output of scan is computed as:
carry, ys = scan(combine_fn, init, xs).
This computation can be unpacked as
c_0, ys_0 = combine_fn(init, xs_0)
c_1, ys_1 = combine_fn(carry_0, xs_1)
c_2, ys_2 = combine_fn(carry_1, xs_2)
...
c_T, ys_T = combine_fn(carry_(T-1), xs_T)
We collect c_0, c_1, ..., c_T into a vector of carries that we save for the backward,
but we only output (c_T, ys),
where ys is the vector of all intermediate outputs [y_0, y_1, ..., y_T].
Given the carries and the ys, the gradients for xs and for init can be computed as follows:
We receive the upstream gradients in torch.autograd.Function, i.e., we get g_c_T and g_ys,
where g_ys is the vector of all intermediate gradients of the outputs [g_ys_0, g_ys_1, ..., g_ys_T]
We then proceed to compute the gradients for the init (g_init) and the xs (g_xs) by running a
scan operation reverse over time. For example,
g_c_(T-1), g_xs_T = combine_fn_bw(c_(T-1), xs_T, g_c_T, g_ys_T)
g_c_(T-2), g_xs_(T-1) = combine_fn_bw(c_(T-2), xs_(T-1), g_c_(T-1), g_ys_(T-1))
g_c_(T-3), g_xs_(T-2) = combine_fn_bw(c_(T-3), xs_(T-2), g_c_(T-2), g_ys_(T-2))
...
g_init, g_xs_1 = combine_fn_bw(c_0, xs_1, g_c_0, g_ys_1)
0 , g_xs_0 = combine_fn_bw(init, xs_0, g_init, g_ys_0),
where combine_fn_bw takes the forward inputs of step t (i.e. c_(t-1), xs_t),
the gradients of the carry of step t (i.e. g_c_t) and
the upstream gradient of the output of step t (i.e. g_ys_T)
and returns the gradient of xs_t -> g_xs_t, as well as the gradient for the carry of step t-1 -> g_c_(t-1).
Through this procedure we end up with the
gradients for the init -> g_init,
the gradients for the xs -> g_xs.
NOTE: [scan autograd implementation]
The forward of scan can be computed as:
1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint``:
To use a scan operation for the backward path as well, we need access to the carries from all steps.
Thus, the function ``combine_fn`` is wrapped such that it returns all carries and not only the last carry.
In particular, we define ``combine_fn_with_carry_checkpoint``:
def combine_fn_with_carry_checkpoint(x: torch.Tensor, y: torch.Tensor):
carry, y = combine_fn(x, y)
return carry, (carry, y)
The scan operator will stack all outputs along the scan dimension.
Thus, by putting next_carry also into outputs of ``combine_fn_with_carry_checkpoint``,
the carries from all steps will be stacked and hence gives us chekpointed_carries
2.) Compute all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint``:
c_T, (carries, ys) = scan_op(combine_fn_with_carry_checkpoint, init, xs, additional_inputs),
Where c_T (last carry) and ys (all outputs) are the original results of scan with the ``combine_fn``.
However, carries are checkpointed carries from all steps.
As a result of the forward, only the last carry c_T and the ys are returned,
while all carries are saved for the backward.
The backward of scan can be computed as:
3.) Prepare the backward graph:
We prepare the backward graph to be used in the backward function.
We utilize ``create_bw_fn`` to generate the joint function, i.e.,
ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands), where fw_operands = [init, xs_0, additional_inputs]
The ctx._combine_fn_bw requires the primals (operands)
followed by the tangents (upstream gradients) from a single step
and produces the gradients of that step, i.e.,
g_c_(T-1), g_xs_T, g_additional_input_T = ctx._combine_fn_bw(c_(T-1), xs_T, additional_inputs, g_c_T, g_ys_T).
4.) Create a wrapper of the ``combine_fn_bw``, i.e., ``combine_fn_bw_grad_accumulation``:
In the forward, there may be additional inputs that participate in every forward step.
The gradients for those additional inputs are also computed at every step and need to be accumulated over all steps,
which is taken care of in this wrapper. For example:
def combine_fn_bw_grad_accumulation(*args):
carried_g_additional_input = args[:num_additional_inputs]
inputs_bw_fn = args[num_additional_inputs:]
g_c_(t-1), g_xs_t, g_additional_input_t = ctx._combine_fn_bw(*inputs_bw_fn)
new_g_additional_inputs = carried_g_additional_input + g_additional_input_t
# The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator
# The ``g_xs_t`` is encoded as the output of the backward scan operator
return [*new_g_additional_inputs, *g_c_t, *g_xs_t]
5.) Perform the backward scan as
g_additional_inputs, g_init, g_xs = scan_op(combine_fn_bw_grad_accumulation, bw_init, bw_xs), where
bw_init consists of the initial gradient carry for the additional_inputs (initialized with 0s):
initial_g_additional_inputs, and the gradient of the last carry: g_c_T. Thus:
bwd_init = [*initial_g_additional_inputs, *g_c_T].
bw_xs consists of the combination of the upstream gradients g_ys,
the forward carries prepended with the fw_init, i.e., bw_carries = concat([fw_init, fw_carries[:-1]]) and
the fw_xs. In particular,
bwd_xs = [*g_ys, *bw_carries, *fw_xs].
Note: g_c_T and g_ys are provided through the torch.autograd.Function.backward's input
As demonstrated in the Example above, this backward scan then yields the gradient for the init -> g_init
and the gradient for the xs -> g_xs
NOTE: [scan partial grad handling]
If any element of init, of xs, of the outputs or of the additional_inputs does not require gradients,
i.e., requires_grad=False, there will be still gradients returned for those elements,
@ -559,300 +449,361 @@ class ScanAutogradOp(torch.autograd.Function):
@staticmethod
def forward(
ctx,
combine_fn,
num_leaves_init,
num_leaves_xs,
num_additional_inputs,
hop_partitioned_graph,
n_init,
n_xs,
n_additional_inputs,
*operands,
):
ctx._num_leaves_init = num_leaves_init
ctx._num_leaves_xs = num_leaves_xs
ctx._num_additional_inputs = num_additional_inputs
ctx._combine_fn = combine_fn
init, xs, additional_inputs = split_into_chunks(
operands, [num_leaves_init, num_leaves_xs, num_additional_inputs]
operands, [n_init, n_xs, n_additional_inputs]
)
additional_inputs_tensor_mask = get_tensor_mask(additional_inputs)
ctx._additional_inputs_tensor_mask = additional_inputs_tensor_mask
# We snapshot the dispatch keys in forward for materializing the
# the bw_graph in backward.
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
# 1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint``
# The wrapper of the forward graph returns carries from all iterations,
# not just from the last iteration. These are required in the backward path
def combine_fn_with_carry_checkpoint(*args):
carry, y = _extract_carry_and_out(combine_fn(*args), num_leaves_init)
return [
*carry,
# We additionally checkpoint all the intermediate carry outputs for backward.
*[n_c.detach().clone() for n_c in carry],
*y,
]
# Materialize the ``combine_fn_with_carry_checkpoint`` with enable_grad
# we need enable_grad to support torch.func.grad_and_value
# in subgraph.
gm = materialize_as_graph(
combine_fn_with_carry_checkpoint,
(*init, *[x[0] for x in xs], *additional_inputs),
force_enable_grad=True,
ctx._scan_impl = ScanAutogradImpl(
hop_partitioned_graph, init, xs, additional_inputs
)
with torch._C._AutoDispatchBelowAutograd():
# 2.) Compute the all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint``
c_T, carries_ys = _extract_carry_and_out(
scan_op(
gm,
init,
xs,
additional_inputs,
),
num_leaves_init,
)
# Collect the carries for each time step from the outs
# and save them for the backward path
carries = list(carries_ys[:num_leaves_init])
ys = list(carries_ys[num_leaves_init:])
save_tensors_and_symints_for_backward(ctx, list(operands) + carries + ys)
ctx._num_leaves_ys = len(ys)
return (*c_T, *ys)
return ctx._scan_impl.call_forward()
@staticmethod
def backward(ctx, *flat_grads):
r"""
This function computes the gradients of the scan operation.
It does so by using a scan operator using all carries and the upstream gradients (see description above)
Args:
flat_grads (torch.Tensor): The tensor of flattened upstream gradients.
"""
from torch._higher_order_ops.utils import fill_none_with_masks
# Collect the saved items from the forward
num_leaves_init = ctx._num_leaves_init
num_leaves_xs = ctx._num_leaves_xs
num_leaves_ys = ctx._num_leaves_ys
num_additional_inputs = ctx._num_additional_inputs
additional_inputs_tensor_mask = ctx._additional_inputs_tensor_mask
def prepend_init_to_carries(init, carries):
# Prepare the carries for the backward path.
# This requires to concatenate the init and the carries
return [
torch.cat([torch.unsqueeze(i, 0), c[:-1]], dim=0)
for i, c in zip(init, carries)
]
def initialize_g_additional_inputs(
additional_inputs,
):
g_additional_inputs = [
torch.zeros_like(ai)
for ai in filter_with_masks(
additional_inputs, additional_inputs_tensor_mask
)
]
return g_additional_inputs
# Retrieve the forward inputs and the forward outputs and dissect them
flat_args = saved_tensors_and_symints(ctx)
fw_init, fw_xs, additional_inputs, fw_carries, fw_ys = split_into_chunks(
flat_args,
[
num_leaves_init,
num_leaves_xs,
num_additional_inputs,
num_leaves_init,
num_leaves_ys,
],
)
# 3.) Prepare the backward graph
fw_operands = (
*fw_init,
*[first_slice_copy(xs) for xs in fw_xs],
*additional_inputs,
)
ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands)
# Initialize the g_additional_inputs with zero-tensors.
# This step is necessary because the gradients of the additional inputs are accumulated in the
# ``wrapper_bwd_combine_fn`` and thus need a zero-initialized starting point
initial_g_additional_inputs = initialize_g_additional_inputs(additional_inputs)
# 4.) Create the BW wrapper to accumulate the gradients for the additional_inputs
def combine_fn_bw_grad_accumulation(*args):
# Dissect args and re-order them for the ``ctx._combine_fn_bw``
# The content of ``combine_fn_bw_tangents`` is [*carries_g, *outs_g]
# The content of ``combine_fn_bw_primals`` is [*init, *xs, *additional_inputs]
(
carried_g_additional_input,
combine_fn_bw_tangents,
combine_fn_bw_primals,
) = split_into_chunks(
args,
[
len(initial_g_additional_inputs),
num_leaves_init + num_leaves_ys,
num_leaves_init + num_leaves_xs + num_additional_inputs,
],
)
combine_fn_bw_args = (*combine_fn_bw_primals, *combine_fn_bw_tangents)
g_c_t, g_xs_t, g_additional_inputs_t = split_into_chunks(
ctx._combine_fn_bw(*combine_fn_bw_args),
[num_leaves_init, num_leaves_xs, num_additional_inputs],
)
new_g_additional_inputs = [
# If the additional inputs are ints or SymInts, those values are taken as is and no gradients are added
carr_g + curr_g
for carr_g, curr_g in zip(
carried_g_additional_input,
filter_with_masks(
g_additional_inputs_t, additional_inputs_tensor_mask
),
)
]
assert all(isinstance(t, torch.Tensor) for t in new_g_additional_inputs)
# The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator
# The ``g_xs_t`` is encoded as the output of the backward scan operator
return [*new_g_additional_inputs, *g_c_t, *g_xs_t]
# Materialize the ``combine_fn_bw_grad_accumulation``
def construct_args_single_step_bw():
# This function constructs the arguments for a single step of the backward scan.
# In other words, it creates the arguments for ``combine_fn_bw_grad_accumulation``
# The order of the arguments returned is identical to the order the backward scan
# operations provides
# The following arguments are used for the backward part of the joint graph
# The first argument relates to the gradient accumulation of the additional inputs.
# Because only tensor elements of additional inputs can have requires_grad=True,
# the values for non-tensor elements of additional inputs are None
masked_additional_inputs = [
a.clone()
for a in filter_with_masks(
additional_inputs, additional_inputs_tensor_mask
)
]
# The second argument relates to the gradients of the carries.
# Because the arguments are for a single step only,
# only the first slice of the carries is used.
sliced_carries = [first_slice_copy(c) for c in fw_carries]
# The third argument relates to the gradients of the ys.
# Because the arguments are for a single step only,
# only the first slice of the ys is used.
sliced_ys = [first_slice_copy(o) for o in fw_ys]
# The following arguments are used for the forward part of the joint graph
# The fourth argument relates to the init for the forward.
# I.e., fw_init
# The fifth argument relates to the xs for the forward.
# Because the arguments are for a single step only,
# only the first slice of the xs is used.
# Note: It is important to preserve the requires_grad flag of xs
# and thus we use the wrapper function ``first_slice_copy_with_grad``
fw_xs_slice = first_slice_copy_with_grad(fw_xs)
# The last argument relates to the additional inputs for the forward.
# I.e., additional_inputs
return (
*masked_additional_inputs,
*sliced_carries,
*sliced_ys,
*fw_init,
*fw_xs_slice,
*additional_inputs,
)
args_single_step_bw = construct_args_single_step_bw()
# TODO: we need to materialize the bw graphs because dynamo is unable to
# trace through the joint function when torch.compile torch.autograd.grad.
combine_fn_bw_grad_accumulation_gm = materialize_as_graph(
combine_fn_bw_grad_accumulation,
args_single_step_bw,
ctx._fw_include_key_set,
ctx._fw_exclude_key_set,
force_enable_grad=True,
)
# Decompose the flat_grads into g_c_T, g_ys
g_c_T, g_ys = split_into_chunks(flat_grads, [num_leaves_init, num_leaves_ys])
assert all(
isinstance(t, torch.Tensor) and t.dtype.is_floating_point for t in g_c_T
)
assert all(
isinstance(t, torch.Tensor) and t.dtype.is_floating_point for t in g_ys
)
# Prepend the inits to the carries.
# This is needed, because when computing the gradients, the last carry is not needed
# but the first carry, the init, is required.
bw_carries = prepend_init_to_carries(fw_init, fw_carries)
# Prepare the xs for the backward scan.
bwd_xs = [*g_ys, *bw_carries, *fw_xs]
# The flipping of the ``bwd_xs`` is necessary because the scan_op in the backward is always performed in reverse
bwd_xs = [torch.flip(elem, [0]) for elem in bwd_xs]
# Prepare the bwd_init
bwd_init = [*initial_g_additional_inputs, *g_c_T]
# 5.) Perform the backward scan:
# The ``combine_fn_bw_wrapped`` receives the
# initial_g_additional_inputs and the last carry as the ``bwd_init`` and the
# gradients of the outputs (g_ys), as well as the fw_carries and the fw_xs of the forward as the ``bwd_xs``
gradients = scan_op(
combine_fn_bw_grad_accumulation_gm,
bwd_init,
bwd_xs,
additional_inputs,
)
# Unpack the computed gradients
g_additional_inputs, g_init, g_xs = split_into_chunks(
gradients,
[len(initial_g_additional_inputs), num_leaves_init, num_leaves_xs],
)
# The flipping back along the scan dimension is required to get the gradients in the right order for ``xs``
g_xs = [torch.flip(elem, [0]) for elem in g_xs]
def backward(ctx, *grad_fw_outputs):
return (
*[None] * 4,
*g_init,
*g_xs,
*fill_none_with_masks(g_additional_inputs, additional_inputs_tensor_mask),
None,
None,
None,
None,
*ctx._scan_impl.call_backward(*grad_fw_outputs),
)
class ScanForwardIntermediatesHandlingPolicy(enum.Enum):
"""
Partitioner can add interemdiates to the output of original graph.
These intermediates fall into 4 categories and we want to have different policies for handling them by
modifying the graph:
CLONE: we clone the intermediate when it is a carried input (i.e. init). In this case, this carry will be
replaced with new values at each forward step so we need to clone the carry as part of return (i.e. ys)
so as to remove the aliasing and that each step's intermediate will be stacked together and saved in bacwkard.
REMOVE_XS: we remove the intermediate from output when it is part of xs. Since xs is read-only, in this case,
we can directly save them for backward to use.
REMOVE_ADDITIONAL_INPUTS: we remove the intermediate from output when it is part of additinonal_inputs. additional_inputs
are also read-only in each step, we can directly save them for bacwkard to use. We differentiate XS and ADDITIONAL_INPUTS
so that we could have different treatment for them in backward. In backward, we need to put xs intermediates in carry but
put additional_inputs as backward scan's additional_inputs.
KEEP: this corresponds to a real intermediate tensor operations' output. It varies at each forward step, we could just keep
it as part of ys.
"""
KEEP = 0
CLONE = 1
REMOVE_XS = 2
REMOVE_ADDITIONAL_INPUTS = 3
class ScanAutogradImpl:
"""
Wraps over partitioned graph and encapsulates scan-specific implementation details
"""
def __init__(
self, hop_partitioned_graph: HopPartitionedGraph, init, xs, additional_inputs
):
self.hop_partitioned_graph = hop_partitioned_graph
self.init = init
self.xs = xs
self.additional_inputs = additional_inputs
self.forward_intermediates_handling_policies: list[
ScanForwardIntermediatesHandlingPolicy
] = []
self.saved_fw_xs: list[Any] = []
self.saved_fw_additional_inputs: list[Any] = []
self.saved_intermediates: list[Any] = []
self.fw_spec = pytree.tree_flatten((init, xs, additional_inputs))[1]
self._optimize_forward_intermediates()
def _insert_clone(
self, need_copy_node: torch.fx.Node, output_node: torch.fx.Node
) -> torch.fx.Node:
graph: torch.fx.Graph = output_node.graph
with graph.inserting_before(output_node):
clone_node = graph.call_function(
torch.ops.aten.clone.default,
args=(need_copy_node,),
)
clone_node.meta = (
need_copy_node.meta.copy() if hasattr(need_copy_node, "meta") else {}
)
return clone_node
def _optimize_forward_intermediates(self):
"""
We optimize the forward intermediates by categorize forward intermediates into categories
and construct a ScanForwardIntermediatesHandlingPolicy for them
"""
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Need remove aliasing in fw_gm:\n%s",
self.hop_partitioned_graph.fw_gm.print_readable(print_output=False),
)
fw_gm = self.hop_partitioned_graph.fw_gm
fw_all_outputs = _find_hop_subgraph_outputs(fw_gm)
phs = list(fw_gm.graph.find_nodes(op="placeholder"))
fw_outputs = fw_all_outputs[: self.hop_partitioned_graph.n_fw_outputs]
fw_intermediates = fw_all_outputs[self.hop_partitioned_graph.n_fw_outputs :]
init_phs, xs_phs, additional_inputs_phs = pytree.tree_unflatten(
phs, self.fw_spec
)
init_node_set, xs_node_set, addi_node_set = (
set(init_phs),
set(xs_phs),
set(additional_inputs_phs),
)
assert len(self.forward_intermediates_handling_policies) == 0
assert len(self.saved_fw_xs) == 0
assert len(self.saved_fw_additional_inputs) == 0
intermediate_idx_to_ph_idx = {}
ph_idx = {ph: i for i, ph in enumerate(phs)}
for i, out in enumerate(fw_intermediates):
if out in init_node_set:
self.forward_intermediates_handling_policies.append(
ScanForwardIntermediatesHandlingPolicy.CLONE
)
intermediate_idx_to_ph_idx[i] = ph_idx[out]
elif out in xs_node_set:
self.forward_intermediates_handling_policies.append(
ScanForwardIntermediatesHandlingPolicy.REMOVE_XS
)
intermediate_idx_to_ph_idx[i] = ph_idx[out]
elif out in addi_node_set:
self.forward_intermediates_handling_policies.append(
ScanForwardIntermediatesHandlingPolicy.REMOVE_ADDITIONAL_INPUTS
)
intermediate_idx_to_ph_idx[i] = ph_idx[out]
else:
self.forward_intermediates_handling_policies.append(
ScanForwardIntermediatesHandlingPolicy.KEEP
)
new_output_node = []
real_graph_inputs = (
list(self.init) + list(self.xs) + list(self.additional_inputs)
)
fw_output_node = next(iter(fw_gm.graph.find_nodes(op="output")))
for intermediate_idx, (node, policy) in enumerate(
zip(fw_intermediates, self.forward_intermediates_handling_policies)
):
if policy == ScanForwardIntermediatesHandlingPolicy.CLONE:
new_output_node.append(self._insert_clone(node, fw_output_node))
elif policy == ScanForwardIntermediatesHandlingPolicy.REMOVE_XS:
assert intermediate_idx in intermediate_idx_to_ph_idx
inp_idx = intermediate_idx_to_ph_idx[intermediate_idx]
self.saved_fw_xs.append(real_graph_inputs[inp_idx])
elif (
policy
== ScanForwardIntermediatesHandlingPolicy.REMOVE_ADDITIONAL_INPUTS
):
assert intermediate_idx in intermediate_idx_to_ph_idx
inp_idx = intermediate_idx_to_ph_idx[intermediate_idx]
self.saved_fw_additional_inputs.append(real_graph_inputs[inp_idx])
else:
new_output_node.append(node)
fw_output_node.args = (tuple(fw_outputs) + tuple(new_output_node),)
fw_gm.graph.lint()
fw_gm.recompile()
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"after removing aliasing:\n%s", fw_gm.print_readable(print_output=False)
)
def call_forward(self):
fw_outputs_and_intermediates: tuple[Any] = scan_op(
self.hop_partitioned_graph.fw_gm, self.init, self.xs, self.additional_inputs
) # type: ignore[return-type]
fw_outs = fw_outputs_and_intermediates[
: self.hop_partitioned_graph.n_fw_outputs
]
saved_intermediates = fw_outputs_and_intermediates[
self.hop_partitioned_graph.n_fw_outputs :
]
assert len(self.saved_intermediates) == 0
self.saved_intermediates.extend(saved_intermediates)
return tuple(fw_outs)
def call_backward(self, *grad_fw_outputs):
"""
Recall that fw_outputs = (*carry, *ys), bw_gm takes in (*fw_intermediates, *grad_carry, *grad_ys)
and returns (*grad_init, *grad_xs, *grad_additional_inputs)
The bacwkard is a reversed scan that can be constructed as follows:
grad_additonal_inputs = torch.zeros_like(additional_inputs)
bw_init = (grad_carry, grad_additional_inputs)
bw_xs = (fw_intermediates, grad_ys)
grad_init, grad_additional_inputs, grad_xs = scan(
combine_fn,
bw_init,
bw_xs,
reverse = True
)
, where combine_fn is defined as follows:
def combine_fn(bw_init, bw_xs):
grad_carry, grad_additional_inputs = bw_init
fw_intermediates, grad_y = bw_xs
nxt_grad_carry, grad_x, nxt_grad_additional_inputs = bw_gm(*fw_intermediates, *grad_carry, *grad_y)
return (nxt_grad_carry, grad_additional_inputs + nxt_grad_additional_inputs), grad_x
Note that grad_additional_inputs is accumulated with add, grad_carry is carried over to next iteration and
grad_x is the ys output, which will be stacked together after the loop and will have the same shape as xs.
"""
fw_policy = self.forward_intermediates_handling_policies
saved_intermediates = self.saved_intermediates
saved_fw_xs = self.saved_fw_xs
saved_fw_additional_inputs = self.saved_fw_additional_inputs
n_carry = len(self.init)
grad_carry, grad_ys = grad_fw_outputs[:n_carry], grad_fw_outputs[n_carry:]
additional_inputs_tensor_masks = [
True if isinstance(t, torch.Tensor) else False
for t in self.additional_inputs
]
grad_additional_inputs = [
torch.zeros_like(t)
for t in filter_with_masks(
self.additional_inputs, additional_inputs_tensor_masks
)
]
bw_init = [grad_carry, grad_additional_inputs]
bw_xs = [
grad_ys,
saved_fw_xs,
saved_intermediates,
]
bw_additional_inputs = saved_fw_additional_inputs
_, flat_spec = pytree.tree_flatten((bw_init, bw_xs, bw_additional_inputs))
grad_spec = None
def bw_single_step_wrapper(*args):
bw_init, bw_xs, bw_additional_inputs = pytree.tree_unflatten(
args, flat_spec
)
grad_carry, grad_additional_inputs = bw_init
grad_y, saved_fw_xs, saved_intermediates = bw_xs
saved_fw_additional_inputs = bw_additional_inputs
fw_intermediates = []
xs_it = iter(saved_fw_xs)
carry_it = iter(saved_intermediates)
addi_it = iter(saved_fw_additional_inputs)
for policy in fw_policy:
if policy in (
ScanForwardIntermediatesHandlingPolicy.CLONE,
ScanForwardIntermediatesHandlingPolicy.KEEP,
):
fw_intermediates.append(next(carry_it))
elif policy == ScanForwardIntermediatesHandlingPolicy.REMOVE_XS:
fw_intermediates.append(next(xs_it))
elif (
policy
== ScanForwardIntermediatesHandlingPolicy.REMOVE_ADDITIONAL_INPUTS
):
fw_intermediates.append(next(addi_it))
else:
raise RuntimeError(f"Unknown policy: {policy}")
grad_fw_outputs = (*grad_carry, *grad_y)
flat_out = self.hop_partitioned_graph.bw_gm(
*fw_intermediates,
*grad_fw_outputs,
)
next_grad_carry, grad_xs, grad_addi = split_into_chunks(
flat_out, # type: ignore[arg-type]
[len(self.init), len(self.xs), len(self.additional_inputs)],
)
nonlocal grad_spec
flat_grads, grad_spec = pytree.tree_flatten(
(
next_grad_carry,
[
prev + cur
for prev, cur in zip(
grad_additional_inputs,
filter_with_masks(
grad_addi, additional_inputs_tensor_masks
),
)
],
grad_xs,
)
)
return flat_grads
single_step_bw_xs = pytree.tree_map(lambda t: t[0], bw_xs)
bw_single_step_gm = materialize_as_graph(
bw_single_step_wrapper,
tuple(
pytree.tree_flatten((bw_init, single_step_bw_xs, bw_additional_inputs))[
0
]
),
)
flat_grads = scan_op(
bw_single_step_gm,
pytree.tree_flatten(bw_init)[0],
# TODO: torch.flip copies the tensor, we should optimize it away
[torch.flip(x, (0,)) for x in pytree.tree_flatten(bw_xs)[0]],
pytree.tree_flatten(bw_additional_inputs)[0],
)
assert grad_spec is not None
grad_init, grad_additional_inputs, grad_xs = pytree.tree_unflatten(
flat_grads, grad_spec
)
return (
*grad_init,
*[torch.flip(elem, (0,)) for elem in grad_xs],
*fill_none_with_masks(
grad_additional_inputs, additional_inputs_tensor_masks
),
)
@scan_op.py_autograd_impl
def scan_autograd(combine_fn, init, xs, additional_inputs):
num_leaves_init = len(init)
num_leaves_xs = len(xs)
num_additional_inputs = len(additional_inputs)
with disable_proxy_modes_tracing():
hop_partitioned_graph: HopPartitionedGraph = (
HopGraphMinCutPartitioner.create_partitioned_graph(
combine_fn,
(*init, *[x[0] for x in xs], *additional_inputs),
always_recompute_complex_exprs=True,
)
)
flat_out = ScanAutogradOp.apply(
combine_fn,
num_leaves_init,
num_leaves_xs,
num_additional_inputs,
*(tuple(init) + tuple(xs) + additional_inputs),
return ScanAutogradOp.apply(
hop_partitioned_graph,
len(init),
len(xs),
len(additional_inputs),
*init,
*xs,
*additional_inputs,
)
return *flat_out[:num_leaves_init], *flat_out[num_leaves_init:]
@scan_op.py_impl(ProxyTorchDispatchMode)

View File

@ -505,9 +505,13 @@ def prepare_fw_with_masks_all_requires_grad(fn):
lambda x: x.requires_grad_(True) if x.dtype.is_floating_point else x,
fw_out,
)
return fw_out, pytree.tree_map_only(
torch.Tensor, lambda x: x.requires_grad, fw_out
)
def _query_requires_grad(t: torch.Tensor) -> bool:
if torch._is_functional_tensor(t):
t = torch._from_functional_tensor(t)
return t.requires_grad
return fw_out, pytree.tree_map_only(torch.Tensor, _query_requires_grad, fw_out)
return fw_with_masks
@ -759,7 +763,9 @@ def _clone_aliasing_output(inputs: Sequence[Any], outputs: Sequence[Any]):
return final_outputs
def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
def create_bw_fn(
fn: Callable, args: tuple[Any, ...], return_fw_outputs: bool = False
) -> Callable:
"""
For a fn that accepts flat inputs and returns flat outputs:
fw_out = fn(*args),
@ -793,7 +799,7 @@ def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
def flat_fn(*args_and_grad_outs):
primals = args_and_grad_outs[:n_primals]
tangents = args_and_grad_outs[n_primals:]
grad_args = bw_fn(primals, tangents)[1]
fw_outs, grad_args = bw_fn(primals, tangents)
assert len(args) == len(grad_args)
# For tensors whose grad is None, create zero tensors as gradients
@ -806,6 +812,8 @@ def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
]
final_grads = _clone_aliasing_output(args_and_grad_outs, grad_args)
if return_fw_outputs:
return *fw_outs, *final_grads
return final_grads
return flat_fn
@ -1111,7 +1119,7 @@ def call_op(op: Union[OpOverload, HopInstance], args, kwargs):
def materialize_as_graph(
fn: Callable,
args: tuple[Any],
args: tuple[Any, ...],
include_key_set: Optional[torch._C.DispatchKeySet] = None,
exclude_key_set: Optional[torch._C.DispatchKeySet] = None,
force_enable_grad=False,

View File

@ -809,6 +809,12 @@ def scatter_upon_const_tensor(
"""
from torch._inductor import metrics
# Check if inputs are tensors instead of inductor IR nodes
if isinstance(selector, torch.Tensor):
# Return a fake tensor with the proper shape that this operator is intended to return
device = selector.device if hasattr(selector, "device") else torch.device("cpu")
return torch.empty(shape, dtype=dtype, device=device)
metrics.num_matches_for_scatter_upon_const_tensor += 1
selector_loader = selector.make_loader()