mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[invoke_subgraph] Filter out grad_out where fw_out requires_grad is False (#150486)
I am not sure if this is the right way. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150486 Approved by: https://github.com/zou3519 ghstack dependencies: #150082, #150450
This commit is contained in:
parent
82ceebce58
commit
42c7c7f15f
|
|
@ -559,23 +559,27 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
@mark_compile_region
|
||||
def gn(x):
|
||||
return mod(x)
|
||||
return torch.cos(x), mod(x)
|
||||
|
||||
def fn(x):
|
||||
return gn(x)
|
||||
out = gn(x)
|
||||
return out[0] + out[1]
|
||||
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
# requires_grad is False deliberately to force None the joint_graph
|
||||
# outputs
|
||||
x = torch.randn(8, 8, requires_grad=False)
|
||||
x_clone = x.detach().clone().requires_grad_(False)
|
||||
|
||||
ref = mod(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
ref = fn(x)
|
||||
res = opt_fn(x_clone)
|
||||
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
|
||||
def test_fail_with_direct_invoke_subgraph(self):
|
||||
from torch._higher_order_ops import invoke_subgraph
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -37,6 +38,15 @@ from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
|||
invoke_subgraph_counter = 0
|
||||
|
||||
|
||||
# During the tracing of the joint graph, we construct this information. This is
|
||||
# used to filter out grad_outs/tangents in the `backward` method of
|
||||
# InvokeSubgraphAutogradOp.
|
||||
@dataclass
|
||||
class FilterTangentInfo:
|
||||
indexes_with_none: set[int] = field(default_factory=set)
|
||||
indexes_with_no_grad: set[int] = field(default_factory=set)
|
||||
|
||||
|
||||
class InvokeSubgraphHOP(HigherOrderOperator):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("invoke_subgraph")
|
||||
|
|
@ -189,24 +199,34 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
|
|||
else fake_mode.shape_env.ignore_fresh_unbacked_symbols()
|
||||
)
|
||||
|
||||
if grad_outputs is None:
|
||||
# Infer grad_outputs to be the same properties as the fw_outputs
|
||||
# if they're not passed in
|
||||
with context:
|
||||
grad_outputs = pytree.tree_map(_from_fun, subgraph(*fw_inputs))
|
||||
with context:
|
||||
fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs))
|
||||
|
||||
num_fw_outs = len(grad_outputs)
|
||||
num_fw_outs = len(fw_outs)
|
||||
|
||||
# Collect the indexes of none in the output to check that the grad
|
||||
# is None at the corresponding index in the backward. This check is
|
||||
# performed in the autograd.Function - InvokeSubgraphAutogradOp.
|
||||
none_indexes_in_fwd_out = set()
|
||||
# Also collect the indexes of no_grad in the output to filter out
|
||||
# the grad_outs in the `backward` method.
|
||||
filter_tangent_info = FilterTangentInfo()
|
||||
|
||||
for idx, grad in enumerate(grad_outputs):
|
||||
if grad is None:
|
||||
none_indexes_in_fwd_out.add(idx)
|
||||
for idx, fw_out in enumerate(fw_outs):
|
||||
if fw_out is None:
|
||||
filter_tangent_info.indexes_with_none.add(idx)
|
||||
elif not fw_out.requires_grad:
|
||||
filter_tangent_info.indexes_with_no_grad.add(idx)
|
||||
|
||||
grad_outputs = [grad for grad in grad_outputs if grad is not None]
|
||||
if grad_outputs is None:
|
||||
# Infer grad_outputs to be the same properties as the fw_outputs
|
||||
# if they're not passed in
|
||||
# Although fw_outs are equivalent to grad_outputs for tracing
|
||||
# purposes, we have to carefully handle the None and fw_out that do
|
||||
# not have require_grad. At those indexes, we will have None in the
|
||||
# backward graph.
|
||||
grad_outputs = fw_outs
|
||||
grad_outputs = [grad for grad in grad_outputs if grad is not None]
|
||||
grad_outputs = [grad for grad in grad_outputs if grad.requires_grad]
|
||||
|
||||
if any(
|
||||
not isinstance(out, torch.Tensor)
|
||||
|
|
@ -227,7 +247,7 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
|
|||
fw_inputs,
|
||||
grad_outputs,
|
||||
)
|
||||
return fw_graph, bw_graph, num_fw_outs, none_indexes_in_fwd_out
|
||||
return fw_graph, bw_graph, num_fw_outs, filter_tangent_info
|
||||
|
||||
|
||||
class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
||||
|
|
@ -243,14 +263,14 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
|||
bw_graph,
|
||||
identifier,
|
||||
num_fw_outs,
|
||||
none_indexes_in_fwd_out,
|
||||
filter_tangent_info,
|
||||
*operands,
|
||||
):
|
||||
ctx._fw_graph = fw_graph
|
||||
ctx._bw_graph = bw_graph
|
||||
ctx._identifier = identifier
|
||||
ctx._num_fw_outs = num_fw_outs
|
||||
ctx._none_indexes_in_fwd_out = none_indexes_in_fwd_out
|
||||
ctx._filter_tangent_info = filter_tangent_info
|
||||
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
out = invoke_subgraph(
|
||||
|
|
@ -264,7 +284,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
|||
# Check that None is at expected indexes.
|
||||
for idx, o in enumerate(out):
|
||||
if o is None:
|
||||
assert idx in none_indexes_in_fwd_out
|
||||
assert idx in filter_tangent_info.indexes_with_none
|
||||
|
||||
return out
|
||||
|
||||
|
|
@ -274,18 +294,22 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
|||
identifier = ctx._identifier
|
||||
primals = saved_tensors_and_symints(ctx)
|
||||
num_fw_outs = ctx._num_fw_outs
|
||||
none_indexes_in_fwd_out = ctx._none_indexes_in_fwd_out
|
||||
filter_tangent_info = ctx._filter_tangent_info
|
||||
|
||||
# While tracing we made the assumption that tangents are contiguous. So,
|
||||
# force the grad_outs to be contiguous. Some of the grads can be None,
|
||||
# because the forward outs could be None. Filter them out.
|
||||
# force the grad_outs to be contiguous.
|
||||
# Also filter out grads that are None or do not require_grad. This was
|
||||
# the assumption we made during the tracing of joint_graph.
|
||||
contiguous_grad_outs = []
|
||||
for idx, o in enumerate(grad_outs):
|
||||
if o is not None:
|
||||
contiguous_grad_outs.append(o.contiguous())
|
||||
if o is None:
|
||||
assert idx in filter_tangent_info.indexes_with_none
|
||||
elif idx in filter_tangent_info.indexes_with_no_grad:
|
||||
# Deliberately skip over the grad_outs which we know should be
|
||||
# None because the corresponding fwd_out does not require_grad.
|
||||
pass
|
||||
else:
|
||||
# Check that None is at expected indexes.
|
||||
assert idx in none_indexes_in_fwd_out
|
||||
contiguous_grad_outs.append(o.contiguous())
|
||||
contiguous_grad_outs = tuple(contiguous_grad_outs)
|
||||
|
||||
# bw_graph is a joint graph with signature (*primals_and_tangents) and
|
||||
|
|
@ -331,13 +355,13 @@ def _(subgraph, identifier, operands):
|
|||
):
|
||||
return saved_autograd_fn(*operands)
|
||||
|
||||
fw_graph, bw_graph, num_fw_outs, none_indexes_in_fwd_out = create_fw_bw_graph(
|
||||
fw_graph, bw_graph, num_fw_outs, filter_tangent_info = create_fw_bw_graph(
|
||||
subgraph, operands
|
||||
)
|
||||
|
||||
def autograd_fn_callable(*args):
|
||||
return InvokeSubgraphAutogradOp.apply(
|
||||
fw_graph, bw_graph, identifier, num_fw_outs, none_indexes_in_fwd_out, *args
|
||||
fw_graph, bw_graph, identifier, num_fw_outs, filter_tangent_info, *args
|
||||
)
|
||||
|
||||
# Save the autograd_fn_callable in the dispatch set cache.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user