diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index d87ec1d40dc..0b20dd98117 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1359,9 +1359,7 @@ class Reduction(Loops): src_dtype: torch.dtype, ) -> Callable[[Sequence[_IntLike]], OpsValue]: """Convert inner_fn from a reduction to an pointwise""" - reduction_ranges = [ - V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges - ] + reduction_ranges = V.graph.sizevars.guard_int_seq(reduction_ranges) combine_fn = get_reduction_combine_fn(reduction_type, src_dtype) diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 4b14989c372..b7dd22aa307 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -438,17 +438,17 @@ def convolution( dilation = tuple(dilation) output_padding = tuple(output_padding) if not isinstance(groups, int): - groups = V.graph.sizevars.evaluate_static_shape(groups) + groups = V.graph.sizevars.guard_int(groups) assert isinstance(groups, int) # Need use hint for triton template since the template does not # work with a dynamic shape. # - # No need to evaluate_static_shape for dilation and output_padding + # No need to guard_int for dilation and output_padding # since the template is only used when dilation is 1 and output_padding # is 0. - stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride)) - padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding)) + stride = tuple(V.graph.sizevars.guard_int_seq(stride)) + padding = tuple(V.graph.sizevars.guard_int_seq(padding)) kwargs: ConvLayoutParams = { "stride": stride, @@ -468,9 +468,7 @@ def convolution( dim=0, ) - out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes( - weight.get_size() - ) + out_chan, in_chan, *kernel_shape = V.graph.sizevars.guard_int_seq(weight.get_size()) # Always convert conv1D to 2D for Intel GPU. # Only conv2D can be converted to channel last layout, @@ -568,7 +566,7 @@ def convolution( args = [x, weight, bias] bias.realize() bias.freeze_layout() - V.graph.sizevars.evaluate_static_shapes(bias.get_size()) + V.graph.sizevars.guard_int_seq(bias.get_size()) choices = [] if torch._inductor.utils._use_conv_autotune_backend("ATEN"): diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 22751fe1004..1a8ee177275 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1182,8 +1182,8 @@ def lower_cpu( skip_mask_score = kernel_options.get("SKIP_MASK_SCORE", False) # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. - SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) - SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE) assert V.graph.sizevars.evaluate_expr( sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE)) ), ( @@ -1254,18 +1254,18 @@ def set_head_dim_values( kernel_options: Dictionary to populate with options qk_head_dim: Query/Key head dimension v_head_dim: Value head dimension - graph_sizevars: Graph size variables object with evaluate_static_shape method + graph_sizevars: Graph size variables object with guard_int method """ # QK dimensions - qk_head_dim_static = graph_sizevars.evaluate_static_shape(qk_head_dim) + qk_head_dim_static = graph_sizevars.guard_int(qk_head_dim) kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim_static) kernel_options.setdefault( "QK_HEAD_DIM_ROUNDED", next_power_of_two(qk_head_dim_static) ) # V dimensions - v_head_dim_static = graph_sizevars.evaluate_static_shape(v_head_dim) + v_head_dim_static = graph_sizevars.guard_int(v_head_dim) kernel_options.setdefault("V_HEAD_DIM", v_head_dim_static) kernel_options.setdefault( "V_HEAD_DIM_ROUNDED", next_power_of_two(v_head_dim_static) @@ -1359,9 +1359,7 @@ def flex_attention( kernel_options = dict(kernel_options) # Mark symbols in custom kernel options as static shapes and add guards. kernel_options = { - k: V.graph.sizevars.evaluate_static_shape(v) - if isinstance(v, sympy.Symbol) - else v + k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v for k, v in kernel_options.items() } kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) @@ -1473,12 +1471,12 @@ def flex_attention( choices: list[Any] = [] dtype = query.get_dtype() - head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) + head_dim = V.graph.sizevars.guard_int(query.get_size()[-1]) configs = V.choices.get_flex_attention_fwd_configs(head_dim, dtype) # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. - SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) - SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE) # Note, we don't need to pass in the captured buffers explicitly # because they're implicitly added by the score_mod function @@ -2468,9 +2466,7 @@ def flex_attention_backward(*args, **kwargs): kernel_options = dict(kernel_options) # Mark symbols in custom kernel options as static shapes and add guards. kernel_options = { - k: V.graph.sizevars.evaluate_static_shape(v) - if isinstance(v, sympy.Symbol) - else v + k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v for k, v in kernel_options.items() } kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) @@ -2585,13 +2581,13 @@ def flex_attention_backward(*args, **kwargs): set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) - SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) - SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE) + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) choices: list[Any] = [] dtype = query.get_dtype() - head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) + head_dim = V.graph.sizevars.guard_int(query.get_size()[-1]) configs = V.choices.get_flex_attention_bwd_configs(head_dim, dtype) # Default config for warp specialization diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index 4ad0de7f6f0..7e0aef98185 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -363,9 +363,7 @@ def create_flex_decoding_kernel(*args, **kwargs): kernel_options = dict(kernel_options) # Mark symbols in custom kernel options as static shapes and add guards. kernel_options = { - k: V.graph.sizevars.evaluate_static_shape(v) - if isinstance(v, sympy.Symbol) - else v + k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v for k, v in kernel_options.items() } @@ -416,7 +414,7 @@ def create_flex_decoding_kernel(*args, **kwargs): choices: list[Any] = [] dtype = key.get_dtype() - head_dim = V.graph.sizevars.evaluate_static_shape(key.get_size()[-1]) + head_dim = V.graph.sizevars.guard_int(key.get_size()[-1]) configs = V.choices.get_flex_decode_configs(head_dim, dtype) # TODO: fix autotuning. @@ -494,7 +492,7 @@ def create_flex_decoding_kernel(*args, **kwargs): # TODO: This feels sketchy kernel_options.setdefault("SAFE_N_BOUNDARY", True) # Mark SPARSE_KV_BLOCK_SIZE as static shapes and add guards. - SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) original_kernel_options = kernel_options.copy() # Note, we don't need to pass in the captured buffers explicitly diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 805e4eabcb7..89c2cb3a002 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -976,9 +976,9 @@ def squeeze(x, dim=None): return TensorBox(SqueezeView.create(x.data)) dim = ( - V.graph.sizevars.evaluate_static_shape(dim) + V.graph.sizevars.guard_int(dim) if isinstance(dim, (int, sympy.Expr)) - else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim) + else tuple(V.graph.sizevars.guard_int(d) for d in dim) ) dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload] dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim) @@ -1769,9 +1769,7 @@ def split(x, sizes, dim=0): # by computing what the actual size of each chunk should be. if not isinstance(sizes, (list, tuple)): x_size = x.get_size()[dim] - chunks = V.graph.sizevars.evaluate_static_shape( - FloorDiv(x_size + sizes - 1, sizes) - ) + chunks = V.graph.sizevars.guard_int(FloorDiv(x_size + sizes - 1, sizes)) sizes_ = [sizes] * chunks # The last chunk might have a smaller size than the rest. sizes_[-1] = x_size - (chunks - 1) * sizes @@ -1797,7 +1795,7 @@ def split_with_sizes(x, sizes, dim=0): @register_lowering(aten.unbind, type_promotion_kind=None) def unbind(x, dim=0): dim = _validate_dim(x, dim, 0) - x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + x_size = V.graph.sizevars.guard_int(x.get_size()[dim]) result = [select(x, dim, i) for i in range(x_size)] return result @@ -1861,7 +1859,7 @@ def _validate_dim(x, dim, offset=0): def glu(x, dim=-1): dim = _validate_dim(x, dim, 0) # TODO: don't guard on static shape here - new_len = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) // 2 + new_len = V.graph.sizevars.guard_int(x.get_size()[dim]) // 2 a = slice_(x, dim, 0, new_len) b = slice_(x, dim, new_len, new_len * 2) return mul(a, sigmoid(b)) @@ -4017,7 +4015,7 @@ def upsample_nearestnd( x_loader = x.make_loader() i_sizes = x.get_size()[-n:] batch = x.get_size()[:-n] - i_sizes = [V.graph.sizevars.evaluate_static_shape(i) for i in i_sizes] + i_sizes = [V.graph.sizevars.guard_int(i) for i in i_sizes] assert len(scales_x) == n o_sizes = output_size @@ -4913,8 +4911,8 @@ def _adaptive_avg_pool2d(x, output_size): *batch, h_in, w_in = x.get_size() - h_in = V.graph.sizevars.evaluate_static_shape(h_in) - w_in = V.graph.sizevars.evaluate_static_shape(w_in) + h_in = V.graph.sizevars.guard_int(h_in) + w_in = V.graph.sizevars.guard_int(w_in) h_out, w_out = output_size @@ -4988,8 +4986,8 @@ def adaptive_max_pool2d(x, output_size): *batch, h_in, w_in = x.get_size() - h_in = V.graph.sizevars.evaluate_static_shape(h_in) - w_in = V.graph.sizevars.evaluate_static_shape(w_in) + h_in = V.graph.sizevars.guard_int(h_in) + w_in = V.graph.sizevars.guard_int(w_in) h_out, w_out = output_size @@ -5165,8 +5163,8 @@ def upsample_nearest2d_backward( x.realize_hint() *_batch, inp_h, inp_w = x.get_size() - inp_h = V.graph.sizevars.evaluate_static_shape(inp_h) - inp_w = V.graph.sizevars.evaluate_static_shape(inp_w) + inp_h = V.graph.sizevars.guard_int(inp_h) + inp_w = V.graph.sizevars.guard_int(inp_w) *_batch, out_h, out_w = input_size diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index d6324259229..064f0f18264 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -492,15 +492,24 @@ class SizeVarAllocator: min_val = self.evaluate_min(left, right) return right if min_val is left else left - def evaluate_static_shape(self, left: Union[Expr, int]) -> int: - if isinstance(left, int): - return left - right = self.size_hint_or_throw(left) - self.check_equals(left, sympy.Integer(right)) - return int(right) + def guard_int(self, expr: Union[Expr, int]) -> int: + """ + Similar to guard_int in symbolic_shapes.py, except this function works with SymPy + expressions instead of SymNodes. It extracts the value represented by expr from shapeEnv + and specialize the compiled graph on it. Raises an error if the result cannot be + determined due to unhinted or unbacked symbols. + """ + if isinstance(expr, int): + return expr + val = self.size_hint_or_throw(expr) + self.check_equals(expr, sympy.Integer(val)) + return int(val) - def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> list[int]: - return [self.evaluate_static_shape(x) for x in left] + def guard_int_seq(self, left: Sequence[Union[Expr, int]]) -> list[int]: + """ + Apply guard_int on a sequence of inputs. + """ + return [self.guard_int(x) for x in left] def remove_precomputed_replacements(self, expr: Expr) -> Expr: if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols): # type: ignore[attr-defined]