[invoke_subgraph][partitioner] Add meta val on run_and_save_rng ops (#157319)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157319
Approved by: https://github.com/zou3519
This commit is contained in:
Animesh Jain 2025-07-01 10:07:26 -07:00 committed by PyTorch MergeBot
parent e5f6ffd810
commit 22edb457c9
2 changed files with 50 additions and 3 deletions

View File

@ -1786,6 +1786,43 @@ class GraphModule(torch.nn.Module):
res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone)
self.assertEqual(ref, res)
@torch._inductor.config.patch(fallback_random=True)
def test_ac_rng(self):
def fn1(x):
return torch.cos(torch.nn.functional.dropout(x, p=0.5))
@nested_compile_region
def fn1_checkpoint(x):
return torch.utils.checkpoint.checkpoint(fn1, x, use_reentrant=False)
def fn(x):
return fn1_checkpoint(x) + fn1_checkpoint(x)
x = torch.randn(8, requires_grad=True)
torch.manual_seed(0)
ref = fn(x)
ref.sum().backward()
x_clone = x.clone().detach().requires_grad_(True)
backend = AotEagerAndRecordGraphs()
torch.manual_seed(0)
res = torch.compile(fn, backend=backend, fullgraph=True)(x_clone)
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
# Check that the Dynamo and AOT graphs have just one subgraph module
self.assertEqual(len(backend.graphs), 1)
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
torch.manual_seed(0)
res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone)
self.assertEqual(ref, res)
res.sum().backward()
def test_fake_tensor_checking(self):
@nested_compile_region
def gn(x):

View File

@ -1317,9 +1317,14 @@ def functionalize_rng_ops(
return torch.device("cpu")
def get_sample_rng_state(device: Optional[torch.device]):
if device is not None and device.type == "cuda":
return torch.cuda.get_rng_state()
return torch.get_rng_state()
from torch._guards import detect_fake_mode # noqa: F401
fake_mode = detect_fake_mode()
assert fake_mode is not None
with fake_mode:
if device is not None and device.type == "cuda":
return fake_mode.from_tensor(torch.cuda.get_rng_state())
return fake_mode.from_tensor(torch.get_rng_state())
# Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd.
joint_graph_rng_ops = get_rng_ops(joint_module)
@ -1414,6 +1419,8 @@ def functionalize_rng_ops(
args=(functional_fw_node, 0),
kwargs={},
)
state.meta["val"] = get_sample_rng_state(device)
rng_output = fw_graph.create_node(
"call_function",
operator.getitem,
@ -1423,6 +1430,9 @@ def functionalize_rng_ops(
),
kwargs={},
)
# Copy the meta data from the original node
rng_output.meta = copy.copy(fw_node.meta)
fw_node.replace_all_uses_with(rng_output)
fw_graph.erase_node(fw_node)
fw_rng_state_outputs.append(state)