[Inductor] No longer throw error in bmm out_dtype lowering due to template heuristics (#166457)

Fixes https://github.com/pytorch/pytorch/issues/165892

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166457
Approved by: https://github.com/coconutruben
This commit is contained in:
PaulZhang12 2025-10-28 16:07:03 -07:00 committed by PyTorch MergeBot
parent 5849eea129
commit c2e3cc7aed

View File

@ -239,9 +239,10 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
templates_to_use.append(aten_handler) templates_to_use.append(aten_handler)
kwarg_overrides[aten_handler.uid] = aten_extra_kwargs kwarg_overrides[aten_handler.uid] = aten_extra_kwargs
if use_triton_template(layout, check_max_autotune=False): if use_triton_template(layout, check_max_autotune=False) and (
out_dtype is None or out_dtype == mat1.get_dtype()
):
# TODO: add out_dtype support for Triton Template # TODO: add out_dtype support for Triton Template
assert out_dtype is None, "out_dtype is not supported for Triton"
templates_to_use.append(bmm_template) templates_to_use.append(bmm_template)
# Single unified call for all templates # Single unified call for all templates