pytorch/torch/_inductor/constant_folding.py
Oguz Ulgen 01b979fc9a [Inductor] Fix constant folding and extern kernel mutation tracking bugs (#115908)
This PR fixes two bugs
1) Constant folding a triton kernel results in the kernel's inputs to be returned back without any modification. Disable constant folding for triton kernels. Need more investigation
2) NoneLayout buffers should not be deleted as they do not exist

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115908
Approved by: https://github.com/aakhundov, https://github.com/jansel
2023-12-19 02:06:50 +00:00

198 lines
6.4 KiB
Python

import collections
from typing import Any, Callable, Dict, Optional
import torch
import torch.utils._pytree as pytree
aten = torch.ops.aten
def replace_node_with_constant(gm, node, constant):
g = gm.graph
if not hasattr(gm, "_frozen_param_count"):
gm._frozen_param_count = 0
i = gm._frozen_param_count
while True:
qualname = f"_frozen_param{i}"
if not hasattr(gm, qualname):
break
i += 1
gm._frozen_param_count = i + 1
with g.inserting_before(node):
new_input_node = g.create_node("get_attr", qualname, (), {})
node.replace_all_uses_with(new_input_node)
new_input_node.meta.update(node.meta)
g.erase_node(node)
# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
gm.register_buffer(qualname, constant)
setattr(gm, qualname, constant)
class ConstantFolder(torch.fx.Interpreter):
def __init__(
self,
gm,
skip_constructors=False,
):
super().__init__(gm)
self.node_replacements: Dict[torch.fx.Node, Any] = {}
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
self.unknown_value = object()
self.skip_constructors: bool = skip_constructors
# overwrite this to deallocate env values if their only remaining use
# is the output
self.user_to_last_uses = self.node_to_last_non_output_use()
def is_impure(self, node: torch.fx.node.Node):
if node.target in [
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]:
# For the pattern fp32_weight -> q -> dq
# We only folding fp32_weight -> q
# int8_weight and leave dq in graph to be fused
return True
return False
def node_to_last_non_output_use(self):
last_non_output_use = collections.defaultdict(list)
seen_uses = set()
output_node = next(iter(reversed(self.module.graph.nodes)))
for node in reversed(self.module.graph.nodes):
if node.target == "output":
continue
def add_use(inp):
if inp in seen_uses:
return
seen_uses.add(inp)
last_non_output_use[node].append(inp)
pytree.tree_map_only(torch.fx.Node, add_use, (node.args, node.kwargs))
# if this node is only used in output, we want to gc it right away
if len(node.users) == 1 and output_node in node.users:
last_non_output_use[node].append(node)
return last_non_output_use
def run_node(self, node):
if node.target == "output":
# because we remove nodes from env on last non output use,
# re-define them now or we'll get error in interpreter
def set_env(arg):
self.env[arg] = self.unknown_value
pytree.tree_map_only(torch.fx.Node, set_env, node.args)
return super().run_node(node)
args, kwargs = self.fetch_args_kwargs_from_env(node)
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
if self.unknown_value in flattened_inputs:
return self.unknown_value
# TODO - fix errors with this
if (
node.op == "call_function"
and node.target == aten._efficientzerotensor.default
):
return self.unknown_value
# TODO - constant folding triton kernel returns the inputs -- fix this
if (
node.op == "call_function"
and node.name == "triton_kernel_wrapper_functional_proxy"
):
return self.unknown_value
# skip constructors, since inductor generates optimal code for them already
# and turning into tensor would result in an additional global memory read
# TODO - more complicated strategy
if (
self.skip_constructors
and node.op != "get_attr"
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
):
return self.unknown_value
# All mutations should either be removed or on inputs which we did not make constant
if (
isinstance(node.target, torch._ops.OpOverload)
and torch.Tag.nondeterministic_seeded in node.target.tags
):
return self.unknown_value
out = super().run_node(node)
if node.op != "get_attr" and isinstance(out, torch.Tensor):
if not self.insertable_tensor_check(out):
return out
if self.is_impure(node):
return self.unknown_value
self.add_node_replacement(node, out)
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
for n in flattened_node_inps:
if not isinstance(n, torch.fx.Node):
continue
self.replaced_uses[n] += 1
for to_delete in self.user_to_last_uses.get(node, []):
if self.replaced_uses[to_delete] == len(to_delete.users):
self.node_replacements.pop(to_delete, None)
return out
def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
return True
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
self.node_replacements[node] = tensor
def run(self):
env = {}
for n in self.module.graph.nodes:
if n.op == "placeholder":
env[n] = self.unknown_value
return super().run(initial_env=env)
@torch.utils._python_dispatch._disable_current_modes()
def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None):
cf = ConstantFolder(gm, skip_constructors=True)
cf.run()
for node, constant in cf.node_replacements.items():
if constraint_fn is not None and not constraint_fn(node):
continue
replace_node_with_constant(gm, node, constant)
erased_params = []
for node in gm.graph.nodes:
if node.op == "get_attr" and len(node.users) == 0:
if hasattr(gm, node.target):
delattr(gm, node.target)
erased_params.append(node)
for node in erased_params:
gm.graph.erase_node(node)
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()