mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make Inductor scheduler aware of _scaled_mm (#146992)
This is used for example to estimate runtime when doing comms overlap Pull Request resolved: https://github.com/pytorch/pytorch/pull/146992 Approved by: https://github.com/drisspg, https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
parent
9da250aada
commit
f9b8121350
|
|
@ -2390,22 +2390,32 @@ def as_storage_and_layout(
|
|||
allow_padding=allow_padding,
|
||||
exact_strides=exact_strides,
|
||||
)
|
||||
if isinstance(x, StorageBox) and isinstance(x.data, Buffer):
|
||||
if isinstance(x, StorageBox):
|
||||
_, layout = as_storage_and_layout(
|
||||
x.data,
|
||||
freeze=freeze,
|
||||
want_contiguous=want_contiguous,
|
||||
stride_order=stride_order,
|
||||
allow_padding=allow_padding,
|
||||
exact_strides=exact_strides,
|
||||
)
|
||||
return x, x.data.get_layout()
|
||||
if isinstance(x, Buffer):
|
||||
if freeze:
|
||||
if want_contiguous:
|
||||
x.data.freeze_layout()
|
||||
assert x.data.get_layout().is_contiguous()
|
||||
x.freeze_layout()
|
||||
assert x.get_layout().is_contiguous()
|
||||
elif stride_order is not None:
|
||||
x.data.freeze_layout_with_stride_order(
|
||||
x.freeze_layout_with_stride_order(
|
||||
stride_order, allow_padding=allow_padding
|
||||
)
|
||||
elif exact_strides is not None:
|
||||
x.data.freeze_layout_with_exact_strides(
|
||||
x.freeze_layout_with_exact_strides(
|
||||
exact_strides, allow_padding=allow_padding
|
||||
)
|
||||
else:
|
||||
x.data.decide_layout()
|
||||
return x, x.data.get_layout()
|
||||
x.decide_layout()
|
||||
return StorageBox(x), x.get_layout()
|
||||
if isinstance(x, ReinterpretView):
|
||||
# making the base of x contiguous or stride_ordered will not necessarily make
|
||||
# the ReinterpretView either, so don't pass along those arguments
|
||||
|
|
|
|||
|
|
@ -938,6 +938,7 @@ kernel_name_to_op = {
|
|||
"extern_kernels.mm": torch.ops.aten.mm,
|
||||
"extern_kernels.bmm": torch.ops.aten.bmm,
|
||||
"extern_kernels.addmm": torch.ops.aten.addmm,
|
||||
"extern_kernels._scaled_mm": torch.ops.aten._scaled_mm,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user