mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] No longer throw error in bmm out_dtype lowering due to template heuristics (#166457)
Fixes https://github.com/pytorch/pytorch/issues/165892 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166457 Approved by: https://github.com/coconutruben
This commit is contained in:
parent
5849eea129
commit
c2e3cc7aed
|
|
@ -239,9 +239,10 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
|||
templates_to_use.append(aten_handler)
|
||||
kwarg_overrides[aten_handler.uid] = aten_extra_kwargs
|
||||
|
||||
if use_triton_template(layout, check_max_autotune=False):
|
||||
if use_triton_template(layout, check_max_autotune=False) and (
|
||||
out_dtype is None or out_dtype == mat1.get_dtype()
|
||||
):
|
||||
# TODO: add out_dtype support for Triton Template
|
||||
assert out_dtype is None, "out_dtype is not supported for Triton"
|
||||
templates_to_use.append(bmm_template)
|
||||
|
||||
# Single unified call for all templates
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user