mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
unify dynamic shapes API namings 3 (guard_int, guard_int_seq) (#155973)
evaluate_static_shape -> guard_int evaluate_static_shapes -> guard_int_seq Pull Request resolved: https://github.com/pytorch/pytorch/pull/155973 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
61f6aa36b9
commit
1e7e21ec5d
|
|
@ -1359,9 +1359,7 @@ class Reduction(Loops):
|
|||
src_dtype: torch.dtype,
|
||||
) -> Callable[[Sequence[_IntLike]], OpsValue]:
|
||||
"""Convert inner_fn from a reduction to an pointwise"""
|
||||
reduction_ranges = [
|
||||
V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges
|
||||
]
|
||||
reduction_ranges = V.graph.sizevars.guard_int_seq(reduction_ranges)
|
||||
|
||||
combine_fn = get_reduction_combine_fn(reduction_type, src_dtype)
|
||||
|
||||
|
|
|
|||
|
|
@ -438,17 +438,17 @@ def convolution(
|
|||
dilation = tuple(dilation)
|
||||
output_padding = tuple(output_padding)
|
||||
if not isinstance(groups, int):
|
||||
groups = V.graph.sizevars.evaluate_static_shape(groups)
|
||||
groups = V.graph.sizevars.guard_int(groups)
|
||||
assert isinstance(groups, int)
|
||||
|
||||
# Need use hint for triton template since the template does not
|
||||
# work with a dynamic shape.
|
||||
#
|
||||
# No need to evaluate_static_shape for dilation and output_padding
|
||||
# No need to guard_int for dilation and output_padding
|
||||
# since the template is only used when dilation is 1 and output_padding
|
||||
# is 0.
|
||||
stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride))
|
||||
padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding))
|
||||
stride = tuple(V.graph.sizevars.guard_int_seq(stride))
|
||||
padding = tuple(V.graph.sizevars.guard_int_seq(padding))
|
||||
|
||||
kwargs: ConvLayoutParams = {
|
||||
"stride": stride,
|
||||
|
|
@ -468,9 +468,7 @@ def convolution(
|
|||
dim=0,
|
||||
)
|
||||
|
||||
out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes(
|
||||
weight.get_size()
|
||||
)
|
||||
out_chan, in_chan, *kernel_shape = V.graph.sizevars.guard_int_seq(weight.get_size())
|
||||
|
||||
# Always convert conv1D to 2D for Intel GPU.
|
||||
# Only conv2D can be converted to channel last layout,
|
||||
|
|
@ -568,7 +566,7 @@ def convolution(
|
|||
args = [x, weight, bias]
|
||||
bias.realize()
|
||||
bias.freeze_layout()
|
||||
V.graph.sizevars.evaluate_static_shapes(bias.get_size())
|
||||
V.graph.sizevars.guard_int_seq(bias.get_size())
|
||||
|
||||
choices = []
|
||||
if torch._inductor.utils._use_conv_autotune_backend("ATEN"):
|
||||
|
|
|
|||
|
|
@ -1182,8 +1182,8 @@ def lower_cpu(
|
|||
|
||||
skip_mask_score = kernel_options.get("SKIP_MASK_SCORE", False)
|
||||
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
|
||||
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
|
||||
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
|
||||
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
|
||||
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE)
|
||||
assert V.graph.sizevars.evaluate_expr(
|
||||
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
|
||||
), (
|
||||
|
|
@ -1254,18 +1254,18 @@ def set_head_dim_values(
|
|||
kernel_options: Dictionary to populate with options
|
||||
qk_head_dim: Query/Key head dimension
|
||||
v_head_dim: Value head dimension
|
||||
graph_sizevars: Graph size variables object with evaluate_static_shape method
|
||||
graph_sizevars: Graph size variables object with guard_int method
|
||||
|
||||
"""
|
||||
# QK dimensions
|
||||
qk_head_dim_static = graph_sizevars.evaluate_static_shape(qk_head_dim)
|
||||
qk_head_dim_static = graph_sizevars.guard_int(qk_head_dim)
|
||||
kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim_static)
|
||||
kernel_options.setdefault(
|
||||
"QK_HEAD_DIM_ROUNDED", next_power_of_two(qk_head_dim_static)
|
||||
)
|
||||
|
||||
# V dimensions
|
||||
v_head_dim_static = graph_sizevars.evaluate_static_shape(v_head_dim)
|
||||
v_head_dim_static = graph_sizevars.guard_int(v_head_dim)
|
||||
kernel_options.setdefault("V_HEAD_DIM", v_head_dim_static)
|
||||
kernel_options.setdefault(
|
||||
"V_HEAD_DIM_ROUNDED", next_power_of_two(v_head_dim_static)
|
||||
|
|
@ -1359,9 +1359,7 @@ def flex_attention(
|
|||
kernel_options = dict(kernel_options)
|
||||
# Mark symbols in custom kernel options as static shapes and add guards.
|
||||
kernel_options = {
|
||||
k: V.graph.sizevars.evaluate_static_shape(v)
|
||||
if isinstance(v, sympy.Symbol)
|
||||
else v
|
||||
k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v
|
||||
for k, v in kernel_options.items()
|
||||
}
|
||||
kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision())
|
||||
|
|
@ -1473,12 +1471,12 @@ def flex_attention(
|
|||
choices: list[Any] = []
|
||||
|
||||
dtype = query.get_dtype()
|
||||
head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1])
|
||||
head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
|
||||
configs = V.choices.get_flex_attention_fwd_configs(head_dim, dtype)
|
||||
|
||||
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
|
||||
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
|
||||
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
|
||||
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
|
||||
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE)
|
||||
|
||||
# Note, we don't need to pass in the captured buffers explicitly
|
||||
# because they're implicitly added by the score_mod function
|
||||
|
|
@ -2468,9 +2466,7 @@ def flex_attention_backward(*args, **kwargs):
|
|||
kernel_options = dict(kernel_options)
|
||||
# Mark symbols in custom kernel options as static shapes and add guards.
|
||||
kernel_options = {
|
||||
k: V.graph.sizevars.evaluate_static_shape(v)
|
||||
if isinstance(v, sympy.Symbol)
|
||||
else v
|
||||
k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v
|
||||
for k, v in kernel_options.items()
|
||||
}
|
||||
kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision())
|
||||
|
|
@ -2585,13 +2581,13 @@ def flex_attention_backward(*args, **kwargs):
|
|||
|
||||
set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars)
|
||||
|
||||
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
|
||||
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
|
||||
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE)
|
||||
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
|
||||
|
||||
choices: list[Any] = []
|
||||
|
||||
dtype = query.get_dtype()
|
||||
head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1])
|
||||
head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
|
||||
configs = V.choices.get_flex_attention_bwd_configs(head_dim, dtype)
|
||||
|
||||
# Default config for warp specialization
|
||||
|
|
|
|||
|
|
@ -363,9 +363,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
kernel_options = dict(kernel_options)
|
||||
# Mark symbols in custom kernel options as static shapes and add guards.
|
||||
kernel_options = {
|
||||
k: V.graph.sizevars.evaluate_static_shape(v)
|
||||
if isinstance(v, sympy.Symbol)
|
||||
else v
|
||||
k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v
|
||||
for k, v in kernel_options.items()
|
||||
}
|
||||
|
||||
|
|
@ -416,7 +414,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
|
||||
choices: list[Any] = []
|
||||
dtype = key.get_dtype()
|
||||
head_dim = V.graph.sizevars.evaluate_static_shape(key.get_size()[-1])
|
||||
head_dim = V.graph.sizevars.guard_int(key.get_size()[-1])
|
||||
configs = V.choices.get_flex_decode_configs(head_dim, dtype)
|
||||
|
||||
# TODO: fix autotuning.
|
||||
|
|
@ -494,7 +492,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
# TODO: This feels sketchy
|
||||
kernel_options.setdefault("SAFE_N_BOUNDARY", True)
|
||||
# Mark SPARSE_KV_BLOCK_SIZE as static shapes and add guards.
|
||||
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
|
||||
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
|
||||
|
||||
original_kernel_options = kernel_options.copy()
|
||||
# Note, we don't need to pass in the captured buffers explicitly
|
||||
|
|
|
|||
|
|
@ -976,9 +976,9 @@ def squeeze(x, dim=None):
|
|||
return TensorBox(SqueezeView.create(x.data))
|
||||
|
||||
dim = (
|
||||
V.graph.sizevars.evaluate_static_shape(dim)
|
||||
V.graph.sizevars.guard_int(dim)
|
||||
if isinstance(dim, (int, sympy.Expr))
|
||||
else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim)
|
||||
else tuple(V.graph.sizevars.guard_int(d) for d in dim)
|
||||
)
|
||||
dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload]
|
||||
dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim)
|
||||
|
|
@ -1769,9 +1769,7 @@ def split(x, sizes, dim=0):
|
|||
# by computing what the actual size of each chunk should be.
|
||||
if not isinstance(sizes, (list, tuple)):
|
||||
x_size = x.get_size()[dim]
|
||||
chunks = V.graph.sizevars.evaluate_static_shape(
|
||||
FloorDiv(x_size + sizes - 1, sizes)
|
||||
)
|
||||
chunks = V.graph.sizevars.guard_int(FloorDiv(x_size + sizes - 1, sizes))
|
||||
sizes_ = [sizes] * chunks
|
||||
# The last chunk might have a smaller size than the rest.
|
||||
sizes_[-1] = x_size - (chunks - 1) * sizes
|
||||
|
|
@ -1797,7 +1795,7 @@ def split_with_sizes(x, sizes, dim=0):
|
|||
@register_lowering(aten.unbind, type_promotion_kind=None)
|
||||
def unbind(x, dim=0):
|
||||
dim = _validate_dim(x, dim, 0)
|
||||
x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim])
|
||||
x_size = V.graph.sizevars.guard_int(x.get_size()[dim])
|
||||
result = [select(x, dim, i) for i in range(x_size)]
|
||||
return result
|
||||
|
||||
|
|
@ -1861,7 +1859,7 @@ def _validate_dim(x, dim, offset=0):
|
|||
def glu(x, dim=-1):
|
||||
dim = _validate_dim(x, dim, 0)
|
||||
# TODO: don't guard on static shape here
|
||||
new_len = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) // 2
|
||||
new_len = V.graph.sizevars.guard_int(x.get_size()[dim]) // 2
|
||||
a = slice_(x, dim, 0, new_len)
|
||||
b = slice_(x, dim, new_len, new_len * 2)
|
||||
return mul(a, sigmoid(b))
|
||||
|
|
@ -4017,7 +4015,7 @@ def upsample_nearestnd(
|
|||
x_loader = x.make_loader()
|
||||
i_sizes = x.get_size()[-n:]
|
||||
batch = x.get_size()[:-n]
|
||||
i_sizes = [V.graph.sizevars.evaluate_static_shape(i) for i in i_sizes]
|
||||
i_sizes = [V.graph.sizevars.guard_int(i) for i in i_sizes]
|
||||
|
||||
assert len(scales_x) == n
|
||||
o_sizes = output_size
|
||||
|
|
@ -4913,8 +4911,8 @@ def _adaptive_avg_pool2d(x, output_size):
|
|||
|
||||
*batch, h_in, w_in = x.get_size()
|
||||
|
||||
h_in = V.graph.sizevars.evaluate_static_shape(h_in)
|
||||
w_in = V.graph.sizevars.evaluate_static_shape(w_in)
|
||||
h_in = V.graph.sizevars.guard_int(h_in)
|
||||
w_in = V.graph.sizevars.guard_int(w_in)
|
||||
|
||||
h_out, w_out = output_size
|
||||
|
||||
|
|
@ -4988,8 +4986,8 @@ def adaptive_max_pool2d(x, output_size):
|
|||
|
||||
*batch, h_in, w_in = x.get_size()
|
||||
|
||||
h_in = V.graph.sizevars.evaluate_static_shape(h_in)
|
||||
w_in = V.graph.sizevars.evaluate_static_shape(w_in)
|
||||
h_in = V.graph.sizevars.guard_int(h_in)
|
||||
w_in = V.graph.sizevars.guard_int(w_in)
|
||||
|
||||
h_out, w_out = output_size
|
||||
|
||||
|
|
@ -5165,8 +5163,8 @@ def upsample_nearest2d_backward(
|
|||
x.realize_hint()
|
||||
|
||||
*_batch, inp_h, inp_w = x.get_size()
|
||||
inp_h = V.graph.sizevars.evaluate_static_shape(inp_h)
|
||||
inp_w = V.graph.sizevars.evaluate_static_shape(inp_w)
|
||||
inp_h = V.graph.sizevars.guard_int(inp_h)
|
||||
inp_w = V.graph.sizevars.guard_int(inp_w)
|
||||
|
||||
*_batch, out_h, out_w = input_size
|
||||
|
||||
|
|
|
|||
|
|
@ -492,15 +492,24 @@ class SizeVarAllocator:
|
|||
min_val = self.evaluate_min(left, right)
|
||||
return right if min_val is left else left
|
||||
|
||||
def evaluate_static_shape(self, left: Union[Expr, int]) -> int:
|
||||
if isinstance(left, int):
|
||||
return left
|
||||
right = self.size_hint_or_throw(left)
|
||||
self.check_equals(left, sympy.Integer(right))
|
||||
return int(right)
|
||||
def guard_int(self, expr: Union[Expr, int]) -> int:
|
||||
"""
|
||||
Similar to guard_int in symbolic_shapes.py, except this function works with SymPy
|
||||
expressions instead of SymNodes. It extracts the value represented by expr from shapeEnv
|
||||
and specialize the compiled graph on it. Raises an error if the result cannot be
|
||||
determined due to unhinted or unbacked symbols.
|
||||
"""
|
||||
if isinstance(expr, int):
|
||||
return expr
|
||||
val = self.size_hint_or_throw(expr)
|
||||
self.check_equals(expr, sympy.Integer(val))
|
||||
return int(val)
|
||||
|
||||
def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> list[int]:
|
||||
return [self.evaluate_static_shape(x) for x in left]
|
||||
def guard_int_seq(self, left: Sequence[Union[Expr, int]]) -> list[int]:
|
||||
"""
|
||||
Apply guard_int on a sequence of inputs.
|
||||
"""
|
||||
return [self.guard_int(x) for x in left]
|
||||
|
||||
def remove_precomputed_replacements(self, expr: Expr) -> Expr:
|
||||
if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols): # type: ignore[attr-defined]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user