Make Inductor scheduler aware of _scaled_mm (#146992)

This is used for example to estimate runtime when doing comms overlap

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146992
Approved by: https://github.com/drisspg, https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
Luca Wehrstedt 2025-02-19 16:31:42 +00:00 committed by PyTorch MergeBot
parent 9da250aada
commit f9b8121350
2 changed files with 18 additions and 7 deletions

View File

@ -2390,22 +2390,32 @@ def as_storage_and_layout(
allow_padding=allow_padding,
exact_strides=exact_strides,
)
if isinstance(x, StorageBox) and isinstance(x.data, Buffer):
if isinstance(x, StorageBox):
_, layout = as_storage_and_layout(
x.data,
freeze=freeze,
want_contiguous=want_contiguous,
stride_order=stride_order,
allow_padding=allow_padding,
exact_strides=exact_strides,
)
return x, x.data.get_layout()
if isinstance(x, Buffer):
if freeze:
if want_contiguous:
x.data.freeze_layout()
assert x.data.get_layout().is_contiguous()
x.freeze_layout()
assert x.get_layout().is_contiguous()
elif stride_order is not None:
x.data.freeze_layout_with_stride_order(
x.freeze_layout_with_stride_order(
stride_order, allow_padding=allow_padding
)
elif exact_strides is not None:
x.data.freeze_layout_with_exact_strides(
x.freeze_layout_with_exact_strides(
exact_strides, allow_padding=allow_padding
)
else:
x.data.decide_layout()
return x, x.data.get_layout()
x.decide_layout()
return StorageBox(x), x.get_layout()
if isinstance(x, ReinterpretView):
# making the base of x contiguous or stride_ordered will not necessarily make
# the ReinterpretView either, so don't pass along those arguments

View File

@ -938,6 +938,7 @@ kernel_name_to_op = {
"extern_kernels.mm": torch.ops.aten.mm,
"extern_kernels.bmm": torch.ops.aten.bmm,
"extern_kernels.addmm": torch.ops.aten.addmm,
"extern_kernels._scaled_mm": torch.ops.aten._scaled_mm,
}