Include cublasLt as an option in max_autotune mode (#92915)

Differential Revision: D42720376 (has some internal results)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92915
Approved by: https://github.com/Chillee
This commit is contained in:
Jason Ansel 2023-01-26 06:08:17 +00:00 committed by PyTorch MergeBot
parent d88bc38b0c
commit ccad2e5000

View File

@ -71,9 +71,25 @@ mm_template = TritonTemplate(
)
aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
aten_addmm = ExternKernelChoice(torch.addmm, "at::addmm_out")
def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
"""
Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
kernel under the hood. There are a few shapes where this is slower,
but they are rare.
"""
if inp.stride(0) == 0 or inp.size(0) == 1:
return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta)
return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta)
aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
@register_lowering(aten.mm)
def tuned_mm(mat1, mat2, *, layout=None):
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
@ -96,26 +112,31 @@ def tuned_mm(mat1, mat2, *, layout=None):
@register_lowering(aten.addmm)
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
# don't expand inp to make sure fused addmm from cublasLt is used
if not use_triton_template(layout):
choices = [aten_addmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
return autotune_select_algorithm(choices, [inp, mat1, mat2], layout)
# TODO this is not quite fair benchmarking because we won't use fused cublasLt addmm
# options to tune from
choices = [
aten_addmm.bind((inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta)
]
if use_triton_template(layout):
for config in mm_configs():
choices.append(
mm_template.generate(
(inp_expanded, mat1, mat2),
layout,
**mm_options(config, k, layout),
prefix_args=1,
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
)
if inp_expanded.get_stride()[0] == 0 and inp_expanded.get_device().type == "cuda":
# unexpand inp to make sure fused addmm from cublasLt is used
choices.insert(
0,
aten_bias_addmm.bind(
(inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
),
)
for config in mm_configs():
choices.append(
mm_template.generate(
(inp_expanded, mat1, mat2),
layout,
**mm_options(config, k, layout),
prefix_args=1,
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
)
)
return autotune_select_algorithm(choices, [inp_expanded, mat1, mat2], layout)