pytorch/torch/_functorch/compile_utils.py
Richard Zou 52bc5c1cfe Move functorch/_src to torch/_functorch (#88756)
This will be the last disruptive functorch internals change.

Why are we moving these files?
- As a part of rationalizing functorch we are moving the code in
functorch/_src to torch/_functorch
- This is so that we can offer the functorch APIs as native PyTorch APIs
(coming soon) and resolve some internal build issues.

Why are we moving all of these files at once?
- It's better to break developers all at once rather than many times

Test Plan:
- wait for tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88756
Approved by: https://github.com/ezyang
2022-11-29 13:55:42 +00:00

91 lines
3.4 KiB
Python

import torch
import torch.fx as fx
from torch.utils._pytree import tree_flatten
aten = torch.ops.aten
def get_aten_target(node):
if hasattr(node.target, 'overloadpacket'):
return node.target.overloadpacket
return node.target
rand_ops = [aten.dropout, aten._fused_dropout, aten._standard_gamma,
aten.bernoulli, aten.multinomial, aten.native_dropout,
aten.normal, aten.poisson, aten.binomial, aten.rrelu,
aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm]
# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph
def fx_graph_cse(fx_g: torch.fx.graph.Graph):
new_graph = fx.Graph()
env = {} # map from node in the old graph to node in the new graph
hash_env = {} # map from hash to a node in the new graph
token_map = {} # map from hash to token
for n in fx_g.nodes:
# The placeholder, output, and get_attr nodes are copied to the new grpah without change
# do not CSE away random operations
if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in rand_ops:
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node
else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
# substitute args and kwargs memebrs to their mapping in env if exists
# specs can be used to reconstruct nested list/dictionaries
def substitute(arg_list):
arg_list, spec = tree_flatten(arg_list)
for i in range(len(arg_list)):
v = arg_list[i]
if isinstance(v, torch.fx.node.Node) and v in env:
arg_list[i] = env[v]
return tuple(arg_list), spec
args, args_spec = substitute(n.args)
kwargs, kwargs_spec = substitute(n.kwargs)
# each token corresponds to a unique node
# nodes with the same token can be substituted
token = {"target": n.target, "args": args, "args_spec": args_spec,
"kwargs": kwargs, "kwargs_spec": kwargs_spec}
# hash substituted args to a number, do not hash specs because specs are not hashable
hash_arg = hash((args, kwargs))
hash_val = (n.target, hash_arg)
# check if a node has a substitute and can be eliminated
hash_val_in_hash_env = hash_val in hash_env
if hash_val_in_hash_env and token_map[hash_val] == token:
env[n] = hash_env[hash_val]
continue
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node
if not hash_val_in_hash_env:
hash_env[hash_val] = new_node
token_map[hash_val] = token
return new_graph
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
def get_placeholders(graph):
return list(filter(lambda x: x.op == 'placeholder', graph.nodes))
def get_outputs(graph):
for node in graph.nodes:
if node.op == 'output':
return tree_flatten(node.args[0])[0]
raise AssertionError("No output node found")