mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Don't register wrong overload to prim decomp (#163138)
These decompositions take precedence before CIA decomps in fake tensor prop, as a result, we would hit this implementation for all where overloads which is wrong in some cases. For the overloads that can't be implemented by this decomp, we just run the default CIA impl. Previously this doesn't matter because in post-dispatch IR, aten.where would have decomposed but when user tries to preserve aten.where this issue will surface because fake tensor will start seeing aten.where. Differential Revision: [D82604702](https://our.internmc.facebook.com/intern/diff/D82604702) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163138 Approved by: https://github.com/henryoier, https://github.com/ezyang
This commit is contained in:
parent
af8c232b75
commit
56893ca1f6
|
|
@ -1458,6 +1458,40 @@ graph():
|
|||
ep = export(f, args, strict=False)
|
||||
self.assertEqual(ep.module()(*args), f(*args))
|
||||
|
||||
def test_where_decomp(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.where.default(x > 0)
|
||||
|
||||
test_module = TestModule()
|
||||
sample_input = (torch.randn(2, 3),)
|
||||
|
||||
def auto_dynamic_shapes_from_args(args): # pyre-ignore
|
||||
"""
|
||||
This function creates dynamic shapes specification with Dim.AUTO
|
||||
in all dimensions of all tensors for given argument list.
|
||||
"""
|
||||
if isinstance(args, list):
|
||||
return [auto_dynamic_shapes_from_args(arg) for arg in args]
|
||||
elif isinstance(args, tuple):
|
||||
return tuple(auto_dynamic_shapes_from_args(arg) for arg in args)
|
||||
elif isinstance(args, dict):
|
||||
return {k: auto_dynamic_shapes_from_args(v) for k, v in args.items()}
|
||||
elif isinstance(args, torch.Tensor):
|
||||
return {j: Dim.AUTO for j in range(args.dim())}
|
||||
else:
|
||||
print(f"args type: {type(args)}")
|
||||
return None
|
||||
|
||||
ep = torch.export.export(
|
||||
test_module,
|
||||
sample_input,
|
||||
dynamic_shapes=auto_dynamic_shapes_from_args(sample_input),
|
||||
).run_decompositions({})
|
||||
|
||||
def test_basic_non_strict_fake_tensor(self):
|
||||
class Basic(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -190,7 +190,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
|
|||
xfail("tril"),
|
||||
xfail("triu"),
|
||||
xfail("unfold_copy"),
|
||||
xfail("where"),
|
||||
# Output has dynamic shape.
|
||||
# Does not have a meta kernel implementation.
|
||||
skip("linalg.lstsq"),
|
||||
|
|
|
|||
|
|
@ -1997,9 +1997,13 @@ def clamp_max(
|
|||
|
||||
|
||||
# https://pytorch.org/docs/stable/generated/torch.where.html
|
||||
# TODO: implement alternate where
|
||||
@register_decomposition(aten.where)
|
||||
@out_wrapper()
|
||||
# TODO: implement where.default
|
||||
@register_decomposition(aten.where.self)
|
||||
@register_decomposition(aten.where.ScalarSelf)
|
||||
@register_decomposition(aten.where.ScalarOther)
|
||||
@register_decomposition(aten.where.Scalar)
|
||||
@register_decomposition(aten.where.self_out)
|
||||
@out_wrapper(exact_dtype=True)
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a", "b"),
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user