mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This will be the last disruptive functorch internals change. Why are we moving these files? - As a part of rationalizing functorch we are moving the code in functorch/_src to torch/_functorch - This is so that we can offer the functorch APIs as native PyTorch APIs (coming soon) and resolve some internal build issues. Why are we moving all of these files at once? - It's better to break developers all at once rather than many times Test Plan: - wait for tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/90091 Approved by: https://github.com/anijain2305, https://github.com/ezyang
91 lines
3.4 KiB
Python
91 lines
3.4 KiB
Python
|
|
import torch
|
|
import torch.fx as fx
|
|
from torch.utils._pytree import tree_flatten
|
|
|
|
aten = torch.ops.aten
|
|
|
|
|
|
def get_aten_target(node):
|
|
if hasattr(node.target, 'overloadpacket'):
|
|
return node.target.overloadpacket
|
|
return node.target
|
|
|
|
|
|
rand_ops = [aten.dropout, aten._fused_dropout, aten._standard_gamma,
|
|
aten.bernoulli, aten.multinomial, aten.native_dropout,
|
|
aten.normal, aten.poisson, aten.binomial, aten.rrelu,
|
|
aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm]
|
|
|
|
|
|
# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph
|
|
def fx_graph_cse(fx_g: torch.fx.graph.Graph):
|
|
new_graph = fx.Graph()
|
|
env = {} # map from node in the old graph to node in the new graph
|
|
hash_env = {} # map from hash to a node in the new graph
|
|
token_map = {} # map from hash to token
|
|
for n in fx_g.nodes:
|
|
# The placeholder, output, and get_attr nodes are copied to the new grpah without change
|
|
# do not CSE away random operations
|
|
if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in rand_ops:
|
|
new_node = new_graph.node_copy(n, lambda x: env[x])
|
|
env[n] = new_node
|
|
else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
|
|
# substitute args and kwargs memebrs to their mapping in env if exists
|
|
# specs can be used to reconstruct nested list/dictionaries
|
|
def substitute(arg_list):
|
|
arg_list, spec = tree_flatten(arg_list)
|
|
for i in range(len(arg_list)):
|
|
v = arg_list[i]
|
|
if isinstance(v, torch.fx.node.Node) and v in env:
|
|
arg_list[i] = env[v]
|
|
return tuple(arg_list), spec
|
|
args, args_spec = substitute(n.args)
|
|
kwargs, kwargs_spec = substitute(n.kwargs)
|
|
|
|
# each token corresponds to a unique node
|
|
# nodes with the same token can be substituted
|
|
token = {"target": n.target, "args": args, "args_spec": args_spec,
|
|
"kwargs": kwargs, "kwargs_spec": kwargs_spec}
|
|
|
|
# hash substituted args to a number, do not hash specs because specs are not hashable
|
|
hash_arg = hash((args, kwargs))
|
|
hash_val = (n.target, hash_arg)
|
|
|
|
# check if a node has a substitute and can be eliminated
|
|
hash_val_in_hash_env = hash_val in hash_env
|
|
if hash_val_in_hash_env and token_map[hash_val] == token:
|
|
env[n] = hash_env[hash_val]
|
|
continue
|
|
|
|
new_node = new_graph.node_copy(n, lambda x: env[x])
|
|
env[n] = new_node
|
|
if not hash_val_in_hash_env:
|
|
hash_env[hash_val] = new_node
|
|
token_map[hash_val] = token
|
|
|
|
return new_graph
|
|
|
|
|
|
def strip_overloads(gm):
|
|
"""
|
|
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
|
|
|
Args:
|
|
gm(fx.GraphModule): The input Fx graph module to be modified
|
|
"""
|
|
for node in gm.graph.nodes:
|
|
if isinstance(node.target, torch._ops.OpOverload):
|
|
node.target = node.target.overloadpacket
|
|
gm.recompile()
|
|
|
|
|
|
def get_placeholders(graph):
|
|
return list(filter(lambda x: x.op == 'placeholder', graph.nodes))
|
|
|
|
def get_outputs(graph):
|
|
for node in graph.nodes:
|
|
if node.op == 'output':
|
|
return tree_flatten(node.args[0])[0]
|
|
raise AssertionError("No output node found")
|