[invoke_subgraph] Filter out grad_out where fw_out requires_grad is False (#150486)

I am not sure if this is the right way.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150486
Approved by: https://github.com/zou3519
ghstack dependencies: #150082, #150450
This commit is contained in:
Animesh Jain 2025-04-01 21:57:25 -07:00 committed by PyTorch MergeBot
parent 82ceebce58
commit 42c7c7f15f
2 changed files with 57 additions and 29 deletions

View File

@ -559,23 +559,27 @@ class GraphModule(torch.nn.Module):
@mark_compile_region
def gn(x):
return mod(x)
return torch.cos(x), mod(x)
def fn(x):
return gn(x)
out = gn(x)
return out[0] + out[1]
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
# requires_grad is False deliberately to force None the joint_graph
# outputs
x = torch.randn(8, 8, requires_grad=False)
x_clone = x.detach().clone().requires_grad_(False)
ref = mod(x)
res = opt_fn(x)
self.assertEqual(ref, res)
ref = fn(x)
res = opt_fn(x_clone)
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
def test_fail_with_direct_invoke_subgraph(self):
from torch._higher_order_ops import invoke_subgraph

View File

@ -2,6 +2,7 @@
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Optional, Union
import torch
@ -37,6 +38,15 @@ from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
invoke_subgraph_counter = 0
# During the tracing of the joint graph, we construct this information. This is
# used to filter out grad_outs/tangents in the `backward` method of
# InvokeSubgraphAutogradOp.
@dataclass
class FilterTangentInfo:
indexes_with_none: set[int] = field(default_factory=set)
indexes_with_no_grad: set[int] = field(default_factory=set)
class InvokeSubgraphHOP(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("invoke_subgraph")
@ -189,24 +199,34 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
else fake_mode.shape_env.ignore_fresh_unbacked_symbols()
)
if grad_outputs is None:
# Infer grad_outputs to be the same properties as the fw_outputs
# if they're not passed in
with context:
grad_outputs = pytree.tree_map(_from_fun, subgraph(*fw_inputs))
with context:
fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs))
num_fw_outs = len(grad_outputs)
num_fw_outs = len(fw_outs)
# Collect the indexes of none in the output to check that the grad
# is None at the corresponding index in the backward. This check is
# performed in the autograd.Function - InvokeSubgraphAutogradOp.
none_indexes_in_fwd_out = set()
# Also collect the indexes of no_grad in the output to filter out
# the grad_outs in the `backward` method.
filter_tangent_info = FilterTangentInfo()
for idx, grad in enumerate(grad_outputs):
if grad is None:
none_indexes_in_fwd_out.add(idx)
for idx, fw_out in enumerate(fw_outs):
if fw_out is None:
filter_tangent_info.indexes_with_none.add(idx)
elif not fw_out.requires_grad:
filter_tangent_info.indexes_with_no_grad.add(idx)
grad_outputs = [grad for grad in grad_outputs if grad is not None]
if grad_outputs is None:
# Infer grad_outputs to be the same properties as the fw_outputs
# if they're not passed in
# Although fw_outs are equivalent to grad_outputs for tracing
# purposes, we have to carefully handle the None and fw_out that do
# not have require_grad. At those indexes, we will have None in the
# backward graph.
grad_outputs = fw_outs
grad_outputs = [grad for grad in grad_outputs if grad is not None]
grad_outputs = [grad for grad in grad_outputs if grad.requires_grad]
if any(
not isinstance(out, torch.Tensor)
@ -227,7 +247,7 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
fw_inputs,
grad_outputs,
)
return fw_graph, bw_graph, num_fw_outs, none_indexes_in_fwd_out
return fw_graph, bw_graph, num_fw_outs, filter_tangent_info
class InvokeSubgraphAutogradOp(torch.autograd.Function):
@ -243,14 +263,14 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
bw_graph,
identifier,
num_fw_outs,
none_indexes_in_fwd_out,
filter_tangent_info,
*operands,
):
ctx._fw_graph = fw_graph
ctx._bw_graph = bw_graph
ctx._identifier = identifier
ctx._num_fw_outs = num_fw_outs
ctx._none_indexes_in_fwd_out = none_indexes_in_fwd_out
ctx._filter_tangent_info = filter_tangent_info
with torch._C._AutoDispatchBelowAutograd():
out = invoke_subgraph(
@ -264,7 +284,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
# Check that None is at expected indexes.
for idx, o in enumerate(out):
if o is None:
assert idx in none_indexes_in_fwd_out
assert idx in filter_tangent_info.indexes_with_none
return out
@ -274,18 +294,22 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
identifier = ctx._identifier
primals = saved_tensors_and_symints(ctx)
num_fw_outs = ctx._num_fw_outs
none_indexes_in_fwd_out = ctx._none_indexes_in_fwd_out
filter_tangent_info = ctx._filter_tangent_info
# While tracing we made the assumption that tangents are contiguous. So,
# force the grad_outs to be contiguous. Some of the grads can be None,
# because the forward outs could be None. Filter them out.
# force the grad_outs to be contiguous.
# Also filter out grads that are None or do not require_grad. This was
# the assumption we made during the tracing of joint_graph.
contiguous_grad_outs = []
for idx, o in enumerate(grad_outs):
if o is not None:
contiguous_grad_outs.append(o.contiguous())
if o is None:
assert idx in filter_tangent_info.indexes_with_none
elif idx in filter_tangent_info.indexes_with_no_grad:
# Deliberately skip over the grad_outs which we know should be
# None because the corresponding fwd_out does not require_grad.
pass
else:
# Check that None is at expected indexes.
assert idx in none_indexes_in_fwd_out
contiguous_grad_outs.append(o.contiguous())
contiguous_grad_outs = tuple(contiguous_grad_outs)
# bw_graph is a joint graph with signature (*primals_and_tangents) and
@ -331,13 +355,13 @@ def _(subgraph, identifier, operands):
):
return saved_autograd_fn(*operands)
fw_graph, bw_graph, num_fw_outs, none_indexes_in_fwd_out = create_fw_bw_graph(
fw_graph, bw_graph, num_fw_outs, filter_tangent_info = create_fw_bw_graph(
subgraph, operands
)
def autograd_fn_callable(*args):
return InvokeSubgraphAutogradOp.apply(
fw_graph, bw_graph, identifier, num_fw_outs, none_indexes_in_fwd_out, *args
fw_graph, bw_graph, identifier, num_fw_outs, filter_tangent_info, *args
)
# Save the autograd_fn_callable in the dispatch set cache.