diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index b6865d06930..ae628a1e3ea 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2390,22 +2390,32 @@ def as_storage_and_layout( allow_padding=allow_padding, exact_strides=exact_strides, ) - if isinstance(x, StorageBox) and isinstance(x.data, Buffer): + if isinstance(x, StorageBox): + _, layout = as_storage_and_layout( + x.data, + freeze=freeze, + want_contiguous=want_contiguous, + stride_order=stride_order, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + return x, x.data.get_layout() + if isinstance(x, Buffer): if freeze: if want_contiguous: - x.data.freeze_layout() - assert x.data.get_layout().is_contiguous() + x.freeze_layout() + assert x.get_layout().is_contiguous() elif stride_order is not None: - x.data.freeze_layout_with_stride_order( + x.freeze_layout_with_stride_order( stride_order, allow_padding=allow_padding ) elif exact_strides is not None: - x.data.freeze_layout_with_exact_strides( + x.freeze_layout_with_exact_strides( exact_strides, allow_padding=allow_padding ) else: - x.data.decide_layout() - return x, x.data.get_layout() + x.decide_layout() + return StorageBox(x), x.get_layout() if isinstance(x, ReinterpretView): # making the base of x contiguous or stride_ordered will not necessarily make # the ReinterpretView either, so don't pass along those arguments diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index ad99a1ecc45..8e29fd586d8 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -938,6 +938,7 @@ kernel_name_to_op = { "extern_kernels.mm": torch.ops.aten.mm, "extern_kernels.bmm": torch.ops.aten.bmm, "extern_kernels.addmm": torch.ops.aten.addmm, + "extern_kernels._scaled_mm": torch.ops.aten._scaled_mm, }