mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Tweak dynamic=False behavior (#105715)
Previously, dynamic=False is a no-op, and dynamic=True preemptively turns on dynamic shapes everywhere. Now, dynamic=False *disables* automatic dynamic, and an unset dynamic defaults to dynamic=None (which uses automatic dynamic.) This seems to be more intuitive per https://github.com/pytorch/pytorch/issues/105634#issuecomment-1644883477 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/105715 Approved by: https://github.com/voznesenskym
This commit is contained in:
parent
0ab74044c2
commit
3045e84e67
|
|
@ -13,12 +13,12 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
|
|||
def foo(x, y):
|
||||
return x * y
|
||||
|
||||
def run_foo_6_times_and_count_recompiles():
|
||||
def run_foo_6_times_and_count_recompiles(dynamic=None):
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
x = torch.randn([2])
|
||||
y = torch.randn([2])
|
||||
opt = torch._dynamo.optimize(cnt)(foo)
|
||||
opt = torch._dynamo.optimize(cnt, dynamic=dynamic)(foo)
|
||||
opt(x, y)
|
||||
x = torch.randn([3])
|
||||
y = torch.randn([3])
|
||||
|
|
@ -51,9 +51,21 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(without.frame_count, 5)
|
||||
self.assertEqual(without.op_count, 5)
|
||||
torch._dynamo.reset()
|
||||
without = run_foo_6_times_and_count_recompiles(dynamic=False)
|
||||
self.assertEqual(without.frame_count, 5)
|
||||
self.assertEqual(without.op_count, 5)
|
||||
torch._dynamo.reset()
|
||||
with_automatic = run_with_automatic()
|
||||
self.assertEqual(with_automatic.frame_count, 2)
|
||||
self.assertEqual(with_automatic.op_count, 2)
|
||||
torch._dynamo.reset()
|
||||
with_automatic = run_foo_6_times_and_count_recompiles(dynamic=None)
|
||||
self.assertEqual(with_automatic.frame_count, 2)
|
||||
self.assertEqual(with_automatic.op_count, 2)
|
||||
torch._dynamo.reset()
|
||||
with_dynamic = run_foo_6_times_and_count_recompiles(dynamic=True)
|
||||
self.assertEqual(with_dynamic.frame_count, 1)
|
||||
self.assertEqual(with_dynamic.op_count, 1)
|
||||
|
||||
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
|
||||
def test_recompiles_true_false_flop(self):
|
||||
|
|
@ -98,7 +110,7 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
|
|||
return run_foo_6_times_and_count_recompiles()
|
||||
|
||||
without = run_without_automatic()
|
||||
self.assertEqual(without.frame_count, 2)
|
||||
self.assertEqual(without.frame_count, 5)
|
||||
self.assertEqual(without.op_count, 5)
|
||||
torch._dynamo.reset()
|
||||
with_automatic = run_with_automatic()
|
||||
|
|
@ -210,3 +222,9 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(cmp_result, eager_result)
|
||||
# Recompile, alias changed
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1588,7 +1588,7 @@ class _TorchCompileWrapper:
|
|||
|
||||
def compile(model: Optional[Callable] = None, *,
|
||||
fullgraph: builtins.bool = False,
|
||||
dynamic: builtins.bool = False,
|
||||
dynamic: Optional[builtins.bool] = None,
|
||||
backend: Union[str, Callable] = "inductor",
|
||||
mode: Union[str, None] = None,
|
||||
options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None,
|
||||
|
|
@ -1610,12 +1610,13 @@ def compile(model: Optional[Callable] = None, *,
|
|||
Args:
|
||||
model (Callable): Module/function to optimize
|
||||
fullgraph (bool): Whether it is ok to break model into several subgraphs
|
||||
dynamic (bool): Use dynamic shape tracing. When this is True, we will up-front attempt
|
||||
dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt
|
||||
to generate a kernel that is as dynamic as possible to avoid recompilations when
|
||||
sizes change. This may not always work as some operations/optimizations will
|
||||
force specialization; use TORCH_LOGS=dynamic to debug overspecialization.
|
||||
In particular, if you use "reduce-overhead", this will force sizes to be static
|
||||
even with dynamic=True.
|
||||
When this is False, we will NEVER generate dynamic kernels, we will always specialize.
|
||||
By default (None), we automatically detect if dynamism has occurred and compile a more
|
||||
dynamic kernel upon recompile.
|
||||
backend (str or Callable): backend to be used
|
||||
- "inductor" is the default backend, which is a good balance between performance and overhead
|
||||
- Non experimental in-tree backends can be seen with `torch._dynamo.list_backends()`
|
||||
|
|
|
|||
|
|
@ -185,15 +185,18 @@ def innermost_fn(fn):
|
|||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable_dynamic(enable: bool = True, export: bool = False):
|
||||
if not enable:
|
||||
yield
|
||||
return
|
||||
# dynamic=True used to mean fully dynamic. However, with automatic dynamic, the default flipped to
|
||||
# deriving dynamism. For back compat, and forward compat for when dynamic=True is default, we take
|
||||
# dynamic=True here to mean "fully dynamic from the start".
|
||||
with config.patch(assume_static_by_default=False):
|
||||
def enable_dynamic(enable: Optional[bool] = None, export: bool = False):
|
||||
if enable is None:
|
||||
yield
|
||||
elif enable:
|
||||
# Assume everything is dynamic by deafult
|
||||
with config.patch(assume_static_by_default=False):
|
||||
yield
|
||||
else:
|
||||
with config.patch(
|
||||
automatic_dynamic_shapes=False, assume_static_by_default=True
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class _TorchDynamoContext:
|
||||
|
|
@ -206,7 +209,7 @@ class _TorchDynamoContext:
|
|||
first_ctx=False,
|
||||
*,
|
||||
export=False,
|
||||
dynamic=False,
|
||||
dynamic=None,
|
||||
compiler_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
|
@ -379,7 +382,7 @@ class OptimizeContext(_TorchDynamoContext):
|
|||
first_ctx=False,
|
||||
*,
|
||||
export=False,
|
||||
dynamic=False,
|
||||
dynamic=None,
|
||||
compiler_config=None,
|
||||
):
|
||||
def on_enter():
|
||||
|
|
@ -475,7 +478,7 @@ def _optimize_catch_errors(
|
|||
hooks: Hooks,
|
||||
backend_ctx_ctor=null_context,
|
||||
export=False,
|
||||
dynamic=False,
|
||||
dynamic=None,
|
||||
compiler_config=None,
|
||||
):
|
||||
return OptimizeContext(
|
||||
|
|
@ -529,7 +532,7 @@ def optimize(
|
|||
guard_export_fn=None,
|
||||
guard_fail_fn=None,
|
||||
disable=False,
|
||||
dynamic=False,
|
||||
dynamic=None,
|
||||
):
|
||||
"""
|
||||
The main entrypoint of TorchDynamo. Do graph capture and call
|
||||
|
|
@ -547,7 +550,9 @@ def optimize(
|
|||
nopython: If True, graph breaks will be errors and there will
|
||||
be a single whole-program graph.
|
||||
disable: If True, turn this decorator into a no-op
|
||||
dynamic: If True, turn on dynamic shapes support
|
||||
dynamic: If True, upfront compile as dynamic a kernel as possible. If False,
|
||||
disable all dynamic shapes support (always specialize). If None, automatically
|
||||
detect when sizes vary and generate dynamic kernels upon recompile.
|
||||
|
||||
Example Usage::
|
||||
|
||||
|
|
@ -1169,7 +1174,7 @@ def optimize_assert(
|
|||
hooks=Hooks(None, None),
|
||||
export=False,
|
||||
export_constraints=None,
|
||||
dynamic=False,
|
||||
dynamic=None,
|
||||
):
|
||||
"""
|
||||
The same as `torch._dynamo.optimize(backend, nopython=True)`
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user