mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
In Inductor, be willing to generate deferred runtime asserts when unbacked (#138804)
Title + we avoid calling defer_assert when we statically know the guard results. timing for pnasnet5large ``` TIMING: code_gen:21.79672 inductor_compile:39.57726 backend_compile:65.30649 entire_frame_compile:95.22052 total_wall_time:95.22052 ``` matches with out the diff ``` TIMING: code_gen:21.89314 inductor_compile:39.72298 backend_compile:65.38539 entire_frame_compile:95.0854 total_wall_time:95.0854 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/138804 Approved by: https://github.com/ezyang
This commit is contained in:
parent
7cb3cef05f
commit
c056dc4cb8
|
|
@ -35,7 +35,7 @@ basic_modules_ListOfLinears_inductor_gpu, compile_time_instruct
|
|||
|
||||
|
||||
|
||||
update_hint_regression, compile_time_instruction_count, 1853008305, 0.02
|
||||
update_hint_regression, compile_time_instruction_count, 1795333141, 0.02
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
|
@ -247,9 +247,11 @@ class SizeVarAllocator:
|
|||
# for which "strides" don't make sense so we ignore them here.
|
||||
# NOTE: These expressions may still block merging dims in the sound
|
||||
# substitution test performed in can_merge_dims.
|
||||
self.stride_vars(x, index_vars)
|
||||
if isinstance(x, sympy.Expr)
|
||||
else [0] * len(index_vars)
|
||||
(
|
||||
self.stride_vars(x, index_vars)
|
||||
if isinstance(x, sympy.Expr)
|
||||
else [0] * len(index_vars)
|
||||
)
|
||||
for x in index_formulas
|
||||
]
|
||||
assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))
|
||||
|
|
@ -415,14 +417,29 @@ class SizeVarAllocator:
|
|||
left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
||||
if isinstance(right, Expr):
|
||||
right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
||||
assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
|
||||
|
||||
expr = sympy.Eq(left, right)
|
||||
static_expr = self.shape_env._maybe_evaluate_static(expr)
|
||||
|
||||
if static_expr is not None:
|
||||
assert bool(static_expr)
|
||||
return left
|
||||
|
||||
assert self.shape_env.defer_runtime_assert(expr, "guard_equals")
|
||||
return left
|
||||
|
||||
def guard_leq(self, left: Expr, right: Expr) -> None:
|
||||
return self.guard_lt(left, right + 1)
|
||||
|
||||
def guard_lt(self, left: Expr, right: Expr) -> None:
|
||||
assert self.shape_env.evaluate_expr(sympy.Lt(left, right))
|
||||
expr = sympy.Lt(left, right)
|
||||
static_expr = self.shape_env._maybe_evaluate_static(expr)
|
||||
|
||||
if static_expr is not None:
|
||||
assert bool(static_expr)
|
||||
return
|
||||
|
||||
assert self.shape_env.defer_runtime_assert(expr, "guard_lt")
|
||||
|
||||
def guarded_order(self, seq):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6289,6 +6289,7 @@ class ShapeEnv:
|
|||
for ra in ras:
|
||||
ra.stack.cleanup()
|
||||
|
||||
@lru_cache(256)
|
||||
@record_shapeenv_event(save_tracked_fakes=True)
|
||||
def defer_runtime_assert(
|
||||
self, orig_expr: SympyBoolean, msg: str, fx_node: Optional[torch.fx.Node] = None
|
||||
|
|
@ -6326,7 +6327,6 @@ class ShapeEnv:
|
|||
# NB: Don't use new_expr as expr; it could contain gunk like shape0
|
||||
# which we don't want to guard on
|
||||
|
||||
# OK, we're definitely doing a runtime assert now
|
||||
if (
|
||||
self._translation_validation_enabled
|
||||
and fx_node is not None
|
||||
|
|
@ -6340,10 +6340,9 @@ class ShapeEnv:
|
|||
if not self._suppress_guards_tls():
|
||||
# If you're here because of this assert, read Note [Backwards runtime asserts]
|
||||
# in torch/_inductor/graph.py
|
||||
assert not self.runtime_asserts_frozen, expr
|
||||
|
||||
if self.runtime_asserts_frozen:
|
||||
log.warning("runtime_asserts_frozen but then got %s", expr)
|
||||
self._check_frozen(expr, sympy.true)
|
||||
|
||||
# eliminate symbols on equality tests / refine ranges
|
||||
if isinstance(expr, sympy.Rel):
|
||||
self._maybe_guard_rel(expr)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user