mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Don't guard data ptrs by default with mark_static_address (#162208)
Fixes https://github.com/pytorch/pytorch/issues/156377 Since we now re-record cudagraphs, it's not necessary to guard by default anymore and induce a full recompile. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162208 Approved by: https://github.com/anijain2305
This commit is contained in:
parent
6b59a19242
commit
75de5b65b4
|
|
@ -899,7 +899,7 @@ class CompiledOptimizerTests(TestCase):
|
|||
compiled = torch.compile(_get_value)
|
||||
|
||||
x = torch.ones(2, 2)
|
||||
mark_static_address(x)
|
||||
mark_static_address(x, guard=True)
|
||||
|
||||
ret_val = compiled(x)
|
||||
|
||||
|
|
|
|||
|
|
@ -752,12 +752,13 @@ def mark_static(
|
|||
|
||||
|
||||
@forbid_in_graph
|
||||
def mark_static_address(t: Any, guard: bool = True) -> None:
|
||||
def mark_static_address(t: Any, guard: bool = False) -> None:
|
||||
"""
|
||||
Marks an input tensor whose data_ptr will not change across multiple calls
|
||||
to a dynamo-compiled function. This indicates to cudagraphs that an extra allocation
|
||||
is not needed for this input. The data_ptr will be guarded if guard=True. Note:
|
||||
Tensors marked in this way will be kept alive until `torch._dynamo.reset()` is called.
|
||||
Marks an input tensor whose address should be treated as constant across calls to the
|
||||
same dynamo-compiled function. This indicates to cudagraphs that an extra allocation
|
||||
is not needed for this input. The data_ptr will be guarded if guard=True, and cause a full
|
||||
recompile if the data_ptr changes. Note: If this address changes, cudagraphs will re-record
|
||||
if guard=False.
|
||||
"""
|
||||
if not isinstance(t, torch.Tensor):
|
||||
raise TypeError(f"mark_static_address expects a tensor but received {type(t)}")
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||
|
||||
for group in self.value.param_groups:
|
||||
for p in group["params"]:
|
||||
mark_static_address(p)
|
||||
mark_static_address(p, guard=True)
|
||||
|
||||
self._set_capturable(tx)
|
||||
|
||||
|
|
@ -240,7 +240,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||
self.tensor_to_source = {}
|
||||
|
||||
def mark_static(x):
|
||||
mark_static_address(x)
|
||||
mark_static_address(x, guard=True)
|
||||
|
||||
tree_map_only(torch.Tensor, mark_static, self.value.state)
|
||||
|
||||
|
|
@ -348,14 +348,14 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||
|
||||
if tensor_value in self.tensor_to_source:
|
||||
# mark these tensors as static for cudagraphs
|
||||
mark_static_address(tensor_value)
|
||||
mark_static_address(tensor_value, guard=True)
|
||||
source = self.tensor_to_source[tensor_value]
|
||||
self.static_tensor_names.add(tx.output.module_key_name(source.name()))
|
||||
elif tensor_value in self.grad_to_source:
|
||||
source = self.grad_to_source[tensor_value]
|
||||
else:
|
||||
# mark these tensors as static for cudagraphs
|
||||
mark_static_address(tensor_value)
|
||||
mark_static_address(tensor_value, guard=True)
|
||||
|
||||
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
|
||||
source = GlobalWeakRefSource(global_name)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user