mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96578 Approved by: https://github.com/ezyang
31 lines
1.0 KiB
Python
31 lines
1.0 KiB
Python
import torch
|
|
|
|
|
|
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
|
|
# Works for length 2 patterns with 1 module and 1 function/method.
|
|
def matches_module_function_pattern(pattern, node, modules):
|
|
if len(node.args) == 0:
|
|
return False
|
|
if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
|
|
node, torch.fx.Node
|
|
):
|
|
return False
|
|
# the first node is call_module
|
|
if node.args[0].op != "call_module":
|
|
return False
|
|
if not isinstance(node.args[0].target, str):
|
|
return False
|
|
if node.args[0].target not in modules:
|
|
return False
|
|
if type(modules[node.args[0].target]) is not pattern[0]:
|
|
return False
|
|
# the second node is call_function or call_method
|
|
if node.op != "call_function" and node.op != "call_method":
|
|
return False
|
|
if node.target != pattern[1]:
|
|
return False
|
|
# make sure node.args[0] output is only used by current node.
|
|
if len(node.args[0].users) > 1:
|
|
return False
|
|
return True
|