unify dynamic shapes API namings 3 (guard_int, guard_int_seq) (#155973)

evaluate_static_shape -> guard_int
evaluate_static_shapes -> guard_int_seq

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155973
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Laith Sakka 2025-06-24 19:56:31 -07:00 committed by PyTorch MergeBot
parent 61f6aa36b9
commit 1e7e21ec5d
6 changed files with 52 additions and 55 deletions

View File

@ -1359,9 +1359,7 @@ class Reduction(Loops):
src_dtype: torch.dtype,
) -> Callable[[Sequence[_IntLike]], OpsValue]:
"""Convert inner_fn from a reduction to an pointwise"""
reduction_ranges = [
V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges
]
reduction_ranges = V.graph.sizevars.guard_int_seq(reduction_ranges)
combine_fn = get_reduction_combine_fn(reduction_type, src_dtype)

View File

@ -438,17 +438,17 @@ def convolution(
dilation = tuple(dilation)
output_padding = tuple(output_padding)
if not isinstance(groups, int):
groups = V.graph.sizevars.evaluate_static_shape(groups)
groups = V.graph.sizevars.guard_int(groups)
assert isinstance(groups, int)
# Need use hint for triton template since the template does not
# work with a dynamic shape.
#
# No need to evaluate_static_shape for dilation and output_padding
# No need to guard_int for dilation and output_padding
# since the template is only used when dilation is 1 and output_padding
# is 0.
stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride))
padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding))
stride = tuple(V.graph.sizevars.guard_int_seq(stride))
padding = tuple(V.graph.sizevars.guard_int_seq(padding))
kwargs: ConvLayoutParams = {
"stride": stride,
@ -468,9 +468,7 @@ def convolution(
dim=0,
)
out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes(
weight.get_size()
)
out_chan, in_chan, *kernel_shape = V.graph.sizevars.guard_int_seq(weight.get_size())
# Always convert conv1D to 2D for Intel GPU.
# Only conv2D can be converted to channel last layout,
@ -568,7 +566,7 @@ def convolution(
args = [x, weight, bias]
bias.realize()
bias.freeze_layout()
V.graph.sizevars.evaluate_static_shapes(bias.get_size())
V.graph.sizevars.guard_int_seq(bias.get_size())
choices = []
if torch._inductor.utils._use_conv_autotune_backend("ATEN"):

View File

@ -1182,8 +1182,8 @@ def lower_cpu(
skip_mask_score = kernel_options.get("SKIP_MASK_SCORE", False)
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE)
assert V.graph.sizevars.evaluate_expr(
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
), (
@ -1254,18 +1254,18 @@ def set_head_dim_values(
kernel_options: Dictionary to populate with options
qk_head_dim: Query/Key head dimension
v_head_dim: Value head dimension
graph_sizevars: Graph size variables object with evaluate_static_shape method
graph_sizevars: Graph size variables object with guard_int method
"""
# QK dimensions
qk_head_dim_static = graph_sizevars.evaluate_static_shape(qk_head_dim)
qk_head_dim_static = graph_sizevars.guard_int(qk_head_dim)
kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim_static)
kernel_options.setdefault(
"QK_HEAD_DIM_ROUNDED", next_power_of_two(qk_head_dim_static)
)
# V dimensions
v_head_dim_static = graph_sizevars.evaluate_static_shape(v_head_dim)
v_head_dim_static = graph_sizevars.guard_int(v_head_dim)
kernel_options.setdefault("V_HEAD_DIM", v_head_dim_static)
kernel_options.setdefault(
"V_HEAD_DIM_ROUNDED", next_power_of_two(v_head_dim_static)
@ -1359,9 +1359,7 @@ def flex_attention(
kernel_options = dict(kernel_options)
# Mark symbols in custom kernel options as static shapes and add guards.
kernel_options = {
k: V.graph.sizevars.evaluate_static_shape(v)
if isinstance(v, sympy.Symbol)
else v
k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v
for k, v in kernel_options.items()
}
kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision())
@ -1473,12 +1471,12 @@ def flex_attention(
choices: list[Any] = []
dtype = query.get_dtype()
head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1])
head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
configs = V.choices.get_flex_attention_fwd_configs(head_dim, dtype)
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE)
# Note, we don't need to pass in the captured buffers explicitly
# because they're implicitly added by the score_mod function
@ -2468,9 +2466,7 @@ def flex_attention_backward(*args, **kwargs):
kernel_options = dict(kernel_options)
# Mark symbols in custom kernel options as static shapes and add guards.
kernel_options = {
k: V.graph.sizevars.evaluate_static_shape(v)
if isinstance(v, sympy.Symbol)
else v
k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v
for k, v in kernel_options.items()
}
kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision())
@ -2585,13 +2581,13 @@ def flex_attention_backward(*args, **kwargs):
set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE)
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
choices: list[Any] = []
dtype = query.get_dtype()
head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1])
head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
configs = V.choices.get_flex_attention_bwd_configs(head_dim, dtype)
# Default config for warp specialization

View File

@ -363,9 +363,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
kernel_options = dict(kernel_options)
# Mark symbols in custom kernel options as static shapes and add guards.
kernel_options = {
k: V.graph.sizevars.evaluate_static_shape(v)
if isinstance(v, sympy.Symbol)
else v
k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v
for k, v in kernel_options.items()
}
@ -416,7 +414,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
choices: list[Any] = []
dtype = key.get_dtype()
head_dim = V.graph.sizevars.evaluate_static_shape(key.get_size()[-1])
head_dim = V.graph.sizevars.guard_int(key.get_size()[-1])
configs = V.choices.get_flex_decode_configs(head_dim, dtype)
# TODO: fix autotuning.
@ -494,7 +492,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
# TODO: This feels sketchy
kernel_options.setdefault("SAFE_N_BOUNDARY", True)
# Mark SPARSE_KV_BLOCK_SIZE as static shapes and add guards.
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
original_kernel_options = kernel_options.copy()
# Note, we don't need to pass in the captured buffers explicitly

View File

@ -976,9 +976,9 @@ def squeeze(x, dim=None):
return TensorBox(SqueezeView.create(x.data))
dim = (
V.graph.sizevars.evaluate_static_shape(dim)
V.graph.sizevars.guard_int(dim)
if isinstance(dim, (int, sympy.Expr))
else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim)
else tuple(V.graph.sizevars.guard_int(d) for d in dim)
)
dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload]
dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim)
@ -1769,9 +1769,7 @@ def split(x, sizes, dim=0):
# by computing what the actual size of each chunk should be.
if not isinstance(sizes, (list, tuple)):
x_size = x.get_size()[dim]
chunks = V.graph.sizevars.evaluate_static_shape(
FloorDiv(x_size + sizes - 1, sizes)
)
chunks = V.graph.sizevars.guard_int(FloorDiv(x_size + sizes - 1, sizes))
sizes_ = [sizes] * chunks
# The last chunk might have a smaller size than the rest.
sizes_[-1] = x_size - (chunks - 1) * sizes
@ -1797,7 +1795,7 @@ def split_with_sizes(x, sizes, dim=0):
@register_lowering(aten.unbind, type_promotion_kind=None)
def unbind(x, dim=0):
dim = _validate_dim(x, dim, 0)
x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim])
x_size = V.graph.sizevars.guard_int(x.get_size()[dim])
result = [select(x, dim, i) for i in range(x_size)]
return result
@ -1861,7 +1859,7 @@ def _validate_dim(x, dim, offset=0):
def glu(x, dim=-1):
dim = _validate_dim(x, dim, 0)
# TODO: don't guard on static shape here
new_len = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) // 2
new_len = V.graph.sizevars.guard_int(x.get_size()[dim]) // 2
a = slice_(x, dim, 0, new_len)
b = slice_(x, dim, new_len, new_len * 2)
return mul(a, sigmoid(b))
@ -4017,7 +4015,7 @@ def upsample_nearestnd(
x_loader = x.make_loader()
i_sizes = x.get_size()[-n:]
batch = x.get_size()[:-n]
i_sizes = [V.graph.sizevars.evaluate_static_shape(i) for i in i_sizes]
i_sizes = [V.graph.sizevars.guard_int(i) for i in i_sizes]
assert len(scales_x) == n
o_sizes = output_size
@ -4913,8 +4911,8 @@ def _adaptive_avg_pool2d(x, output_size):
*batch, h_in, w_in = x.get_size()
h_in = V.graph.sizevars.evaluate_static_shape(h_in)
w_in = V.graph.sizevars.evaluate_static_shape(w_in)
h_in = V.graph.sizevars.guard_int(h_in)
w_in = V.graph.sizevars.guard_int(w_in)
h_out, w_out = output_size
@ -4988,8 +4986,8 @@ def adaptive_max_pool2d(x, output_size):
*batch, h_in, w_in = x.get_size()
h_in = V.graph.sizevars.evaluate_static_shape(h_in)
w_in = V.graph.sizevars.evaluate_static_shape(w_in)
h_in = V.graph.sizevars.guard_int(h_in)
w_in = V.graph.sizevars.guard_int(w_in)
h_out, w_out = output_size
@ -5165,8 +5163,8 @@ def upsample_nearest2d_backward(
x.realize_hint()
*_batch, inp_h, inp_w = x.get_size()
inp_h = V.graph.sizevars.evaluate_static_shape(inp_h)
inp_w = V.graph.sizevars.evaluate_static_shape(inp_w)
inp_h = V.graph.sizevars.guard_int(inp_h)
inp_w = V.graph.sizevars.guard_int(inp_w)
*_batch, out_h, out_w = input_size

View File

@ -492,15 +492,24 @@ class SizeVarAllocator:
min_val = self.evaluate_min(left, right)
return right if min_val is left else left
def evaluate_static_shape(self, left: Union[Expr, int]) -> int:
if isinstance(left, int):
return left
right = self.size_hint_or_throw(left)
self.check_equals(left, sympy.Integer(right))
return int(right)
def guard_int(self, expr: Union[Expr, int]) -> int:
"""
Similar to guard_int in symbolic_shapes.py, except this function works with SymPy
expressions instead of SymNodes. It extracts the value represented by expr from shapeEnv
and specialize the compiled graph on it. Raises an error if the result cannot be
determined due to unhinted or unbacked symbols.
"""
if isinstance(expr, int):
return expr
val = self.size_hint_or_throw(expr)
self.check_equals(expr, sympy.Integer(val))
return int(val)
def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> list[int]:
return [self.evaluate_static_shape(x) for x in left]
def guard_int_seq(self, left: Sequence[Union[Expr, int]]) -> list[int]:
"""
Apply guard_int on a sequence of inputs.
"""
return [self.guard_int(x) for x in left]
def remove_precomputed_replacements(self, expr: Expr) -> Expr:
if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols): # type: ignore[attr-defined]