Don't register wrong overload to prim decomp (#163138)

These decompositions take precedence before CIA decomps in fake tensor prop, as a result, we would hit this implementation for all where overloads which is wrong in some cases. For the overloads that can't be implemented by this decomp, we just run the default CIA impl. Previously this doesn't matter because in post-dispatch IR, aten.where would have decomposed but when user tries to preserve aten.where this issue will surface because fake tensor will start seeing aten.where.

Differential Revision: [D82604702](https://our.internmc.facebook.com/intern/diff/D82604702)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163138
Approved by: https://github.com/henryoier, https://github.com/ezyang
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2025-09-17 08:39:54 -07:00 committed by PyTorch MergeBot
parent af8c232b75
commit 56893ca1f6
3 changed files with 41 additions and 4 deletions

View File

@ -1458,6 +1458,40 @@ graph():
ep = export(f, args, strict=False)
self.assertEqual(ep.module()(*args), f(*args))
def test_where_decomp(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.ops.aten.where.default(x > 0)
test_module = TestModule()
sample_input = (torch.randn(2, 3),)
def auto_dynamic_shapes_from_args(args): # pyre-ignore
"""
This function creates dynamic shapes specification with Dim.AUTO
in all dimensions of all tensors for given argument list.
"""
if isinstance(args, list):
return [auto_dynamic_shapes_from_args(arg) for arg in args]
elif isinstance(args, tuple):
return tuple(auto_dynamic_shapes_from_args(arg) for arg in args)
elif isinstance(args, dict):
return {k: auto_dynamic_shapes_from_args(v) for k, v in args.items()}
elif isinstance(args, torch.Tensor):
return {j: Dim.AUTO for j in range(args.dim())}
else:
print(f"args type: {type(args)}")
return None
ep = torch.export.export(
test_module,
sample_input,
dynamic_shapes=auto_dynamic_shapes_from_args(sample_input),
).run_decompositions({})
def test_basic_non_strict_fake_tensor(self):
class Basic(torch.nn.Module):
def __init__(self) -> None:

View File

@ -190,7 +190,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
xfail("tril"),
xfail("triu"),
xfail("unfold_copy"),
xfail("where"),
# Output has dynamic shape.
# Does not have a meta kernel implementation.
skip("linalg.lstsq"),

View File

@ -1997,9 +1997,13 @@ def clamp_max(
# https://pytorch.org/docs/stable/generated/torch.where.html
# TODO: implement alternate where
@register_decomposition(aten.where)
@out_wrapper()
# TODO: implement where.default
@register_decomposition(aten.where.self)
@register_decomposition(aten.where.ScalarSelf)
@register_decomposition(aten.where.ScalarOther)
@register_decomposition(aten.where.Scalar)
@register_decomposition(aten.where.self_out)
@out_wrapper(exact_dtype=True)
@elementwise_type_promotion_wrapper(
type_promoting_args=("a", "b"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,