mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[invoke_subgraph][partitioner] Add meta val on run_and_save_rng ops (#157319)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157319 Approved by: https://github.com/zou3519
This commit is contained in:
parent
e5f6ffd810
commit
22edb457c9
|
|
@ -1786,6 +1786,43 @@ class GraphModule(torch.nn.Module):
|
|||
res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
def test_ac_rng(self):
|
||||
def fn1(x):
|
||||
return torch.cos(torch.nn.functional.dropout(x, p=0.5))
|
||||
|
||||
@nested_compile_region
|
||||
def fn1_checkpoint(x):
|
||||
return torch.utils.checkpoint.checkpoint(fn1, x, use_reentrant=False)
|
||||
|
||||
def fn(x):
|
||||
return fn1_checkpoint(x) + fn1_checkpoint(x)
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
torch.manual_seed(0)
|
||||
ref = fn(x)
|
||||
ref.sum().backward()
|
||||
|
||||
x_clone = x.clone().detach().requires_grad_(True)
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
|
||||
torch.manual_seed(0)
|
||||
res = torch.compile(fn, backend=backend, fullgraph=True)(x_clone)
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
|
||||
# Check that the Dynamo and AOT graphs have just one subgraph module
|
||||
self.assertEqual(len(backend.graphs), 1)
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
torch.manual_seed(0)
|
||||
res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone)
|
||||
self.assertEqual(ref, res)
|
||||
res.sum().backward()
|
||||
|
||||
def test_fake_tensor_checking(self):
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
|
|
|
|||
|
|
@ -1317,9 +1317,14 @@ def functionalize_rng_ops(
|
|||
return torch.device("cpu")
|
||||
|
||||
def get_sample_rng_state(device: Optional[torch.device]):
|
||||
if device is not None and device.type == "cuda":
|
||||
return torch.cuda.get_rng_state()
|
||||
return torch.get_rng_state()
|
||||
from torch._guards import detect_fake_mode # noqa: F401
|
||||
|
||||
fake_mode = detect_fake_mode()
|
||||
assert fake_mode is not None
|
||||
with fake_mode:
|
||||
if device is not None and device.type == "cuda":
|
||||
return fake_mode.from_tensor(torch.cuda.get_rng_state())
|
||||
return fake_mode.from_tensor(torch.get_rng_state())
|
||||
|
||||
# Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd.
|
||||
joint_graph_rng_ops = get_rng_ops(joint_module)
|
||||
|
|
@ -1414,6 +1419,8 @@ def functionalize_rng_ops(
|
|||
args=(functional_fw_node, 0),
|
||||
kwargs={},
|
||||
)
|
||||
state.meta["val"] = get_sample_rng_state(device)
|
||||
|
||||
rng_output = fw_graph.create_node(
|
||||
"call_function",
|
||||
operator.getitem,
|
||||
|
|
@ -1423,6 +1430,9 @@ def functionalize_rng_ops(
|
|||
),
|
||||
kwargs={},
|
||||
)
|
||||
# Copy the meta data from the original node
|
||||
rng_output.meta = copy.copy(fw_node.meta)
|
||||
|
||||
fw_node.replace_all_uses_with(rng_output)
|
||||
fw_graph.erase_node(fw_node)
|
||||
fw_rng_state_outputs.append(state)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user