mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Include cublasLt as an option in max_autotune mode (#92915)
Differential Revision: D42720376 (has some internal results) Pull Request resolved: https://github.com/pytorch/pytorch/pull/92915 Approved by: https://github.com/Chillee
This commit is contained in:
parent
d88bc38b0c
commit
ccad2e5000
|
|
@ -71,9 +71,25 @@ mm_template = TritonTemplate(
|
|||
)
|
||||
|
||||
aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
|
||||
|
||||
|
||||
aten_addmm = ExternKernelChoice(torch.addmm, "at::addmm_out")
|
||||
|
||||
|
||||
def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
|
||||
"""
|
||||
Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
|
||||
kernel under the hood. There are a few shapes where this is slower,
|
||||
but they are rare.
|
||||
"""
|
||||
if inp.stride(0) == 0 or inp.size(0) == 1:
|
||||
return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta)
|
||||
return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta)
|
||||
|
||||
|
||||
aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
|
||||
|
||||
|
||||
@register_lowering(aten.mm)
|
||||
def tuned_mm(mat1, mat2, *, layout=None):
|
||||
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
||||
|
|
@ -96,26 +112,31 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
|||
@register_lowering(aten.addmm)
|
||||
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
|
||||
# don't expand inp to make sure fused addmm from cublasLt is used
|
||||
if not use_triton_template(layout):
|
||||
choices = [aten_addmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
|
||||
return autotune_select_algorithm(choices, [inp, mat1, mat2], layout)
|
||||
|
||||
# TODO this is not quite fair benchmarking because we won't use fused cublasLt addmm
|
||||
# options to tune from
|
||||
choices = [
|
||||
aten_addmm.bind((inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta)
|
||||
]
|
||||
if use_triton_template(layout):
|
||||
for config in mm_configs():
|
||||
choices.append(
|
||||
mm_template.generate(
|
||||
(inp_expanded, mat1, mat2),
|
||||
layout,
|
||||
**mm_options(config, k, layout),
|
||||
prefix_args=1,
|
||||
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
||||
)
|
||||
if inp_expanded.get_stride()[0] == 0 and inp_expanded.get_device().type == "cuda":
|
||||
# unexpand inp to make sure fused addmm from cublasLt is used
|
||||
choices.insert(
|
||||
0,
|
||||
aten_bias_addmm.bind(
|
||||
(inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
|
||||
),
|
||||
)
|
||||
|
||||
for config in mm_configs():
|
||||
choices.append(
|
||||
mm_template.generate(
|
||||
(inp_expanded, mat1, mat2),
|
||||
layout,
|
||||
**mm_options(config, k, layout),
|
||||
prefix_args=1,
|
||||
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
||||
)
|
||||
)
|
||||
|
||||
return autotune_select_algorithm(choices, [inp_expanded, mat1, mat2], layout)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user