mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Provide helper functions for guard filter hook (#155083)
Collection of ready-made guard filters. One issue is that they are not composable - `filter1(filter2(guard))`. On the other hand, they are easy to use. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155083 Approved by: https://github.com/zhxchen17, https://github.com/jansel
This commit is contained in:
parent
0935a97d95
commit
54976bca10
|
|
@ -25,4 +25,8 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
|
|||
is_compiling
|
||||
is_dynamo_compiling
|
||||
is_exporting
|
||||
```
|
||||
skip_guard_on_inbuilt_nn_modules_unsafe
|
||||
skip_guard_on_all_nn_modules_unsafe
|
||||
keep_tensor_guards_unsafe
|
||||
skip_guard_on_globals_unsafe
|
||||
```
|
||||
|
|
|
|||
|
|
@ -12462,6 +12462,111 @@ fn
|
|||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
self.assertEqual(fn(*inputs), inputs[0])
|
||||
|
||||
def test_guard_filter_inbuilt_nn_modules(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.norm = torch.nn.LayerNorm(8)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(x)
|
||||
|
||||
mod = Mod()
|
||||
opt_mod = torch.compile(
|
||||
mod,
|
||||
options={
|
||||
"guard_filter_fn": torch.compiler.skip_guard_on_inbuilt_nn_modules_unsafe
|
||||
},
|
||||
)
|
||||
|
||||
x = torch.rand(4, 8)
|
||||
opt_mod(x)
|
||||
|
||||
mod.norm.eps = 1e-02
|
||||
# Since the guards are skipped on inbuilt nn modules, we should not recompile
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
opt_mod(x)
|
||||
|
||||
def test_guard_filter_nn_modules(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = 2
|
||||
self.norm = torch.nn.LayerNorm(8)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(x) + self.c
|
||||
|
||||
mod = Mod()
|
||||
opt_mod = torch.compile(
|
||||
mod,
|
||||
options={
|
||||
"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe
|
||||
},
|
||||
)
|
||||
|
||||
x = torch.rand(4, 8)
|
||||
opt_mod(x)
|
||||
|
||||
mod.c = 3
|
||||
# Since the guards are skipped on all nn modules, we should not recompile
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
opt_mod(x)
|
||||
|
||||
def test_guard_filter_tensors(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = 2.0
|
||||
self.norm = torch.nn.LayerNorm(8)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(x) + self.c
|
||||
|
||||
mod = Mod()
|
||||
opt_mod = torch.compile(
|
||||
mod,
|
||||
options={
|
||||
"guard_filter_fn": torch.compiler.keep_tensor_guards_unsafe,
|
||||
},
|
||||
)
|
||||
|
||||
x = torch.rand(4, 8)
|
||||
opt_mod(x)
|
||||
|
||||
mod.c = 3.0
|
||||
# Since the guards are skipped on all tensors
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
opt_mod(x)
|
||||
|
||||
def test_guard_filter_globals(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = 2
|
||||
self.norm = torch.nn.LayerNorm(8)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(x) + self.c + GLOBAL_INT
|
||||
|
||||
mod = Mod()
|
||||
opt_mod = torch.compile(
|
||||
mod,
|
||||
options={
|
||||
"guard_filter_fn": torch.compiler.skip_guard_on_globals_unsafe,
|
||||
},
|
||||
)
|
||||
|
||||
global GLOBAL_INT
|
||||
GLOBAL_INT = 1
|
||||
x = torch.rand(4, 8)
|
||||
opt_mod(x)
|
||||
|
||||
GLOBAL_INT = 2
|
||||
# Since the guards are skipped on globals, we should not recompile
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
opt_mod(x)
|
||||
|
||||
|
||||
class TestTracer(JitTestCase):
|
||||
def test_jit_save(self):
|
||||
|
|
|
|||
|
|
@ -28,6 +28,10 @@ __all__ = [
|
|||
"is_exporting",
|
||||
"save_cache_artifacts",
|
||||
"load_cache_artifacts",
|
||||
"skip_guard_on_inbuilt_nn_modules_unsafe",
|
||||
"skip_guard_on_all_nn_modules_unsafe",
|
||||
"keep_tensor_guards_unsafe",
|
||||
"skip_guard_on_globals_unsafe",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -477,3 +481,88 @@ def load_cache_artifacts(serialized_artifacts: bytes) -> Optional["CacheInfo"]:
|
|||
if artifacts is not None:
|
||||
return CacheArtifactManager.populate_caches(artifacts)
|
||||
return None
|
||||
|
||||
|
||||
def skip_guard_on_inbuilt_nn_modules_unsafe(guard_entries):
|
||||
"""
|
||||
A common function to skip guards on the inbuilt nn modules like
|
||||
torch.nn.Linear. This is unsafe to use by default. But for majority of
|
||||
torch.compile users, the model code does not modify the inbuilt nn module
|
||||
attributes. They can benefit from reduction in guard latency overhead using
|
||||
this API.
|
||||
|
||||
To use this API, use guard_filter_fn argument while calling torch.compile
|
||||
|
||||
>> opt_mod = torch.compile(
|
||||
>> mod,
|
||||
>> options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe},
|
||||
>> )
|
||||
"""
|
||||
return [
|
||||
not entry.orig_guard.source.is_unspecialized_builtin_nn_module()
|
||||
for entry in guard_entries
|
||||
]
|
||||
|
||||
|
||||
def skip_guard_on_all_nn_modules_unsafe(guard_entries):
|
||||
"""
|
||||
A common function to skip guards on all nn modules, both user defined as
|
||||
well inbuilt nn modules (like torch.nn.Linear). This is unsafe to use by
|
||||
default. But for majority of torch.compile users, the model code does not
|
||||
modify the nn module attributes. They can benefit from reduction in guard
|
||||
latency overhead using this API.
|
||||
|
||||
To use this API, use guard_filter_fn argument while calling torch.compile
|
||||
|
||||
>> opt_mod = torch.compile(
|
||||
>> mod,
|
||||
>> options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe},
|
||||
>> )
|
||||
"""
|
||||
|
||||
return [
|
||||
not entry.orig_guard.source.is_unspecialized_nn_module()
|
||||
for entry in guard_entries
|
||||
]
|
||||
|
||||
|
||||
def keep_tensor_guards_unsafe(guard_entries, keep_parameters=False):
|
||||
"""
|
||||
A common function to keep tensor guards on all tensors. This is unsafe to
|
||||
use by default. But if you don't expect any changes in the model code, you
|
||||
can just keep the tensor guards.
|
||||
|
||||
|
||||
>> opt_mod = torch.compile(
|
||||
>> mod,
|
||||
>> options={"guard_filter_fn": torch.compiler.keep_tensor_guards},
|
||||
>> )
|
||||
"""
|
||||
|
||||
keep_flags = []
|
||||
for entry in guard_entries:
|
||||
if entry.guard_type == "TENSOR_MATCH":
|
||||
if not isinstance(entry.value, torch.nn.Parameter):
|
||||
keep_flags.append(True)
|
||||
elif keep_parameters:
|
||||
keep_flags.append(True)
|
||||
else:
|
||||
keep_flags.append(False)
|
||||
else:
|
||||
keep_flags.append(False)
|
||||
return keep_flags
|
||||
|
||||
|
||||
def skip_guard_on_globals_unsafe(guard_entries):
|
||||
"""
|
||||
A common function to skip guards on all globals. This is unsafe to use by
|
||||
default. But if you don't expect any changes in the globals, you can just
|
||||
keep the tensor guards.
|
||||
|
||||
>> opt_mod = torch.compile(
|
||||
>> mod,
|
||||
>> options={"guard_filter_fn": torch.compiler.skip_guard_on_globals},
|
||||
>> )
|
||||
"""
|
||||
|
||||
return [not entry.is_global for entry in guard_entries]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user