In Inductor, be willing to generate deferred runtime asserts when unbacked (#138804)

Title + we avoid calling defer_assert when we statically know the guard results.
timing for pnasnet5large

```
TIMING: code_gen:21.79672 inductor_compile:39.57726 backend_compile:65.30649 entire_frame_compile:95.22052 total_wall_time:95.22052
```
matches with out the diff
```
TIMING: code_gen:21.89314 inductor_compile:39.72298 backend_compile:65.38539 entire_frame_compile:95.0854 total_wall_time:95.0854
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138804
Approved by: https://github.com/ezyang
This commit is contained in:
Laith Sakka 2024-10-27 11:15:24 -07:00 committed by PyTorch MergeBot
parent 7cb3cef05f
commit c056dc4cb8
3 changed files with 26 additions and 10 deletions

View File

@ -35,7 +35,7 @@ basic_modules_ListOfLinears_inductor_gpu, compile_time_instruct
update_hint_regression, compile_time_instruction_count, 1853008305, 0.02
update_hint_regression, compile_time_instruction_count, 1795333141, 0.02

1 add_loop_eager compile_time_instruction_count 3004749893 0.015
35
36
37
38
39
40
41

View File

@ -247,9 +247,11 @@ class SizeVarAllocator:
# for which "strides" don't make sense so we ignore them here.
# NOTE: These expressions may still block merging dims in the sound
# substitution test performed in can_merge_dims.
self.stride_vars(x, index_vars)
if isinstance(x, sympy.Expr)
else [0] * len(index_vars)
(
self.stride_vars(x, index_vars)
if isinstance(x, sympy.Expr)
else [0] * len(index_vars)
)
for x in index_formulas
]
assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))
@ -415,14 +417,29 @@ class SizeVarAllocator:
left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type]
if isinstance(right, Expr):
right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type]
assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
expr = sympy.Eq(left, right)
static_expr = self.shape_env._maybe_evaluate_static(expr)
if static_expr is not None:
assert bool(static_expr)
return left
assert self.shape_env.defer_runtime_assert(expr, "guard_equals")
return left
def guard_leq(self, left: Expr, right: Expr) -> None:
return self.guard_lt(left, right + 1)
def guard_lt(self, left: Expr, right: Expr) -> None:
assert self.shape_env.evaluate_expr(sympy.Lt(left, right))
expr = sympy.Lt(left, right)
static_expr = self.shape_env._maybe_evaluate_static(expr)
if static_expr is not None:
assert bool(static_expr)
return
assert self.shape_env.defer_runtime_assert(expr, "guard_lt")
def guarded_order(self, seq):
"""

View File

@ -6289,6 +6289,7 @@ class ShapeEnv:
for ra in ras:
ra.stack.cleanup()
@lru_cache(256)
@record_shapeenv_event(save_tracked_fakes=True)
def defer_runtime_assert(
self, orig_expr: SympyBoolean, msg: str, fx_node: Optional[torch.fx.Node] = None
@ -6326,7 +6327,6 @@ class ShapeEnv:
# NB: Don't use new_expr as expr; it could contain gunk like shape0
# which we don't want to guard on
# OK, we're definitely doing a runtime assert now
if (
self._translation_validation_enabled
and fx_node is not None
@ -6340,10 +6340,9 @@ class ShapeEnv:
if not self._suppress_guards_tls():
# If you're here because of this assert, read Note [Backwards runtime asserts]
# in torch/_inductor/graph.py
assert not self.runtime_asserts_frozen, expr
if self.runtime_asserts_frozen:
log.warning("runtime_asserts_frozen but then got %s", expr)
self._check_frozen(expr, sympy.true)
# eliminate symbols on equality tests / refine ranges
if isinstance(expr, sympy.Rel):
self._maybe_guard_rel(expr)