mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
use statically_known_true instead of guard_size_oblivious in pattern matcher (#147557)
We shouldn't add guards here. Use statically_known_true instead. Internal xref: https://fb.workplace.com/groups/1075192433118967/?multi_permalinks=1609560723015466&comment_id=1610040026300869¬if_id=1740082892544333¬if_t=work_feedback_reaction_generic&ref=notif Differential Revision: [D69950122](https://our.internmc.facebook.com/intern/diff/D69950122/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/147557 Approved by: https://github.com/eellison
This commit is contained in:
parent
b246cd7b82
commit
6b44a91a62
|
|
@ -63,7 +63,7 @@ from torch._dynamo.utils import counters
|
||||||
from torch._prims_common import is_integer_dtype
|
from torch._prims_common import is_integer_dtype
|
||||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||||
from torch.fx.graph_module import _get_attr
|
from torch.fx.graph_module import _get_attr
|
||||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||||
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
||||||
|
|
@ -1345,7 +1345,7 @@ def register_replacement(
|
||||||
)
|
)
|
||||||
for v in itertools.chain(args[i].shape, args[i].stride()):
|
for v in itertools.chain(args[i].shape, args[i].stride()):
|
||||||
if isinstance(v, torch.SymInt) and all(
|
if isinstance(v, torch.SymInt) and all(
|
||||||
guard_size_oblivious(v != a) for a in sym_args
|
statically_known_true(v != a) for a in sym_args
|
||||||
):
|
):
|
||||||
sym_args.append(v)
|
sym_args.append(v)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user