Tweak dynamic=False behavior (#105715)

Previously, dynamic=False is a no-op, and dynamic=True preemptively
turns on dynamic shapes everywhere.

Now, dynamic=False *disables* automatic dynamic, and an unset dynamic
defaults to dynamic=None (which uses automatic dynamic.)  This
seems to be more intuitive per
https://github.com/pytorch/pytorch/issues/105634#issuecomment-1644883477

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105715
Approved by: https://github.com/voznesenskym
This commit is contained in:
Edward Z. Yang 2023-07-24 14:10:17 +00:00 committed by PyTorch MergeBot
parent 0ab74044c2
commit 3045e84e67
3 changed files with 45 additions and 21 deletions

View File

@ -13,12 +13,12 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
def foo(x, y):
return x * y
def run_foo_6_times_and_count_recompiles():
def run_foo_6_times_and_count_recompiles(dynamic=None):
cnt = torch._dynamo.testing.CompileCounter()
x = torch.randn([2])
y = torch.randn([2])
opt = torch._dynamo.optimize(cnt)(foo)
opt = torch._dynamo.optimize(cnt, dynamic=dynamic)(foo)
opt(x, y)
x = torch.randn([3])
y = torch.randn([3])
@ -51,9 +51,21 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
without = run_foo_6_times_and_count_recompiles(dynamic=False)
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
with_automatic = run_with_automatic()
self.assertEqual(with_automatic.frame_count, 2)
self.assertEqual(with_automatic.op_count, 2)
torch._dynamo.reset()
with_automatic = run_foo_6_times_and_count_recompiles(dynamic=None)
self.assertEqual(with_automatic.frame_count, 2)
self.assertEqual(with_automatic.op_count, 2)
torch._dynamo.reset()
with_dynamic = run_foo_6_times_and_count_recompiles(dynamic=True)
self.assertEqual(with_dynamic.frame_count, 1)
self.assertEqual(with_dynamic.op_count, 1)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def test_recompiles_true_false_flop(self):
@ -98,7 +110,7 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
return run_foo_6_times_and_count_recompiles()
without = run_without_automatic()
self.assertEqual(without.frame_count, 2)
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
with_automatic = run_with_automatic()
@ -210,3 +222,9 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cmp_result, eager_result)
# Recompile, alias changed
self.assertEqual(cnt.frame_count, 2)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -1588,7 +1588,7 @@ class _TorchCompileWrapper:
def compile(model: Optional[Callable] = None, *,
fullgraph: builtins.bool = False,
dynamic: builtins.bool = False,
dynamic: Optional[builtins.bool] = None,
backend: Union[str, Callable] = "inductor",
mode: Union[str, None] = None,
options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None,
@ -1610,12 +1610,13 @@ def compile(model: Optional[Callable] = None, *,
Args:
model (Callable): Module/function to optimize
fullgraph (bool): Whether it is ok to break model into several subgraphs
dynamic (bool): Use dynamic shape tracing. When this is True, we will up-front attempt
dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt
to generate a kernel that is as dynamic as possible to avoid recompilations when
sizes change. This may not always work as some operations/optimizations will
force specialization; use TORCH_LOGS=dynamic to debug overspecialization.
In particular, if you use "reduce-overhead", this will force sizes to be static
even with dynamic=True.
When this is False, we will NEVER generate dynamic kernels, we will always specialize.
By default (None), we automatically detect if dynamism has occurred and compile a more
dynamic kernel upon recompile.
backend (str or Callable): backend to be used
- "inductor" is the default backend, which is a good balance between performance and overhead
- Non experimental in-tree backends can be seen with `torch._dynamo.list_backends()`

View File

@ -185,15 +185,18 @@ def innermost_fn(fn):
@contextlib.contextmanager
def enable_dynamic(enable: bool = True, export: bool = False):
if not enable:
yield
return
# dynamic=True used to mean fully dynamic. However, with automatic dynamic, the default flipped to
# deriving dynamism. For back compat, and forward compat for when dynamic=True is default, we take
# dynamic=True here to mean "fully dynamic from the start".
with config.patch(assume_static_by_default=False):
def enable_dynamic(enable: Optional[bool] = None, export: bool = False):
if enable is None:
yield
elif enable:
# Assume everything is dynamic by deafult
with config.patch(assume_static_by_default=False):
yield
else:
with config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
yield
class _TorchDynamoContext:
@ -206,7 +209,7 @@ class _TorchDynamoContext:
first_ctx=False,
*,
export=False,
dynamic=False,
dynamic=None,
compiler_config=None,
):
super().__init__()
@ -379,7 +382,7 @@ class OptimizeContext(_TorchDynamoContext):
first_ctx=False,
*,
export=False,
dynamic=False,
dynamic=None,
compiler_config=None,
):
def on_enter():
@ -475,7 +478,7 @@ def _optimize_catch_errors(
hooks: Hooks,
backend_ctx_ctor=null_context,
export=False,
dynamic=False,
dynamic=None,
compiler_config=None,
):
return OptimizeContext(
@ -529,7 +532,7 @@ def optimize(
guard_export_fn=None,
guard_fail_fn=None,
disable=False,
dynamic=False,
dynamic=None,
):
"""
The main entrypoint of TorchDynamo. Do graph capture and call
@ -547,7 +550,9 @@ def optimize(
nopython: If True, graph breaks will be errors and there will
be a single whole-program graph.
disable: If True, turn this decorator into a no-op
dynamic: If True, turn on dynamic shapes support
dynamic: If True, upfront compile as dynamic a kernel as possible. If False,
disable all dynamic shapes support (always specialize). If None, automatically
detect when sizes vary and generate dynamic kernels upon recompile.
Example Usage::
@ -1169,7 +1174,7 @@ def optimize_assert(
hooks=Hooks(None, None),
export=False,
export_constraints=None,
dynamic=False,
dynamic=None,
):
"""
The same as `torch._dynamo.optimize(backend, nopython=True)`