[dynamo] Provide helper functions for guard filter hook (#155083)

Collection of ready-made guard filters. One issue is that they are not composable - `filter1(filter2(guard))`. On the other hand, they are easy to use.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155083
Approved by: https://github.com/zhxchen17, https://github.com/jansel
This commit is contained in:
Animesh Jain 2025-06-15 07:37:14 -07:00 committed by PyTorch MergeBot
parent 0935a97d95
commit 54976bca10
3 changed files with 199 additions and 1 deletions

View File

@ -25,4 +25,8 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
is_compiling
is_dynamo_compiling
is_exporting
```
skip_guard_on_inbuilt_nn_modules_unsafe
skip_guard_on_all_nn_modules_unsafe
keep_tensor_guards_unsafe
skip_guard_on_globals_unsafe
```

View File

@ -12462,6 +12462,111 @@ fn
with torch.compiler.set_stance("fail_on_recompile"):
self.assertEqual(fn(*inputs), inputs[0])
def test_guard_filter_inbuilt_nn_modules(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.norm = torch.nn.LayerNorm(8)
def forward(self, x):
return self.norm(x)
mod = Mod()
opt_mod = torch.compile(
mod,
options={
"guard_filter_fn": torch.compiler.skip_guard_on_inbuilt_nn_modules_unsafe
},
)
x = torch.rand(4, 8)
opt_mod(x)
mod.norm.eps = 1e-02
# Since the guards are skipped on inbuilt nn modules, we should not recompile
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
opt_mod(x)
def test_guard_filter_nn_modules(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.c = 2
self.norm = torch.nn.LayerNorm(8)
def forward(self, x):
return self.norm(x) + self.c
mod = Mod()
opt_mod = torch.compile(
mod,
options={
"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe
},
)
x = torch.rand(4, 8)
opt_mod(x)
mod.c = 3
# Since the guards are skipped on all nn modules, we should not recompile
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
opt_mod(x)
def test_guard_filter_tensors(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.c = 2.0
self.norm = torch.nn.LayerNorm(8)
def forward(self, x):
return self.norm(x) + self.c
mod = Mod()
opt_mod = torch.compile(
mod,
options={
"guard_filter_fn": torch.compiler.keep_tensor_guards_unsafe,
},
)
x = torch.rand(4, 8)
opt_mod(x)
mod.c = 3.0
# Since the guards are skipped on all tensors
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
opt_mod(x)
def test_guard_filter_globals(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.c = 2
self.norm = torch.nn.LayerNorm(8)
def forward(self, x):
return self.norm(x) + self.c + GLOBAL_INT
mod = Mod()
opt_mod = torch.compile(
mod,
options={
"guard_filter_fn": torch.compiler.skip_guard_on_globals_unsafe,
},
)
global GLOBAL_INT
GLOBAL_INT = 1
x = torch.rand(4, 8)
opt_mod(x)
GLOBAL_INT = 2
# Since the guards are skipped on globals, we should not recompile
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
opt_mod(x)
class TestTracer(JitTestCase):
def test_jit_save(self):

View File

@ -28,6 +28,10 @@ __all__ = [
"is_exporting",
"save_cache_artifacts",
"load_cache_artifacts",
"skip_guard_on_inbuilt_nn_modules_unsafe",
"skip_guard_on_all_nn_modules_unsafe",
"keep_tensor_guards_unsafe",
"skip_guard_on_globals_unsafe",
]
@ -477,3 +481,88 @@ def load_cache_artifacts(serialized_artifacts: bytes) -> Optional["CacheInfo"]:
if artifacts is not None:
return CacheArtifactManager.populate_caches(artifacts)
return None
def skip_guard_on_inbuilt_nn_modules_unsafe(guard_entries):
"""
A common function to skip guards on the inbuilt nn modules like
torch.nn.Linear. This is unsafe to use by default. But for majority of
torch.compile users, the model code does not modify the inbuilt nn module
attributes. They can benefit from reduction in guard latency overhead using
this API.
To use this API, use guard_filter_fn argument while calling torch.compile
>> opt_mod = torch.compile(
>> mod,
>> options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe},
>> )
"""
return [
not entry.orig_guard.source.is_unspecialized_builtin_nn_module()
for entry in guard_entries
]
def skip_guard_on_all_nn_modules_unsafe(guard_entries):
"""
A common function to skip guards on all nn modules, both user defined as
well inbuilt nn modules (like torch.nn.Linear). This is unsafe to use by
default. But for majority of torch.compile users, the model code does not
modify the nn module attributes. They can benefit from reduction in guard
latency overhead using this API.
To use this API, use guard_filter_fn argument while calling torch.compile
>> opt_mod = torch.compile(
>> mod,
>> options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe},
>> )
"""
return [
not entry.orig_guard.source.is_unspecialized_nn_module()
for entry in guard_entries
]
def keep_tensor_guards_unsafe(guard_entries, keep_parameters=False):
"""
A common function to keep tensor guards on all tensors. This is unsafe to
use by default. But if you don't expect any changes in the model code, you
can just keep the tensor guards.
>> opt_mod = torch.compile(
>> mod,
>> options={"guard_filter_fn": torch.compiler.keep_tensor_guards},
>> )
"""
keep_flags = []
for entry in guard_entries:
if entry.guard_type == "TENSOR_MATCH":
if not isinstance(entry.value, torch.nn.Parameter):
keep_flags.append(True)
elif keep_parameters:
keep_flags.append(True)
else:
keep_flags.append(False)
else:
keep_flags.append(False)
return keep_flags
def skip_guard_on_globals_unsafe(guard_entries):
"""
A common function to skip guards on all globals. This is unsafe to use by
default. But if you don't expect any changes in the globals, you can just
keep the tensor guards.
>> opt_mod = torch.compile(
>> mod,
>> options={"guard_filter_fn": torch.compiler.skip_guard_on_globals},
>> )
"""
return [not entry.is_global for entry in guard_entries]