mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[scan] create fw and bw graphs via partitioning (#162754)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162754 Approved by: https://github.com/zou3519 ghstack dependencies: #161557, #161664, #161808, #162025, #161732
This commit is contained in:
parent
3413490f53
commit
8f6dbc0ba8
|
|
@ -1830,34 +1830,34 @@ def forward(self, pred_1, x_1):
|
|||
]
|
||||
)
|
||||
|
||||
init_clone = [i.clone() for i in init]
|
||||
init_clone2 = [i.clone() for i in init]
|
||||
elements_clone = [ele.clone() for ele in elements]
|
||||
elements_clone2 = [ele.clone() for ele in elements]
|
||||
result = scan_fct(
|
||||
get_scan_combine_fn("s5_operator", False),
|
||||
init,
|
||||
elements,
|
||||
dim=0,
|
||||
init_clone,
|
||||
elements_clone,
|
||||
reverse=reverse,
|
||||
)
|
||||
expected_result = _fake_scan(
|
||||
get_scan_combine_fn("s5_operator", False),
|
||||
init=init,
|
||||
xs=elements,
|
||||
dim=0,
|
||||
init_clone2,
|
||||
elements_clone2,
|
||||
reverse=reverse,
|
||||
)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
if autograd:
|
||||
init_flatten, _ = pytree.tree_flatten(init)
|
||||
elements_flatten, _ = pytree.tree_flatten(elements)
|
||||
|
||||
result_flatten, _ = pytree.tree_flatten(result)
|
||||
result_exp_flatten, _ = pytree.tree_flatten(expected_result)
|
||||
|
||||
grad_out = [torch.ones_like(el) for el in result_exp_flatten]
|
||||
expected_grads = torch.autograd.grad(
|
||||
result_exp_flatten, (*init_flatten, *elements_flatten), grad_out
|
||||
result_exp_flatten, (*init_clone2, *elements_clone2), grad_out
|
||||
)
|
||||
grads = torch.autograd.grad(
|
||||
result_flatten, (*init_flatten, *elements_flatten), grad_out
|
||||
result_flatten, (*init_clone, *elements_clone), grad_out
|
||||
)
|
||||
self.assertEqual(grads, expected_grads)
|
||||
|
||||
|
|
@ -2757,9 +2757,7 @@ class GraphModule(torch.nn.Module):
|
|||
l_init_1_ = L_init_1_
|
||||
l_xs_ = L_xs_
|
||||
|
||||
elem: "f32[3, 10, 2]" = torch.movedim(l_xs_, 0, 0); l_xs_ = None
|
||||
|
||||
flip: "f32[3, 10, 2]" = torch.flip(elem, [0]); elem = None
|
||||
flip: "f32[3, 10, 2]" = torch.flip(l_xs_, [0]); l_xs_ = None
|
||||
|
||||
scan_combine_fn_0 = self.scan_combine_fn_0
|
||||
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_0_, l_init_1_], [flip], []); scan_combine_fn_0 = l_init_0_ = l_init_1_ = flip = None
|
||||
|
|
@ -3457,8 +3455,7 @@ class GraphModule(torch.nn.Module):
|
|||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, fct_1, init_1, xs_1):
|
||||
permute = torch.ops.aten.permute.default(xs_1, [0, 1, 2])
|
||||
flip = torch.ops.aten.flip.default(permute, [0]); permute = None
|
||||
flip = torch.ops.aten.flip.default(xs_1, [0])
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 1)
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(init_1, 2)
|
||||
sym_size_int_3 = torch.ops.aten.sym_size.int(xs_1, 1)
|
||||
|
|
@ -3482,8 +3479,7 @@ def forward(self, fct_1, init_1, xs_1):
|
|||
def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor):
|
||||
l_init_ = L_init_
|
||||
l_xs_ = L_xs_
|
||||
elem = torch.movedim(l_xs_, 0, 0); l_xs_ = None
|
||||
flip = torch.flip(elem, [0]); elem = None
|
||||
flip = torch.flip(l_xs_, [0]); l_xs_ = None
|
||||
scan_combine_fn_0 = self.scan_combine_fn_0
|
||||
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [flip], []); scan_combine_fn_0 = l_init_ = flip = None
|
||||
carry = scan[0]
|
||||
|
|
@ -8069,9 +8065,8 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
|
|||
l_xs_ = L_xs_
|
||||
l_add_closure_0_cell_contents_0_param_ = L_add_closure_0_cell_contents_0_param_
|
||||
l_add_closure_0_cell_contents_1_0_ = L_add_closure_0_cell_contents_1_0_
|
||||
r = torch.movedim(l_xs_, 0, 0); l_xs_ = None
|
||||
scan_combine_fn_0 = self.scan_combine_fn_0
|
||||
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [r], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = r = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
|
||||
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = l_xs_ = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
|
||||
carry = scan[0]
|
||||
out = scan[1]; scan = None
|
||||
return (carry, out)""", # noqa: B950
|
||||
|
|
@ -8085,9 +8080,8 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
|
|||
l_xs_ = L_xs_
|
||||
l_add_closure_0_cell_contents_0_param_ = L_add_closure_0_cell_contents_0_param_
|
||||
l_add_closure_0_cell_contents_1_0_ = L_add_closure_0_cell_contents_1_0_
|
||||
movedim = torch.movedim(l_xs_, 0, 0); l_xs_ = None
|
||||
scan_combine_fn_0 = self.scan_combine_fn_0
|
||||
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [movedim], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = movedim = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
|
||||
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = l_xs_ = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
|
||||
carry = scan[0]
|
||||
out = scan[1]; scan = None
|
||||
return (carry, out)""", # noqa: B950
|
||||
|
|
|
|||
|
|
@ -1561,6 +1561,8 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
|||
for user in node.users:
|
||||
if (
|
||||
must_recompute(user)
|
||||
and "ac_graph_id" in user.meta
|
||||
and "ac_graph_id" in node.meta
|
||||
and user.meta["ac_graph_id"] > node.meta["ac_graph_id"]
|
||||
):
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
|
|
|
|||
364
torch/_higher_order_ops/partitioner.py
Normal file
364
torch/_higher_order_ops/partitioner.py
Normal file
|
|
@ -0,0 +1,364 @@
|
|||
import logging
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.utils import create_bw_fn, materialize_as_graph
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def _find_hop_subgraph_outputs(gm: torch.fx.GraphModule) -> tuple[torch.fx.Node]:
|
||||
output_node_args = gm.graph.find_nodes(op="output")[0].args
|
||||
assert isinstance(output_node_args, tuple)
|
||||
return output_node_args[0]
|
||||
|
||||
|
||||
def is_complex_expr(expr: Any) -> bool:
|
||||
return not expr.is_symbol and not expr.is_constant()
|
||||
|
||||
|
||||
class HopPartitionedGraph:
|
||||
def __init__(
|
||||
self,
|
||||
fw_gm: torch.fx.GraphModule,
|
||||
bw_gm: torch.fx.GraphModule,
|
||||
n_fw_outputs: int,
|
||||
n_intermediates: int,
|
||||
no_complex_exprs_at_boundary: bool,
|
||||
):
|
||||
self.fw_gm = fw_gm
|
||||
self.bw_gm = bw_gm
|
||||
self.n_fw_outputs = n_fw_outputs
|
||||
self.n_intermediates = n_intermediates
|
||||
self.no_complex_exprs_at_boundary = no_complex_exprs_at_boundary
|
||||
self._reorder_fw_output()
|
||||
self._check_partition_boundary()
|
||||
|
||||
def _check_partition_boundary(self) -> None:
|
||||
"""check partitioned graph is in valid state."""
|
||||
invalid_reasons = []
|
||||
fw_outputs = _find_hop_subgraph_outputs(self.fw_gm)
|
||||
for i, out in enumerate(fw_outputs):
|
||||
if "val" not in out.meta:
|
||||
invalid_reasons.append(f"fw_gm output[{i}] doesn't have a 'val' meta.")
|
||||
elif not isinstance(out.meta["val"], (torch.SymInt, torch.Tensor)):
|
||||
invalid_reasons.append(
|
||||
f"fw_gm output[{i}] is of type {type(out.meta['val'])} but only SymInt or Tensor are allowed."
|
||||
)
|
||||
|
||||
elif (
|
||||
isinstance(out.meta["val"], torch.SymInt)
|
||||
and is_complex_expr(out.meta["val"].node.expr)
|
||||
and self.no_complex_exprs_at_boundary
|
||||
):
|
||||
invalid_reasons.append(
|
||||
f"fw_gm output[{i}] must be of type SymInt with basic symbols or "
|
||||
f"Tensor but got {type(out.meta['val'])} {out.meta['val']}"
|
||||
)
|
||||
|
||||
if len(fw_outputs) != self.n_fw_outputs + self.n_intermediates:
|
||||
invalid_reasons.append(
|
||||
f"len(fw_outputs) ({len(fw_outputs)}) != n_fw_outputs ({self.n_fw_outputs}) + n_intermediates ({self.n_intermediates})" # noqa: B950
|
||||
)
|
||||
|
||||
bw_phs = list(self.bw_gm.graph.find_nodes(op="placeholder"))
|
||||
|
||||
if len(fw_outputs) != len(bw_phs):
|
||||
invalid_reasons.append(
|
||||
f"Expect number of fw_gm's output to be the same as bw_gm's input but "
|
||||
f"fw_gm has {len(fw_outputs)} outputs, bw_gm takes {len(bw_phs)} inputs."
|
||||
)
|
||||
|
||||
original_forward_outputs = fw_outputs[: self.n_fw_outputs]
|
||||
fw_intermediates = fw_outputs[self.n_fw_outputs :]
|
||||
|
||||
bw_intermediates = bw_phs[: -self.n_fw_outputs]
|
||||
bw_grads = bw_phs[-self.n_fw_outputs :]
|
||||
|
||||
def _match_size_or_expr(
|
||||
val1: Union[torch.SymInt, torch.Tensor],
|
||||
val2: Union[torch.SymInt, torch.Tensor],
|
||||
) -> bool:
|
||||
if type(val1) != type(val2):
|
||||
return False
|
||||
|
||||
if isinstance(val1, torch.SymInt) and isinstance(val2, torch.SymInt):
|
||||
return val1.node.expr == val2.node.expr
|
||||
elif isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor):
|
||||
return val1.size() == val2.size()
|
||||
|
||||
return False
|
||||
|
||||
for fw, bw in zip(fw_intermediates, bw_intermediates):
|
||||
if fw.name != bw.name or not _match_size_or_expr(
|
||||
fw.meta["val"], bw.meta["val"]
|
||||
):
|
||||
invalid_reasons.append("fw intermediates don't match bw intermediates")
|
||||
|
||||
for fw_out, bw_grad in zip(original_forward_outputs, bw_grads):
|
||||
if not _match_size_or_expr(fw_out.meta["val"], bw_grad.meta["val"]):
|
||||
invalid_reasons.append("fw outputs don't match bw gradients")
|
||||
|
||||
if len(invalid_reasons) > 0:
|
||||
newline = "\n"
|
||||
raise RuntimeError(
|
||||
"Invalid HopPartitionedGraph. Reasons:\n",
|
||||
f"{newline.join(invalid_reasons)}",
|
||||
)
|
||||
|
||||
def _reorder_fw_output(self) -> None:
|
||||
"""
|
||||
Before the pass, fw_gm returns (*fw_outputs, *intermediates1)
|
||||
and bw_gm takes (*intermediates2, *grad_fw_outputs) as input.
|
||||
intermediates1 and intermediates2 share the same node names but
|
||||
they might be in different order. E.g. this could happen if there
|
||||
are inputs that contain symints.
|
||||
|
||||
To simplify downstream processing, this graph pass normalizes the output of fw_gm
|
||||
to be consistent with the bacwkard inputs:
|
||||
|
||||
fw_gm:
|
||||
- input: fw_args
|
||||
- output: (*fw_outputs, *intermediates)
|
||||
|
||||
bw_gm:
|
||||
- input: (*intermediates, *grad_fw_outputs)
|
||||
- output: grad_fw_args
|
||||
|
||||
Example:
|
||||
|
||||
def fw_gm(x, y, z):
|
||||
a, b, c = f(x), g(y), k(z)
|
||||
return a, b, c, f_tmp, g_tmp, k_tmp
|
||||
|
||||
, where a, b, c are fw_outputs, f_tmp, g_tmp, k_tmp are intermediates
|
||||
|
||||
The corresponding bw_gm has the following signature:
|
||||
|
||||
def bw_gm(f_tmp, g_tmp, k_tmp, grad_a, grad_b, grac):
|
||||
return grad_x, grad_y, grad_z
|
||||
"""
|
||||
fw_gm_output_nodes = _find_hop_subgraph_outputs(self.fw_gm)
|
||||
fw_outputs_nodes = fw_gm_output_nodes[: self.n_fw_outputs]
|
||||
fw_intermediates_nodes = fw_gm_output_nodes[self.n_fw_outputs :]
|
||||
if len(fw_intermediates_nodes) > 0:
|
||||
fw_intermediates_name_to_node = {n.name: n for n in fw_intermediates_nodes}
|
||||
|
||||
# First n_intermediates placeholders
|
||||
bw_names: list[str] = [
|
||||
ph.name
|
||||
for ph in list(self.bw_gm.graph.find_nodes(op="placeholder"))[
|
||||
: self.n_intermediates
|
||||
]
|
||||
]
|
||||
new_fw_outputs = list(fw_outputs_nodes) + [
|
||||
fw_intermediates_name_to_node[name] for name in bw_names
|
||||
]
|
||||
|
||||
output_node = self.fw_gm.graph.find_nodes(op="output")[0]
|
||||
output_node.args = (tuple(new_fw_outputs),)
|
||||
|
||||
self.fw_gm.graph.lint()
|
||||
self.fw_gm.recompile()
|
||||
|
||||
|
||||
class HopJointGraph:
|
||||
def __init__(
|
||||
self,
|
||||
joint_gm: torch.fx.GraphModule,
|
||||
n_primals: int,
|
||||
n_fw_outputs: int,
|
||||
*,
|
||||
functionalized: bool,
|
||||
):
|
||||
self.joint_gm = joint_gm
|
||||
self.n_primals = n_primals
|
||||
self.n_fw_outputs = n_fw_outputs
|
||||
self.functionalized = functionalized
|
||||
|
||||
self._rename_phs()
|
||||
self._remove_redundant_sym_size_ops()
|
||||
|
||||
def _rename_phs(self) -> None:
|
||||
"""
|
||||
Rename the placeholders for joint_gm so that the partitioner
|
||||
could recognize which inputs are primals and which are tangents.
|
||||
"""
|
||||
self.n_tangents = 0
|
||||
for i, ph in enumerate(self.joint_gm.graph.find_nodes(op="placeholder")):
|
||||
if i < self.n_primals:
|
||||
ph.target = f"primals_{i}"
|
||||
ph.name = f"primals_{i}"
|
||||
else:
|
||||
self.n_tangents += 1
|
||||
ph.target = f"tangents_{i - self.n_primals}"
|
||||
ph.name = f"tangents_{i - self.n_primals}"
|
||||
|
||||
self.joint_gm.graph.lint()
|
||||
self.joint_gm.compile()
|
||||
|
||||
def _remove_redundant_sym_size_ops(self) -> None:
|
||||
"""
|
||||
Deletes torch.ops.sym_size.int operators whose output is a
|
||||
corresponding placeholder that holds the same symbol, and replace all usage
|
||||
of the sym_size node to be directly using the placeholders.
|
||||
|
||||
This is to make sure all basic symbols come from inputs.
|
||||
"""
|
||||
placeholder_exprs = {}
|
||||
for node in self.joint_gm.graph.nodes:
|
||||
if (
|
||||
isinstance(node, torch.fx.Node)
|
||||
and node.op == "placeholder"
|
||||
and hasattr(node, "meta")
|
||||
and "val" in node.meta
|
||||
):
|
||||
val = node.meta["val"]
|
||||
if isinstance(val, torch.SymInt):
|
||||
placeholder_exprs[val.node.expr] = node
|
||||
|
||||
nodes_to_remove = []
|
||||
for node in self.joint_gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.sym_size.int
|
||||
):
|
||||
assert hasattr(node, "meta") and "val" in node.meta, node
|
||||
val = node.meta["val"]
|
||||
expr = val.node.expr
|
||||
if expr in placeholder_exprs:
|
||||
placeholder_node = placeholder_exprs[expr]
|
||||
node.replace_all_uses_with(placeholder_node)
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
for node in nodes_to_remove:
|
||||
self.joint_gm.graph.erase_node(node)
|
||||
|
||||
self.joint_gm.graph.lint()
|
||||
self.joint_gm.recompile()
|
||||
|
||||
def _mark_complex_exprs_as_must_recompute(self) -> None:
|
||||
"""
|
||||
For control flow operators such as scan, we don't want to
|
||||
have symint in the partitioning boundaries because otherwise we would need to support stacking
|
||||
the symints up, which causes more entropy in the stack.
|
||||
|
||||
By marking the recompute polify for complex nodes as MUST_RECOMPUTE, the partitioning boundary
|
||||
no longer contains complex expressions.
|
||||
|
||||
Note that this pass doesn't exclude basic symbols from partitioning boundary
|
||||
and it's up to the downstream to decide whether to return the basic symbol
|
||||
or have a separate graph pass to remove them.
|
||||
"""
|
||||
|
||||
from torch._functorch.partitioners import CheckpointPolicy
|
||||
|
||||
for n in (
|
||||
node for node in self.joint_gm.graph.nodes if node.op == "call_function"
|
||||
):
|
||||
if "val" not in n.meta:
|
||||
continue
|
||||
val = n.meta["val"]
|
||||
if isinstance(val, torch.SymInt) and is_complex_expr(val.node.expr):
|
||||
assert n.meta.get("recompute", None) is None
|
||||
|
||||
n.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE
|
||||
|
||||
self.joint_gm.graph.lint()
|
||||
self.joint_gm.recompile()
|
||||
|
||||
def partition(
|
||||
self, partition_fn: Callable, always_recompute_complex_exprs: bool
|
||||
) -> HopPartitionedGraph:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"before min_cut_partition:\n%s",
|
||||
self.joint_gm.print_readable(print_output=False),
|
||||
)
|
||||
|
||||
if always_recompute_complex_exprs:
|
||||
self._mark_complex_exprs_as_must_recompute()
|
||||
|
||||
fw_gm, bw_gm = partition_fn(
|
||||
self.joint_gm, None, num_fwd_outputs=self.n_fw_outputs
|
||||
)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("after partition_fn:")
|
||||
logger.debug("fw_gm:\n%s", fw_gm.print_readable(print_output=False))
|
||||
logger.debug("bw_gm:\n%s", bw_gm.print_readable(print_output=False))
|
||||
|
||||
n_intermediates = len(_find_hop_subgraph_outputs(fw_gm)) - self.n_fw_outputs
|
||||
|
||||
return HopPartitionedGraph(
|
||||
fw_gm,
|
||||
bw_gm,
|
||||
self.n_fw_outputs,
|
||||
n_intermediates,
|
||||
always_recompute_complex_exprs,
|
||||
)
|
||||
|
||||
|
||||
def create_hop_joint_graph(
|
||||
fw_fn: Callable,
|
||||
fw_args: tuple[Union[torch.Tensor, torch.SymInt], ...],
|
||||
functionalize: bool,
|
||||
) -> HopJointGraph:
|
||||
fw_gm = materialize_as_graph(fw_fn, fw_args, force_enable_grad=True)
|
||||
fw_gm_output_nodes = _find_hop_subgraph_outputs(fw_gm)
|
||||
|
||||
assert all(
|
||||
isinstance(n, torch.fx.Node) and "val" in n.meta for n in fw_gm_output_nodes
|
||||
)
|
||||
fw_gm_output_vals = tuple(n.meta["val"] for n in fw_gm_output_nodes) # type: ignore[arg-type]
|
||||
|
||||
assert all(isinstance(val, torch.Tensor) for val in fw_gm_output_vals)
|
||||
example_grads = tuple(torch.zeros_like(val) for val in fw_gm_output_vals)
|
||||
|
||||
joint_fn = create_bw_fn(fw_fn, fw_args, return_fw_outputs=True)
|
||||
joint_gm = materialize_as_graph(
|
||||
joint_fn, fw_args + example_grads, force_enable_grad=True
|
||||
)
|
||||
if functionalize:
|
||||
# Need to first trace out the joint_fn with autograd info on
|
||||
# then functionalize the graph otherwise the grad information is lost
|
||||
joint_gm = materialize_as_graph(
|
||||
torch.func.functionalize(joint_gm, remove="mutations_and_views"),
|
||||
fw_args + example_grads,
|
||||
)
|
||||
|
||||
return HopJointGraph(
|
||||
joint_gm,
|
||||
len(fw_args),
|
||||
len(fw_gm_output_nodes),
|
||||
functionalized=functionalize,
|
||||
)
|
||||
|
||||
|
||||
class HopGraphMinCutPartitioner:
|
||||
@staticmethod
|
||||
def create_partitioned_graph(
|
||||
fw_fn: Callable,
|
||||
fw_args: tuple[Union[torch.Tensor, torch.SymInt], ...],
|
||||
*,
|
||||
always_recompute_complex_exprs: bool = False,
|
||||
) -> HopPartitionedGraph:
|
||||
"""
|
||||
Inputs:
|
||||
- fw_fn: the forward function that we'll use to create a joint graph and partition
|
||||
- fw_args: the flat_args to fw_fn
|
||||
- always_recompute_complex_exprs: when set to True, the bw_gm will do a re-compute
|
||||
for inputs that are complex expressions such that the partitioning boundary
|
||||
only consists of basic symbols and tensors.
|
||||
|
||||
Returns a HopPartitionedGraph
|
||||
"""
|
||||
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
||||
|
||||
joint_graph: HopJointGraph = create_hop_joint_graph(
|
||||
fw_fn, fw_args, functionalize=True
|
||||
)
|
||||
return joint_graph.partition(
|
||||
min_cut_rematerialization_partition, always_recompute_complex_exprs
|
||||
)
|
||||
|
|
@ -1,26 +1,30 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.partitioner import (
|
||||
_find_hop_subgraph_outputs,
|
||||
HopGraphMinCutPartitioner,
|
||||
HopPartitionedGraph,
|
||||
)
|
||||
from torch._higher_order_ops.utils import (
|
||||
_maybe_compile_and_run_fn,
|
||||
check_input_alias_and_mutation_return_outputs,
|
||||
check_meta_consistency,
|
||||
create_bw_fn,
|
||||
fill_none_with_masks,
|
||||
filter_with_masks,
|
||||
first_slice_copy,
|
||||
first_slice_copy_with_grad,
|
||||
get_tensor_mask,
|
||||
mask_list,
|
||||
materialize_as_graph,
|
||||
reenter_make_fx,
|
||||
save_tensors_and_symints_for_backward,
|
||||
saved_tensors_and_symints,
|
||||
split_into_chunks,
|
||||
unique_graph_id,
|
||||
validate_subgraph_args_types,
|
||||
|
|
@ -35,6 +39,7 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
aten = torch._ops.ops.aten
|
||||
|
||||
|
||||
|
|
@ -42,7 +47,7 @@ def wrap_combine_fn_flat(
|
|||
*args, combine_fn, spec_init, spec_xs, num_init_leaves, num_inp_leaves
|
||||
):
|
||||
assert len(args) == (num_init_leaves + num_inp_leaves), (
|
||||
f"Combin_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}"
|
||||
f"combine_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}"
|
||||
)
|
||||
carry = pytree.tree_unflatten(args[:num_init_leaves], spec_init)
|
||||
xs = pytree.tree_unflatten(args[num_init_leaves:], spec_xs)
|
||||
|
|
@ -182,7 +187,7 @@ def scan(
|
|||
# Move scan dim to 0 and always perform scan on dim 0
|
||||
leaves_xs = []
|
||||
for elem in leaves_xs_orig:
|
||||
leaves_xs.append(torch.movedim(elem, dim, 0))
|
||||
leaves_xs.append(torch.movedim(elem, dim, 0) if dim != 0 else elem)
|
||||
|
||||
if reverse:
|
||||
leaves_xs = [torch.flip(elem, [0]) for elem in leaves_xs]
|
||||
|
|
@ -429,121 +434,6 @@ def scan_op_dense(combine_fn, init, xs, additional_inputs):
|
|||
|
||||
class ScanAutogradOp(torch.autograd.Function):
|
||||
"""
|
||||
Example ::
|
||||
|
||||
def combine_fn(x: torch.Tensor, y: torch.Tensor):
|
||||
next_carry = y = x * y
|
||||
return next_carry, y
|
||||
|
||||
The ``combine_fn_bw``, computing the gradients for x and y of ``combine_fn`` is computed as:
|
||||
def combine_fn_bw(x: torch.Tensor, y: torch.Tensor, g_carry: torch.Tensor, g_y: torch.Tensor):
|
||||
return g_y * y + g_carry * y, g_y * x + g_carry * x
|
||||
|
||||
Note: In a real usecase of scan, there may be additional_inputs that participate in the
|
||||
forward as well as in the backward of the scan operator. For the sake of readability those inputs
|
||||
have been omitted in the following example, but are included in the subsequent detailed description below
|
||||
|
||||
The forward output of scan is computed as:
|
||||
carry, ys = scan(combine_fn, init, xs).
|
||||
|
||||
This computation can be unpacked as
|
||||
c_0, ys_0 = combine_fn(init, xs_0)
|
||||
c_1, ys_1 = combine_fn(carry_0, xs_1)
|
||||
c_2, ys_2 = combine_fn(carry_1, xs_2)
|
||||
...
|
||||
c_T, ys_T = combine_fn(carry_(T-1), xs_T)
|
||||
|
||||
We collect c_0, c_1, ..., c_T into a vector of carries that we save for the backward,
|
||||
but we only output (c_T, ys),
|
||||
where ys is the vector of all intermediate outputs [y_0, y_1, ..., y_T].
|
||||
|
||||
Given the carries and the ys, the gradients for xs and for init can be computed as follows:
|
||||
We receive the upstream gradients in torch.autograd.Function, i.e., we get g_c_T and g_ys,
|
||||
where g_ys is the vector of all intermediate gradients of the outputs [g_ys_0, g_ys_1, ..., g_ys_T]
|
||||
|
||||
We then proceed to compute the gradients for the init (g_init) and the xs (g_xs) by running a
|
||||
scan operation reverse over time. For example,
|
||||
|
||||
g_c_(T-1), g_xs_T = combine_fn_bw(c_(T-1), xs_T, g_c_T, g_ys_T)
|
||||
g_c_(T-2), g_xs_(T-1) = combine_fn_bw(c_(T-2), xs_(T-1), g_c_(T-1), g_ys_(T-1))
|
||||
g_c_(T-3), g_xs_(T-2) = combine_fn_bw(c_(T-3), xs_(T-2), g_c_(T-2), g_ys_(T-2))
|
||||
...
|
||||
g_init, g_xs_1 = combine_fn_bw(c_0, xs_1, g_c_0, g_ys_1)
|
||||
0 , g_xs_0 = combine_fn_bw(init, xs_0, g_init, g_ys_0),
|
||||
|
||||
where combine_fn_bw takes the forward inputs of step t (i.e. c_(t-1), xs_t),
|
||||
the gradients of the carry of step t (i.e. g_c_t) and
|
||||
the upstream gradient of the output of step t (i.e. g_ys_T)
|
||||
and returns the gradient of xs_t -> g_xs_t, as well as the gradient for the carry of step t-1 -> g_c_(t-1).
|
||||
|
||||
Through this procedure we end up with the
|
||||
gradients for the init -> g_init,
|
||||
the gradients for the xs -> g_xs.
|
||||
|
||||
|
||||
NOTE: [scan autograd implementation]
|
||||
|
||||
The forward of scan can be computed as:
|
||||
1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint``:
|
||||
To use a scan operation for the backward path as well, we need access to the carries from all steps.
|
||||
Thus, the function ``combine_fn`` is wrapped such that it returns all carries and not only the last carry.
|
||||
In particular, we define ``combine_fn_with_carry_checkpoint``:
|
||||
def combine_fn_with_carry_checkpoint(x: torch.Tensor, y: torch.Tensor):
|
||||
carry, y = combine_fn(x, y)
|
||||
return carry, (carry, y)
|
||||
|
||||
The scan operator will stack all outputs along the scan dimension.
|
||||
Thus, by putting next_carry also into outputs of ``combine_fn_with_carry_checkpoint``,
|
||||
the carries from all steps will be stacked and hence gives us chekpointed_carries
|
||||
|
||||
2.) Compute all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint``:
|
||||
c_T, (carries, ys) = scan_op(combine_fn_with_carry_checkpoint, init, xs, additional_inputs),
|
||||
Where c_T (last carry) and ys (all outputs) are the original results of scan with the ``combine_fn``.
|
||||
However, carries are checkpointed carries from all steps.
|
||||
As a result of the forward, only the last carry c_T and the ys are returned,
|
||||
while all carries are saved for the backward.
|
||||
|
||||
The backward of scan can be computed as:
|
||||
|
||||
3.) Prepare the backward graph:
|
||||
We prepare the backward graph to be used in the backward function.
|
||||
We utilize ``create_bw_fn`` to generate the joint function, i.e.,
|
||||
ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands), where fw_operands = [init, xs_0, additional_inputs]
|
||||
|
||||
The ctx._combine_fn_bw requires the primals (operands)
|
||||
followed by the tangents (upstream gradients) from a single step
|
||||
and produces the gradients of that step, i.e.,
|
||||
g_c_(T-1), g_xs_T, g_additional_input_T = ctx._combine_fn_bw(c_(T-1), xs_T, additional_inputs, g_c_T, g_ys_T).
|
||||
|
||||
4.) Create a wrapper of the ``combine_fn_bw``, i.e., ``combine_fn_bw_grad_accumulation``:
|
||||
In the forward, there may be additional inputs that participate in every forward step.
|
||||
The gradients for those additional inputs are also computed at every step and need to be accumulated over all steps,
|
||||
which is taken care of in this wrapper. For example:
|
||||
def combine_fn_bw_grad_accumulation(*args):
|
||||
carried_g_additional_input = args[:num_additional_inputs]
|
||||
inputs_bw_fn = args[num_additional_inputs:]
|
||||
g_c_(t-1), g_xs_t, g_additional_input_t = ctx._combine_fn_bw(*inputs_bw_fn)
|
||||
new_g_additional_inputs = carried_g_additional_input + g_additional_input_t
|
||||
# The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator
|
||||
# The ``g_xs_t`` is encoded as the output of the backward scan operator
|
||||
return [*new_g_additional_inputs, *g_c_t, *g_xs_t]
|
||||
|
||||
5.) Perform the backward scan as
|
||||
g_additional_inputs, g_init, g_xs = scan_op(combine_fn_bw_grad_accumulation, bw_init, bw_xs), where
|
||||
bw_init consists of the initial gradient carry for the additional_inputs (initialized with 0s):
|
||||
initial_g_additional_inputs, and the gradient of the last carry: g_c_T. Thus:
|
||||
bwd_init = [*initial_g_additional_inputs, *g_c_T].
|
||||
|
||||
bw_xs consists of the combination of the upstream gradients g_ys,
|
||||
the forward carries prepended with the fw_init, i.e., bw_carries = concat([fw_init, fw_carries[:-1]]) and
|
||||
the fw_xs. In particular,
|
||||
bwd_xs = [*g_ys, *bw_carries, *fw_xs].
|
||||
|
||||
Note: g_c_T and g_ys are provided through the torch.autograd.Function.backward's input
|
||||
|
||||
As demonstrated in the Example above, this backward scan then yields the gradient for the init -> g_init
|
||||
and the gradient for the xs -> g_xs
|
||||
|
||||
NOTE: [scan partial grad handling]
|
||||
If any element of init, of xs, of the outputs or of the additional_inputs does not require gradients,
|
||||
i.e., requires_grad=False, there will be still gradients returned for those elements,
|
||||
|
|
@ -559,300 +449,361 @@ class ScanAutogradOp(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
combine_fn,
|
||||
num_leaves_init,
|
||||
num_leaves_xs,
|
||||
num_additional_inputs,
|
||||
hop_partitioned_graph,
|
||||
n_init,
|
||||
n_xs,
|
||||
n_additional_inputs,
|
||||
*operands,
|
||||
):
|
||||
ctx._num_leaves_init = num_leaves_init
|
||||
ctx._num_leaves_xs = num_leaves_xs
|
||||
ctx._num_additional_inputs = num_additional_inputs
|
||||
ctx._combine_fn = combine_fn
|
||||
init, xs, additional_inputs = split_into_chunks(
|
||||
operands, [num_leaves_init, num_leaves_xs, num_additional_inputs]
|
||||
operands, [n_init, n_xs, n_additional_inputs]
|
||||
)
|
||||
additional_inputs_tensor_mask = get_tensor_mask(additional_inputs)
|
||||
ctx._additional_inputs_tensor_mask = additional_inputs_tensor_mask
|
||||
|
||||
# We snapshot the dispatch keys in forward for materializing the
|
||||
# the bw_graph in backward.
|
||||
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
|
||||
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
|
||||
|
||||
# 1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint``
|
||||
# The wrapper of the forward graph returns carries from all iterations,
|
||||
# not just from the last iteration. These are required in the backward path
|
||||
def combine_fn_with_carry_checkpoint(*args):
|
||||
carry, y = _extract_carry_and_out(combine_fn(*args), num_leaves_init)
|
||||
return [
|
||||
*carry,
|
||||
# We additionally checkpoint all the intermediate carry outputs for backward.
|
||||
*[n_c.detach().clone() for n_c in carry],
|
||||
*y,
|
||||
]
|
||||
|
||||
# Materialize the ``combine_fn_with_carry_checkpoint`` with enable_grad
|
||||
# we need enable_grad to support torch.func.grad_and_value
|
||||
# in subgraph.
|
||||
gm = materialize_as_graph(
|
||||
combine_fn_with_carry_checkpoint,
|
||||
(*init, *[x[0] for x in xs], *additional_inputs),
|
||||
force_enable_grad=True,
|
||||
ctx._scan_impl = ScanAutogradImpl(
|
||||
hop_partitioned_graph, init, xs, additional_inputs
|
||||
)
|
||||
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
# 2.) Compute the all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint``
|
||||
c_T, carries_ys = _extract_carry_and_out(
|
||||
scan_op(
|
||||
gm,
|
||||
init,
|
||||
xs,
|
||||
additional_inputs,
|
||||
),
|
||||
num_leaves_init,
|
||||
)
|
||||
|
||||
# Collect the carries for each time step from the outs
|
||||
# and save them for the backward path
|
||||
carries = list(carries_ys[:num_leaves_init])
|
||||
ys = list(carries_ys[num_leaves_init:])
|
||||
save_tensors_and_symints_for_backward(ctx, list(operands) + carries + ys)
|
||||
ctx._num_leaves_ys = len(ys)
|
||||
|
||||
return (*c_T, *ys)
|
||||
return ctx._scan_impl.call_forward()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *flat_grads):
|
||||
r"""
|
||||
This function computes the gradients of the scan operation.
|
||||
It does so by using a scan operator using all carries and the upstream gradients (see description above)
|
||||
|
||||
Args:
|
||||
flat_grads (torch.Tensor): The tensor of flattened upstream gradients.
|
||||
"""
|
||||
from torch._higher_order_ops.utils import fill_none_with_masks
|
||||
|
||||
# Collect the saved items from the forward
|
||||
num_leaves_init = ctx._num_leaves_init
|
||||
num_leaves_xs = ctx._num_leaves_xs
|
||||
num_leaves_ys = ctx._num_leaves_ys
|
||||
num_additional_inputs = ctx._num_additional_inputs
|
||||
additional_inputs_tensor_mask = ctx._additional_inputs_tensor_mask
|
||||
|
||||
def prepend_init_to_carries(init, carries):
|
||||
# Prepare the carries for the backward path.
|
||||
# This requires to concatenate the init and the carries
|
||||
return [
|
||||
torch.cat([torch.unsqueeze(i, 0), c[:-1]], dim=0)
|
||||
for i, c in zip(init, carries)
|
||||
]
|
||||
|
||||
def initialize_g_additional_inputs(
|
||||
additional_inputs,
|
||||
):
|
||||
g_additional_inputs = [
|
||||
torch.zeros_like(ai)
|
||||
for ai in filter_with_masks(
|
||||
additional_inputs, additional_inputs_tensor_mask
|
||||
)
|
||||
]
|
||||
return g_additional_inputs
|
||||
|
||||
# Retrieve the forward inputs and the forward outputs and dissect them
|
||||
flat_args = saved_tensors_and_symints(ctx)
|
||||
fw_init, fw_xs, additional_inputs, fw_carries, fw_ys = split_into_chunks(
|
||||
flat_args,
|
||||
[
|
||||
num_leaves_init,
|
||||
num_leaves_xs,
|
||||
num_additional_inputs,
|
||||
num_leaves_init,
|
||||
num_leaves_ys,
|
||||
],
|
||||
)
|
||||
|
||||
# 3.) Prepare the backward graph
|
||||
fw_operands = (
|
||||
*fw_init,
|
||||
*[first_slice_copy(xs) for xs in fw_xs],
|
||||
*additional_inputs,
|
||||
)
|
||||
ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands)
|
||||
|
||||
# Initialize the g_additional_inputs with zero-tensors.
|
||||
# This step is necessary because the gradients of the additional inputs are accumulated in the
|
||||
# ``wrapper_bwd_combine_fn`` and thus need a zero-initialized starting point
|
||||
initial_g_additional_inputs = initialize_g_additional_inputs(additional_inputs)
|
||||
|
||||
# 4.) Create the BW wrapper to accumulate the gradients for the additional_inputs
|
||||
def combine_fn_bw_grad_accumulation(*args):
|
||||
# Dissect args and re-order them for the ``ctx._combine_fn_bw``
|
||||
# The content of ``combine_fn_bw_tangents`` is [*carries_g, *outs_g]
|
||||
# The content of ``combine_fn_bw_primals`` is [*init, *xs, *additional_inputs]
|
||||
(
|
||||
carried_g_additional_input,
|
||||
combine_fn_bw_tangents,
|
||||
combine_fn_bw_primals,
|
||||
) = split_into_chunks(
|
||||
args,
|
||||
[
|
||||
len(initial_g_additional_inputs),
|
||||
num_leaves_init + num_leaves_ys,
|
||||
num_leaves_init + num_leaves_xs + num_additional_inputs,
|
||||
],
|
||||
)
|
||||
combine_fn_bw_args = (*combine_fn_bw_primals, *combine_fn_bw_tangents)
|
||||
|
||||
g_c_t, g_xs_t, g_additional_inputs_t = split_into_chunks(
|
||||
ctx._combine_fn_bw(*combine_fn_bw_args),
|
||||
[num_leaves_init, num_leaves_xs, num_additional_inputs],
|
||||
)
|
||||
|
||||
new_g_additional_inputs = [
|
||||
# If the additional inputs are ints or SymInts, those values are taken as is and no gradients are added
|
||||
carr_g + curr_g
|
||||
for carr_g, curr_g in zip(
|
||||
carried_g_additional_input,
|
||||
filter_with_masks(
|
||||
g_additional_inputs_t, additional_inputs_tensor_mask
|
||||
),
|
||||
)
|
||||
]
|
||||
assert all(isinstance(t, torch.Tensor) for t in new_g_additional_inputs)
|
||||
|
||||
# The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator
|
||||
# The ``g_xs_t`` is encoded as the output of the backward scan operator
|
||||
return [*new_g_additional_inputs, *g_c_t, *g_xs_t]
|
||||
|
||||
# Materialize the ``combine_fn_bw_grad_accumulation``
|
||||
def construct_args_single_step_bw():
|
||||
# This function constructs the arguments for a single step of the backward scan.
|
||||
# In other words, it creates the arguments for ``combine_fn_bw_grad_accumulation``
|
||||
# The order of the arguments returned is identical to the order the backward scan
|
||||
# operations provides
|
||||
|
||||
# The following arguments are used for the backward part of the joint graph
|
||||
# The first argument relates to the gradient accumulation of the additional inputs.
|
||||
# Because only tensor elements of additional inputs can have requires_grad=True,
|
||||
# the values for non-tensor elements of additional inputs are None
|
||||
masked_additional_inputs = [
|
||||
a.clone()
|
||||
for a in filter_with_masks(
|
||||
additional_inputs, additional_inputs_tensor_mask
|
||||
)
|
||||
]
|
||||
|
||||
# The second argument relates to the gradients of the carries.
|
||||
# Because the arguments are for a single step only,
|
||||
# only the first slice of the carries is used.
|
||||
sliced_carries = [first_slice_copy(c) for c in fw_carries]
|
||||
|
||||
# The third argument relates to the gradients of the ys.
|
||||
# Because the arguments are for a single step only,
|
||||
# only the first slice of the ys is used.
|
||||
sliced_ys = [first_slice_copy(o) for o in fw_ys]
|
||||
|
||||
# The following arguments are used for the forward part of the joint graph
|
||||
# The fourth argument relates to the init for the forward.
|
||||
# I.e., fw_init
|
||||
|
||||
# The fifth argument relates to the xs for the forward.
|
||||
# Because the arguments are for a single step only,
|
||||
# only the first slice of the xs is used.
|
||||
# Note: It is important to preserve the requires_grad flag of xs
|
||||
# and thus we use the wrapper function ``first_slice_copy_with_grad``
|
||||
fw_xs_slice = first_slice_copy_with_grad(fw_xs)
|
||||
|
||||
# The last argument relates to the additional inputs for the forward.
|
||||
# I.e., additional_inputs
|
||||
|
||||
return (
|
||||
*masked_additional_inputs,
|
||||
*sliced_carries,
|
||||
*sliced_ys,
|
||||
*fw_init,
|
||||
*fw_xs_slice,
|
||||
*additional_inputs,
|
||||
)
|
||||
|
||||
args_single_step_bw = construct_args_single_step_bw()
|
||||
|
||||
# TODO: we need to materialize the bw graphs because dynamo is unable to
|
||||
# trace through the joint function when torch.compile torch.autograd.grad.
|
||||
combine_fn_bw_grad_accumulation_gm = materialize_as_graph(
|
||||
combine_fn_bw_grad_accumulation,
|
||||
args_single_step_bw,
|
||||
ctx._fw_include_key_set,
|
||||
ctx._fw_exclude_key_set,
|
||||
force_enable_grad=True,
|
||||
)
|
||||
|
||||
# Decompose the flat_grads into g_c_T, g_ys
|
||||
g_c_T, g_ys = split_into_chunks(flat_grads, [num_leaves_init, num_leaves_ys])
|
||||
assert all(
|
||||
isinstance(t, torch.Tensor) and t.dtype.is_floating_point for t in g_c_T
|
||||
)
|
||||
assert all(
|
||||
isinstance(t, torch.Tensor) and t.dtype.is_floating_point for t in g_ys
|
||||
)
|
||||
|
||||
# Prepend the inits to the carries.
|
||||
# This is needed, because when computing the gradients, the last carry is not needed
|
||||
# but the first carry, the init, is required.
|
||||
bw_carries = prepend_init_to_carries(fw_init, fw_carries)
|
||||
|
||||
# Prepare the xs for the backward scan.
|
||||
bwd_xs = [*g_ys, *bw_carries, *fw_xs]
|
||||
|
||||
# The flipping of the ``bwd_xs`` is necessary because the scan_op in the backward is always performed in reverse
|
||||
bwd_xs = [torch.flip(elem, [0]) for elem in bwd_xs]
|
||||
|
||||
# Prepare the bwd_init
|
||||
bwd_init = [*initial_g_additional_inputs, *g_c_T]
|
||||
|
||||
# 5.) Perform the backward scan:
|
||||
# The ``combine_fn_bw_wrapped`` receives the
|
||||
# initial_g_additional_inputs and the last carry as the ``bwd_init`` and the
|
||||
# gradients of the outputs (g_ys), as well as the fw_carries and the fw_xs of the forward as the ``bwd_xs``
|
||||
gradients = scan_op(
|
||||
combine_fn_bw_grad_accumulation_gm,
|
||||
bwd_init,
|
||||
bwd_xs,
|
||||
additional_inputs,
|
||||
)
|
||||
|
||||
# Unpack the computed gradients
|
||||
g_additional_inputs, g_init, g_xs = split_into_chunks(
|
||||
gradients,
|
||||
[len(initial_g_additional_inputs), num_leaves_init, num_leaves_xs],
|
||||
)
|
||||
|
||||
# The flipping back along the scan dimension is required to get the gradients in the right order for ``xs``
|
||||
g_xs = [torch.flip(elem, [0]) for elem in g_xs]
|
||||
|
||||
def backward(ctx, *grad_fw_outputs):
|
||||
return (
|
||||
*[None] * 4,
|
||||
*g_init,
|
||||
*g_xs,
|
||||
*fill_none_with_masks(g_additional_inputs, additional_inputs_tensor_mask),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
*ctx._scan_impl.call_backward(*grad_fw_outputs),
|
||||
)
|
||||
|
||||
|
||||
class ScanForwardIntermediatesHandlingPolicy(enum.Enum):
|
||||
"""
|
||||
Partitioner can add interemdiates to the output of original graph.
|
||||
These intermediates fall into 4 categories and we want to have different policies for handling them by
|
||||
modifying the graph:
|
||||
|
||||
CLONE: we clone the intermediate when it is a carried input (i.e. init). In this case, this carry will be
|
||||
replaced with new values at each forward step so we need to clone the carry as part of return (i.e. ys)
|
||||
so as to remove the aliasing and that each step's intermediate will be stacked together and saved in bacwkard.
|
||||
|
||||
REMOVE_XS: we remove the intermediate from output when it is part of xs. Since xs is read-only, in this case,
|
||||
we can directly save them for backward to use.
|
||||
|
||||
REMOVE_ADDITIONAL_INPUTS: we remove the intermediate from output when it is part of additinonal_inputs. additional_inputs
|
||||
are also read-only in each step, we can directly save them for bacwkard to use. We differentiate XS and ADDITIONAL_INPUTS
|
||||
so that we could have different treatment for them in backward. In backward, we need to put xs intermediates in carry but
|
||||
put additional_inputs as backward scan's additional_inputs.
|
||||
|
||||
KEEP: this corresponds to a real intermediate tensor operations' output. It varies at each forward step, we could just keep
|
||||
it as part of ys.
|
||||
|
||||
"""
|
||||
|
||||
KEEP = 0
|
||||
CLONE = 1
|
||||
REMOVE_XS = 2
|
||||
REMOVE_ADDITIONAL_INPUTS = 3
|
||||
|
||||
|
||||
class ScanAutogradImpl:
|
||||
"""
|
||||
Wraps over partitioned graph and encapsulates scan-specific implementation details
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, hop_partitioned_graph: HopPartitionedGraph, init, xs, additional_inputs
|
||||
):
|
||||
self.hop_partitioned_graph = hop_partitioned_graph
|
||||
self.init = init
|
||||
self.xs = xs
|
||||
self.additional_inputs = additional_inputs
|
||||
self.forward_intermediates_handling_policies: list[
|
||||
ScanForwardIntermediatesHandlingPolicy
|
||||
] = []
|
||||
self.saved_fw_xs: list[Any] = []
|
||||
self.saved_fw_additional_inputs: list[Any] = []
|
||||
self.saved_intermediates: list[Any] = []
|
||||
self.fw_spec = pytree.tree_flatten((init, xs, additional_inputs))[1]
|
||||
self._optimize_forward_intermediates()
|
||||
|
||||
def _insert_clone(
|
||||
self, need_copy_node: torch.fx.Node, output_node: torch.fx.Node
|
||||
) -> torch.fx.Node:
|
||||
graph: torch.fx.Graph = output_node.graph
|
||||
with graph.inserting_before(output_node):
|
||||
clone_node = graph.call_function(
|
||||
torch.ops.aten.clone.default,
|
||||
args=(need_copy_node,),
|
||||
)
|
||||
clone_node.meta = (
|
||||
need_copy_node.meta.copy() if hasattr(need_copy_node, "meta") else {}
|
||||
)
|
||||
return clone_node
|
||||
|
||||
def _optimize_forward_intermediates(self):
|
||||
"""
|
||||
We optimize the forward intermediates by categorize forward intermediates into categories
|
||||
and construct a ScanForwardIntermediatesHandlingPolicy for them
|
||||
|
||||
"""
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"Need remove aliasing in fw_gm:\n%s",
|
||||
self.hop_partitioned_graph.fw_gm.print_readable(print_output=False),
|
||||
)
|
||||
|
||||
fw_gm = self.hop_partitioned_graph.fw_gm
|
||||
fw_all_outputs = _find_hop_subgraph_outputs(fw_gm)
|
||||
phs = list(fw_gm.graph.find_nodes(op="placeholder"))
|
||||
fw_outputs = fw_all_outputs[: self.hop_partitioned_graph.n_fw_outputs]
|
||||
fw_intermediates = fw_all_outputs[self.hop_partitioned_graph.n_fw_outputs :]
|
||||
|
||||
init_phs, xs_phs, additional_inputs_phs = pytree.tree_unflatten(
|
||||
phs, self.fw_spec
|
||||
)
|
||||
init_node_set, xs_node_set, addi_node_set = (
|
||||
set(init_phs),
|
||||
set(xs_phs),
|
||||
set(additional_inputs_phs),
|
||||
)
|
||||
|
||||
assert len(self.forward_intermediates_handling_policies) == 0
|
||||
assert len(self.saved_fw_xs) == 0
|
||||
assert len(self.saved_fw_additional_inputs) == 0
|
||||
intermediate_idx_to_ph_idx = {}
|
||||
ph_idx = {ph: i for i, ph in enumerate(phs)}
|
||||
for i, out in enumerate(fw_intermediates):
|
||||
if out in init_node_set:
|
||||
self.forward_intermediates_handling_policies.append(
|
||||
ScanForwardIntermediatesHandlingPolicy.CLONE
|
||||
)
|
||||
intermediate_idx_to_ph_idx[i] = ph_idx[out]
|
||||
elif out in xs_node_set:
|
||||
self.forward_intermediates_handling_policies.append(
|
||||
ScanForwardIntermediatesHandlingPolicy.REMOVE_XS
|
||||
)
|
||||
intermediate_idx_to_ph_idx[i] = ph_idx[out]
|
||||
elif out in addi_node_set:
|
||||
self.forward_intermediates_handling_policies.append(
|
||||
ScanForwardIntermediatesHandlingPolicy.REMOVE_ADDITIONAL_INPUTS
|
||||
)
|
||||
intermediate_idx_to_ph_idx[i] = ph_idx[out]
|
||||
else:
|
||||
self.forward_intermediates_handling_policies.append(
|
||||
ScanForwardIntermediatesHandlingPolicy.KEEP
|
||||
)
|
||||
|
||||
new_output_node = []
|
||||
real_graph_inputs = (
|
||||
list(self.init) + list(self.xs) + list(self.additional_inputs)
|
||||
)
|
||||
fw_output_node = next(iter(fw_gm.graph.find_nodes(op="output")))
|
||||
for intermediate_idx, (node, policy) in enumerate(
|
||||
zip(fw_intermediates, self.forward_intermediates_handling_policies)
|
||||
):
|
||||
if policy == ScanForwardIntermediatesHandlingPolicy.CLONE:
|
||||
new_output_node.append(self._insert_clone(node, fw_output_node))
|
||||
elif policy == ScanForwardIntermediatesHandlingPolicy.REMOVE_XS:
|
||||
assert intermediate_idx in intermediate_idx_to_ph_idx
|
||||
inp_idx = intermediate_idx_to_ph_idx[intermediate_idx]
|
||||
self.saved_fw_xs.append(real_graph_inputs[inp_idx])
|
||||
elif (
|
||||
policy
|
||||
== ScanForwardIntermediatesHandlingPolicy.REMOVE_ADDITIONAL_INPUTS
|
||||
):
|
||||
assert intermediate_idx in intermediate_idx_to_ph_idx
|
||||
inp_idx = intermediate_idx_to_ph_idx[intermediate_idx]
|
||||
self.saved_fw_additional_inputs.append(real_graph_inputs[inp_idx])
|
||||
else:
|
||||
new_output_node.append(node)
|
||||
|
||||
fw_output_node.args = (tuple(fw_outputs) + tuple(new_output_node),)
|
||||
fw_gm.graph.lint()
|
||||
fw_gm.recompile()
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"after removing aliasing:\n%s", fw_gm.print_readable(print_output=False)
|
||||
)
|
||||
|
||||
def call_forward(self):
|
||||
fw_outputs_and_intermediates: tuple[Any] = scan_op(
|
||||
self.hop_partitioned_graph.fw_gm, self.init, self.xs, self.additional_inputs
|
||||
) # type: ignore[return-type]
|
||||
fw_outs = fw_outputs_and_intermediates[
|
||||
: self.hop_partitioned_graph.n_fw_outputs
|
||||
]
|
||||
saved_intermediates = fw_outputs_and_intermediates[
|
||||
self.hop_partitioned_graph.n_fw_outputs :
|
||||
]
|
||||
assert len(self.saved_intermediates) == 0
|
||||
self.saved_intermediates.extend(saved_intermediates)
|
||||
return tuple(fw_outs)
|
||||
|
||||
def call_backward(self, *grad_fw_outputs):
|
||||
"""
|
||||
Recall that fw_outputs = (*carry, *ys), bw_gm takes in (*fw_intermediates, *grad_carry, *grad_ys)
|
||||
and returns (*grad_init, *grad_xs, *grad_additional_inputs)
|
||||
The bacwkard is a reversed scan that can be constructed as follows:
|
||||
|
||||
grad_additonal_inputs = torch.zeros_like(additional_inputs)
|
||||
bw_init = (grad_carry, grad_additional_inputs)
|
||||
bw_xs = (fw_intermediates, grad_ys)
|
||||
grad_init, grad_additional_inputs, grad_xs = scan(
|
||||
combine_fn,
|
||||
bw_init,
|
||||
bw_xs,
|
||||
reverse = True
|
||||
)
|
||||
, where combine_fn is defined as follows:
|
||||
|
||||
def combine_fn(bw_init, bw_xs):
|
||||
grad_carry, grad_additional_inputs = bw_init
|
||||
fw_intermediates, grad_y = bw_xs
|
||||
nxt_grad_carry, grad_x, nxt_grad_additional_inputs = bw_gm(*fw_intermediates, *grad_carry, *grad_y)
|
||||
return (nxt_grad_carry, grad_additional_inputs + nxt_grad_additional_inputs), grad_x
|
||||
|
||||
Note that grad_additional_inputs is accumulated with add, grad_carry is carried over to next iteration and
|
||||
grad_x is the ys output, which will be stacked together after the loop and will have the same shape as xs.
|
||||
"""
|
||||
fw_policy = self.forward_intermediates_handling_policies
|
||||
saved_intermediates = self.saved_intermediates
|
||||
saved_fw_xs = self.saved_fw_xs
|
||||
saved_fw_additional_inputs = self.saved_fw_additional_inputs
|
||||
|
||||
n_carry = len(self.init)
|
||||
|
||||
grad_carry, grad_ys = grad_fw_outputs[:n_carry], grad_fw_outputs[n_carry:]
|
||||
additional_inputs_tensor_masks = [
|
||||
True if isinstance(t, torch.Tensor) else False
|
||||
for t in self.additional_inputs
|
||||
]
|
||||
grad_additional_inputs = [
|
||||
torch.zeros_like(t)
|
||||
for t in filter_with_masks(
|
||||
self.additional_inputs, additional_inputs_tensor_masks
|
||||
)
|
||||
]
|
||||
|
||||
bw_init = [grad_carry, grad_additional_inputs]
|
||||
bw_xs = [
|
||||
grad_ys,
|
||||
saved_fw_xs,
|
||||
saved_intermediates,
|
||||
]
|
||||
bw_additional_inputs = saved_fw_additional_inputs
|
||||
|
||||
_, flat_spec = pytree.tree_flatten((bw_init, bw_xs, bw_additional_inputs))
|
||||
|
||||
grad_spec = None
|
||||
|
||||
def bw_single_step_wrapper(*args):
|
||||
bw_init, bw_xs, bw_additional_inputs = pytree.tree_unflatten(
|
||||
args, flat_spec
|
||||
)
|
||||
grad_carry, grad_additional_inputs = bw_init
|
||||
grad_y, saved_fw_xs, saved_intermediates = bw_xs
|
||||
saved_fw_additional_inputs = bw_additional_inputs
|
||||
|
||||
fw_intermediates = []
|
||||
xs_it = iter(saved_fw_xs)
|
||||
carry_it = iter(saved_intermediates)
|
||||
addi_it = iter(saved_fw_additional_inputs)
|
||||
for policy in fw_policy:
|
||||
if policy in (
|
||||
ScanForwardIntermediatesHandlingPolicy.CLONE,
|
||||
ScanForwardIntermediatesHandlingPolicy.KEEP,
|
||||
):
|
||||
fw_intermediates.append(next(carry_it))
|
||||
elif policy == ScanForwardIntermediatesHandlingPolicy.REMOVE_XS:
|
||||
fw_intermediates.append(next(xs_it))
|
||||
elif (
|
||||
policy
|
||||
== ScanForwardIntermediatesHandlingPolicy.REMOVE_ADDITIONAL_INPUTS
|
||||
):
|
||||
fw_intermediates.append(next(addi_it))
|
||||
else:
|
||||
raise RuntimeError(f"Unknown policy: {policy}")
|
||||
|
||||
grad_fw_outputs = (*grad_carry, *grad_y)
|
||||
|
||||
flat_out = self.hop_partitioned_graph.bw_gm(
|
||||
*fw_intermediates,
|
||||
*grad_fw_outputs,
|
||||
)
|
||||
|
||||
next_grad_carry, grad_xs, grad_addi = split_into_chunks(
|
||||
flat_out, # type: ignore[arg-type]
|
||||
[len(self.init), len(self.xs), len(self.additional_inputs)],
|
||||
)
|
||||
|
||||
nonlocal grad_spec
|
||||
flat_grads, grad_spec = pytree.tree_flatten(
|
||||
(
|
||||
next_grad_carry,
|
||||
[
|
||||
prev + cur
|
||||
for prev, cur in zip(
|
||||
grad_additional_inputs,
|
||||
filter_with_masks(
|
||||
grad_addi, additional_inputs_tensor_masks
|
||||
),
|
||||
)
|
||||
],
|
||||
grad_xs,
|
||||
)
|
||||
)
|
||||
return flat_grads
|
||||
|
||||
single_step_bw_xs = pytree.tree_map(lambda t: t[0], bw_xs)
|
||||
bw_single_step_gm = materialize_as_graph(
|
||||
bw_single_step_wrapper,
|
||||
tuple(
|
||||
pytree.tree_flatten((bw_init, single_step_bw_xs, bw_additional_inputs))[
|
||||
0
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
flat_grads = scan_op(
|
||||
bw_single_step_gm,
|
||||
pytree.tree_flatten(bw_init)[0],
|
||||
# TODO: torch.flip copies the tensor, we should optimize it away
|
||||
[torch.flip(x, (0,)) for x in pytree.tree_flatten(bw_xs)[0]],
|
||||
pytree.tree_flatten(bw_additional_inputs)[0],
|
||||
)
|
||||
assert grad_spec is not None
|
||||
grad_init, grad_additional_inputs, grad_xs = pytree.tree_unflatten(
|
||||
flat_grads, grad_spec
|
||||
)
|
||||
return (
|
||||
*grad_init,
|
||||
*[torch.flip(elem, (0,)) for elem in grad_xs],
|
||||
*fill_none_with_masks(
|
||||
grad_additional_inputs, additional_inputs_tensor_masks
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@scan_op.py_autograd_impl
|
||||
def scan_autograd(combine_fn, init, xs, additional_inputs):
|
||||
num_leaves_init = len(init)
|
||||
num_leaves_xs = len(xs)
|
||||
num_additional_inputs = len(additional_inputs)
|
||||
with disable_proxy_modes_tracing():
|
||||
hop_partitioned_graph: HopPartitionedGraph = (
|
||||
HopGraphMinCutPartitioner.create_partitioned_graph(
|
||||
combine_fn,
|
||||
(*init, *[x[0] for x in xs], *additional_inputs),
|
||||
always_recompute_complex_exprs=True,
|
||||
)
|
||||
)
|
||||
|
||||
flat_out = ScanAutogradOp.apply(
|
||||
combine_fn,
|
||||
num_leaves_init,
|
||||
num_leaves_xs,
|
||||
num_additional_inputs,
|
||||
*(tuple(init) + tuple(xs) + additional_inputs),
|
||||
return ScanAutogradOp.apply(
|
||||
hop_partitioned_graph,
|
||||
len(init),
|
||||
len(xs),
|
||||
len(additional_inputs),
|
||||
*init,
|
||||
*xs,
|
||||
*additional_inputs,
|
||||
)
|
||||
return *flat_out[:num_leaves_init], *flat_out[num_leaves_init:]
|
||||
|
||||
|
||||
@scan_op.py_impl(ProxyTorchDispatchMode)
|
||||
|
|
|
|||
|
|
@ -505,9 +505,13 @@ def prepare_fw_with_masks_all_requires_grad(fn):
|
|||
lambda x: x.requires_grad_(True) if x.dtype.is_floating_point else x,
|
||||
fw_out,
|
||||
)
|
||||
return fw_out, pytree.tree_map_only(
|
||||
torch.Tensor, lambda x: x.requires_grad, fw_out
|
||||
)
|
||||
|
||||
def _query_requires_grad(t: torch.Tensor) -> bool:
|
||||
if torch._is_functional_tensor(t):
|
||||
t = torch._from_functional_tensor(t)
|
||||
return t.requires_grad
|
||||
|
||||
return fw_out, pytree.tree_map_only(torch.Tensor, _query_requires_grad, fw_out)
|
||||
|
||||
return fw_with_masks
|
||||
|
||||
|
|
@ -759,7 +763,9 @@ def _clone_aliasing_output(inputs: Sequence[Any], outputs: Sequence[Any]):
|
|||
return final_outputs
|
||||
|
||||
|
||||
def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
|
||||
def create_bw_fn(
|
||||
fn: Callable, args: tuple[Any, ...], return_fw_outputs: bool = False
|
||||
) -> Callable:
|
||||
"""
|
||||
For a fn that accepts flat inputs and returns flat outputs:
|
||||
fw_out = fn(*args),
|
||||
|
|
@ -793,7 +799,7 @@ def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
|
|||
def flat_fn(*args_and_grad_outs):
|
||||
primals = args_and_grad_outs[:n_primals]
|
||||
tangents = args_and_grad_outs[n_primals:]
|
||||
grad_args = bw_fn(primals, tangents)[1]
|
||||
fw_outs, grad_args = bw_fn(primals, tangents)
|
||||
assert len(args) == len(grad_args)
|
||||
|
||||
# For tensors whose grad is None, create zero tensors as gradients
|
||||
|
|
@ -806,6 +812,8 @@ def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
|
|||
]
|
||||
|
||||
final_grads = _clone_aliasing_output(args_and_grad_outs, grad_args)
|
||||
if return_fw_outputs:
|
||||
return *fw_outs, *final_grads
|
||||
return final_grads
|
||||
|
||||
return flat_fn
|
||||
|
|
@ -1111,7 +1119,7 @@ def call_op(op: Union[OpOverload, HopInstance], args, kwargs):
|
|||
|
||||
def materialize_as_graph(
|
||||
fn: Callable,
|
||||
args: tuple[Any],
|
||||
args: tuple[Any, ...],
|
||||
include_key_set: Optional[torch._C.DispatchKeySet] = None,
|
||||
exclude_key_set: Optional[torch._C.DispatchKeySet] = None,
|
||||
force_enable_grad=False,
|
||||
|
|
|
|||
|
|
@ -809,6 +809,12 @@ def scatter_upon_const_tensor(
|
|||
"""
|
||||
from torch._inductor import metrics
|
||||
|
||||
# Check if inputs are tensors instead of inductor IR nodes
|
||||
if isinstance(selector, torch.Tensor):
|
||||
# Return a fake tensor with the proper shape that this operator is intended to return
|
||||
device = selector.device if hasattr(selector, "device") else torch.device("cpu")
|
||||
return torch.empty(shape, dtype=dtype, device=device)
|
||||
|
||||
metrics.num_matches_for_scatter_upon_const_tensor += 1
|
||||
|
||||
selector_loader = selector.make_loader()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user