[Dynamo] Don't guard data ptrs by default with mark_static_address (#162208)

Fixes https://github.com/pytorch/pytorch/issues/156377

Since we now re-record cudagraphs, it's not necessary to guard by default anymore and induce a full recompile.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162208
Approved by: https://github.com/anijain2305
This commit is contained in:
Michael Lazos 2025-09-12 07:15:10 +00:00 committed by PyTorch MergeBot
parent 6b59a19242
commit 75de5b65b4
3 changed files with 11 additions and 10 deletions

View File

@ -899,7 +899,7 @@ class CompiledOptimizerTests(TestCase):
compiled = torch.compile(_get_value)
x = torch.ones(2, 2)
mark_static_address(x)
mark_static_address(x, guard=True)
ret_val = compiled(x)

View File

@ -752,12 +752,13 @@ def mark_static(
@forbid_in_graph
def mark_static_address(t: Any, guard: bool = True) -> None:
def mark_static_address(t: Any, guard: bool = False) -> None:
"""
Marks an input tensor whose data_ptr will not change across multiple calls
to a dynamo-compiled function. This indicates to cudagraphs that an extra allocation
is not needed for this input. The data_ptr will be guarded if guard=True. Note:
Tensors marked in this way will be kept alive until `torch._dynamo.reset()` is called.
Marks an input tensor whose address should be treated as constant across calls to the
same dynamo-compiled function. This indicates to cudagraphs that an extra allocation
is not needed for this input. The data_ptr will be guarded if guard=True, and cause a full
recompile if the data_ptr changes. Note: If this address changes, cudagraphs will re-record
if guard=False.
"""
if not isinstance(t, torch.Tensor):
raise TypeError(f"mark_static_address expects a tensor but received {type(t)}")

View File

@ -147,7 +147,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
for group in self.value.param_groups:
for p in group["params"]:
mark_static_address(p)
mark_static_address(p, guard=True)
self._set_capturable(tx)
@ -240,7 +240,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
self.tensor_to_source = {}
def mark_static(x):
mark_static_address(x)
mark_static_address(x, guard=True)
tree_map_only(torch.Tensor, mark_static, self.value.state)
@ -348,14 +348,14 @@ class OptimizerVariable(UserDefinedObjectVariable):
if tensor_value in self.tensor_to_source:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value)
mark_static_address(tensor_value, guard=True)
source = self.tensor_to_source[tensor_value]
self.static_tensor_names.add(tx.output.module_key_name(source.name()))
elif tensor_value in self.grad_to_source:
source = self.grad_to_source[tensor_value]
else:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value)
mark_static_address(tensor_value, guard=True)
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
source = GlobalWeakRefSource(global_name)