mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
aten.empty is almost always fusible into its consumer, so we never CSE
it. This fixes a bug that looks like the following:
```py
@torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"})
def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None:
out_sin.copy_(x.sin())
out_cos.copy_(x.cos())
@torch.compile
def f(x):
out0 = torch.empty_like(x)
out1 = torch.empty_like(x)
sin_cos(x, out0, out1)
return x.clone(), out0, out1
x = torch.randn(3, requires_grad=True)
f(x)
```
- cse would de-duplicate the empty nodes
- reinplacing would add an additional clone (because it can't write to
both tensors at the same time)
- the clone lowers into a new buffer + a copy_ kernel
- the copy_ kernel is unnecessary because "empty" is special - all reinplacing needed was an additional
buffer, it doesn't matter what the values are.
We could attempt to fix this on the reinplacing side but this seemed
better as a partitioner heuristic and the reinplacing fix is a bit more
tricky (we'd need to identify that the op never reads from the empty
node).
Test Plan:
- new test (the old number was 27, the new number is 21, so this PR
helped).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134703
Approved by: https://github.com/yf225
ghstack dependencies: #134466, #134490, #134491
144 lines
4.8 KiB
Python
144 lines
4.8 KiB
Python
# mypy: ignore-errors
|
|
|
|
|
|
from typing import Callable
|
|
|
|
import torch
|
|
import torch.fx as fx
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._pytree import tree_flatten
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
|
|
def get_aten_target(node: fx.Node) -> Callable:
|
|
if hasattr(node.target, "overloadpacket"):
|
|
return node.target.overloadpacket
|
|
return node.target
|
|
|
|
|
|
rand_ops = [
|
|
aten.dropout,
|
|
aten._fused_dropout,
|
|
aten._standard_gamma,
|
|
aten.bernoulli,
|
|
aten.multinomial,
|
|
aten.native_dropout,
|
|
aten.normal,
|
|
aten.poisson,
|
|
aten.binomial,
|
|
aten.rrelu,
|
|
aten.rand_like,
|
|
aten.rand,
|
|
aten.randint,
|
|
aten.randn,
|
|
aten.randperm,
|
|
]
|
|
|
|
|
|
# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph
|
|
def fx_graph_cse(fx_g: torch.fx.graph.Graph):
|
|
new_graph = fx.Graph()
|
|
env = {} # map from node in the old graph to node in the new graph
|
|
hash_env = {} # map from hash to a node in the new graph
|
|
token_map = {} # map from hash to token
|
|
|
|
from torch._inductor.pattern_matcher import (
|
|
compute_mutation_region_ids,
|
|
same_mutation_regions,
|
|
)
|
|
|
|
compute_mutation_region_ids(fx_g) # type: ignore[arg-type]
|
|
for n in fx_g.nodes:
|
|
# The placeholder, output, and get_attr nodes are copied to the new graph without change
|
|
# do not CSE away random operations
|
|
if (
|
|
n.op == "placeholder"
|
|
or n.op == "output"
|
|
or n.op == "get_attr"
|
|
or get_aten_target(n) in rand_ops
|
|
# aten.empty is non-deterministic, so don't CSE it.
|
|
# Also, aten.empty is almost always fusible into its consumer,
|
|
# so it's not worth CSEing.
|
|
or get_aten_target(n) is aten.empty
|
|
):
|
|
new_node = new_graph.node_copy(n, lambda x: env[x])
|
|
env[n] = new_node
|
|
else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
|
|
# substitute args and kwargs members to their mapping in env if exists
|
|
# specs can be used to reconstruct nested list/dictionaries
|
|
def substitute(arg_list):
|
|
arg_list, spec = tree_flatten(arg_list)
|
|
for i in range(len(arg_list)):
|
|
v = arg_list[i]
|
|
if isinstance(v, torch.fx.node.Node) and v in env:
|
|
arg_list[i] = env[v]
|
|
if isinstance(v, (torch.SymBool, torch.SymInt, torch.SymFloat)):
|
|
arg_list[i] = v.node
|
|
return tuple(arg_list), spec
|
|
|
|
args, args_spec = substitute(n.args)
|
|
kwargs, kwargs_spec = substitute(n.kwargs)
|
|
|
|
# each token corresponds to a unique node
|
|
# nodes with the same token can be substituted
|
|
token = {
|
|
"target": n.target,
|
|
"args": args,
|
|
"args_spec": args_spec,
|
|
"kwargs": kwargs,
|
|
"kwargs_spec": kwargs_spec,
|
|
}
|
|
|
|
# hash substituted args to a number, do not hash specs because specs are not hashable
|
|
# We need to add type into hash to avoid situations like:
|
|
# hash((primals_2, 1.0)) == hash((primals_2, 1))
|
|
hash_arg = hash(
|
|
(tuple((a, type(a)) for a in args), tuple((a, type(a)) for a in kwargs))
|
|
)
|
|
hash_val = (n.target, hash_arg)
|
|
|
|
# check if a node has a substitute and can be eliminated
|
|
hash_val_in_hash_env = hash_val in hash_env
|
|
overwrite_due_to_mutation = False
|
|
if hash_val_in_hash_env and token_map[hash_val] == token:
|
|
duplicate_n_prev = hash_env[hash_val]
|
|
if same_mutation_regions(n, duplicate_n_prev):
|
|
env[n] = duplicate_n_prev
|
|
continue
|
|
else:
|
|
# any futures duplicates should replace with n, not duplicate_n_prev
|
|
overwrite_due_to_mutation = True
|
|
|
|
new_node = new_graph.node_copy(n, lambda x: env[x])
|
|
env[n] = new_node
|
|
if overwrite_due_to_mutation or not hash_val_in_hash_env:
|
|
hash_env[hash_val] = new_node
|
|
token_map[hash_val] = token
|
|
|
|
return new_graph
|
|
|
|
|
|
def strip_overloads(gm):
|
|
"""
|
|
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
|
|
|
Args:
|
|
gm(fx.GraphModule): The input Fx graph module to be modified
|
|
"""
|
|
for node in gm.graph.nodes:
|
|
if isinstance(node.target, torch._ops.OpOverload):
|
|
node.target = node.target.overloadpacket
|
|
gm.recompile()
|
|
|
|
|
|
def get_placeholders(graph):
|
|
return graph.find_nodes(op="placeholder")
|
|
|
|
|
|
def get_outputs(graph):
|
|
for node in graph.find_nodes(op="output"):
|
|
return pytree.tree_leaves(node.args[0])
|
|
raise AssertionError("No output node found")
|