mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamic shapes] handle Max(*,1) for inductor layout contiguity (#160578)
Differential Revision: D80214882 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160578 Approved by: https://github.com/ZixinYang, https://github.com/bobrenjc93
This commit is contained in:
parent
4cae9cf2df
commit
f7ad69f59c
|
|
@ -3558,12 +3558,21 @@ class IndexingConstant(BaseConstant):
|
|||
def is_contiguous_strides_for_shape(
|
||||
stride: Sequence[_IntLike], shape: Sequence[_IntLike]
|
||||
) -> bool:
|
||||
return all(
|
||||
size == 1 or left == right
|
||||
for left, right, size in zip(
|
||||
stride, FlexibleLayout.contiguous_strides(shape), shape
|
||||
)
|
||||
)
|
||||
expected_stride = 1
|
||||
expected_stride_max = 1
|
||||
for x, y in reversed(tuple(zip(shape, stride))):
|
||||
if x == 1:
|
||||
continue
|
||||
|
||||
if not V.graph.sizevars.statically_known_equals(
|
||||
y, expected_stride
|
||||
) and not V.graph.sizevars.statically_known_equals(y, expected_stride_max):
|
||||
return False
|
||||
|
||||
expected_stride_max *= sympy.Max(1, x)
|
||||
expected_stride *= x
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_align_for_dtype(dtype: torch.dtype) -> int:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user