pytorch/torch/_inductor/constant_folding.py
Aaron Orenstein d95aedf5fd [BE] typing for decorators - fx/_compatibility (part 1) (#134202)
Part of #134054.

This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
2024-08-22 17:07:33 +00:00

349 lines
12 KiB
Python

import collections
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.utils._pytree as pytree
aten = torch.ops.aten
# We would like to split modules into two subgraphs for runtime weight updates to work correctly.
# The use case and more information could be found at:
# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing
META_TAG = "MODULE_TYPE"
MODULE_TAG = "_MAIN_MODULE"
CONST_MODULE_TAG = "_CONST_MODULE"
def replace_node_with_constant(
gm: torch.fx.GraphModule,
node: torch.fx.Node,
constant: torch.Tensor,
name: Optional[str] = None,
) -> None:
g = gm.graph
if name:
qualname = name
else:
if not hasattr(gm, "_frozen_param_count"):
gm._frozen_param_count = 0 # type: ignore[assignment]
i = gm._frozen_param_count
while True:
qualname = f"_frozen_param{i}"
if not hasattr(gm, qualname):
break
i += 1
gm._frozen_param_count = i + 1
with g.inserting_before(node):
new_input_node = g.create_node("get_attr", qualname, (), {})
node.replace_all_uses_with(new_input_node)
new_input_node.meta.update(node.meta)
g.erase_node(node)
# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
gm.register_buffer(qualname, constant)
setattr(gm, qualname, constant)
def is_const_source(
node: torch.fx.Node, lifted_constants: Optional[Dict[str, Any]]
) -> bool:
return node.op == "get_attr" or (
node.op == "placeholder"
and lifted_constants is not None
and node.name in lifted_constants
)
class ConstantFolder(torch.fx.Interpreter):
def __init__(
self,
gm: torch.fx.GraphModule,
skip_constructors: bool = False,
lifted_constants: Optional[Dict[str, torch.Tensor]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None:
super().__init__(gm)
self.node_replacements: Dict[torch.fx.Node, Any] = {}
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
self.unknown_value = object()
self.skip_constructors: bool = skip_constructors
# overwrite this to deallocate env values if their only remaining use
# is the output
self.user_to_last_uses = self.node_to_last_non_output_use()
self.lifted_constants = lifted_constants
def _support_dynamic_shape(self) -> bool:
# ConstantFolder not support dynamic shape now
return False
def _deduce_value(self, node: torch.fx.Node) -> Any:
return super().run_node(node)
def is_impure(self, node: torch.fx.node.Node) -> bool:
if (
node.target == torch.ops.prims.convert_element_type.default
and is_const_source(node.args[0], self.lifted_constants) # type: ignore[arg-type]
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
and node.args[1] == torch.bfloat16
):
# For int8_weight -> dq -> bf16_weight
return True
if node.target in [
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]:
# For the pattern fp32_weight -> q -> dq
# We only folding fp32_weight -> q
# int8_weight and leave dq in graph to be fused
return True
return False
def node_to_last_non_output_use(self) -> Dict[torch.fx.Node, List[torch.fx.Node]]:
last_non_output_use = collections.defaultdict(list)
seen_uses = set()
output_node = next(iter(reversed(self.module.graph.nodes)))
for node in reversed(self.module.graph.nodes):
if node.target == "output":
continue
def add_use(inp: torch.fx.Node) -> None:
if inp in seen_uses:
return
seen_uses.add(inp)
last_non_output_use[node].append(inp)
# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))
# if this node is only used in output, we want to gc it right away
if len(node.users) == 1 and output_node in node.users:
last_non_output_use[node].append(node)
return last_non_output_use
def run_node(self, node: torch.fx.Node) -> Any:
if node.target == "output":
# because we remove nodes from env on last non output use,
# re-define them now or we'll get error in interpreter
def set_env(arg: torch.fx.Node) -> None:
self.env[arg] = self.unknown_value
# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
return super().run_node(node)
args, kwargs = self.fetch_args_kwargs_from_env(node)
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
# We need to do this weird thing because in cases where flattened_inputs
# contains a ScriptObject, equality checking results in a type error if
# the types are different.
if any(
type(self.unknown_value) == type(input_) and self.unknown_value == input_
for input_ in flattened_inputs
):
return self.unknown_value
# TODO - fix errors with this
if (
node.op == "call_function"
and node.target == aten._efficientzerotensor.default
):
return self.unknown_value
# TODO - constant folding triton kernel returns the inputs -- fix this
if (
node.op == "call_function"
and node.name == "triton_kernel_wrapper_functional_proxy"
):
return self.unknown_value
# skip constructors, since inductor generates optimal code for them already
# and turning into tensor would result in an additional global memory read
# TODO - more complicated strategy
if (
self.skip_constructors
and not is_const_source(node, self.lifted_constants)
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
):
return self.unknown_value
# All mutations should either be removed or on inputs which we did not make constant
if (
isinstance(node.target, torch._ops.OpOverload)
and torch.Tag.nondeterministic_seeded in node.target.tags
):
return self.unknown_value
out = self._deduce_value(node)
if out == self.unknown_value:
return self.unknown_value
if not is_const_source(node, self.lifted_constants) and isinstance(
out, torch.Tensor
):
if out.device.type == "meta":
return out
if not self.insertable_tensor_check(out):
return out
if self.is_impure(node):
return self.unknown_value
self.add_node_replacement(node, out)
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
for n in flattened_node_inps:
if not isinstance(n, torch.fx.Node):
continue
self.replaced_uses[n] += 1
for to_delete in self.user_to_last_uses.get(node, []):
if self.replaced_uses[to_delete] == len(to_delete.users):
self.node_replacements.pop(to_delete, None)
return out
def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
return True
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
self.node_replacements[node] = tensor
def run(self) -> Any: # type: ignore[override]
env: Dict[torch.fx.Node, Any] = {}
self.insert_placerholder_values(env)
return super().run(initial_env=env)
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
for n in self.module.graph.find_nodes(op="placeholder"):
if self.lifted_constants is not None and n.name in self.lifted_constants:
env[n] = self.lifted_constants[n.name]
else:
env[n] = self.unknown_value # type: ignore[assignment]
def constant_fold(
gm: torch.fx.GraphModule,
constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None:
with torch.utils._python_dispatch._disable_current_modes():
cf = ConstantFolder(gm, skip_constructors=True)
cf.run()
for node, constant in cf.node_replacements.items():
if constraint_fn is not None and not constraint_fn(node):
continue
replace_node_with_constant(gm, node, constant)
erased_params = []
for node in gm.graph.find_nodes(op="get_attr"):
if len(node.users) == 0:
if hasattr(gm, node.target):
delattr(gm, node.target)
erased_params.append(node)
for node in erased_params:
gm.graph.erase_node(node)
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
def constant_graph_tag(
gm: torch.fx.GraphModule,
lifted_constants: Optional[Dict[str, Any]],
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]],
) -> None:
with torch.utils._python_dispatch._disable_current_modes():
cf = ConstantFolder(
gm, skip_constructors=True, lifted_constants=lifted_constants
)
cf.run()
for node in gm.graph.nodes:
if skip_folding_node_fn is not None and skip_folding_node_fn(node):
node.meta[META_TAG] = MODULE_TAG
continue
if (
is_const_source(node, lifted_constants)
or node in cf.node_replacements
or node in cf.replaced_uses
):
node.meta[META_TAG] = CONST_MODULE_TAG
else:
node.meta[META_TAG] = MODULE_TAG
def run_and_get_constant_graph(
gm: torch.fx.GraphModule,
lifted_constants: Optional[Dict[str, Any]],
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]],
) -> Tuple[torch.fx.GraphModule, Tuple[torch.Tensor, ...]]:
"""
Construct a GraphModule which corresponds to the part which could be
constant folded in provided gm.
"""
constant_graph_tag(gm, lifted_constants, skip_folding_node_fn)
def untag(node: torch.fx.Node) -> bool:
used_to_fold = False
for u in node.users:
if u.meta[META_TAG] == CONST_MODULE_TAG:
used_to_fold = True
break
if not used_to_fold:
node.meta[META_TAG] = MODULE_TAG
return used_to_fold
const_args = []
if lifted_constants is not None:
placeholders = list(gm.graph.find_nodes(op="placeholder"))
for node in placeholders:
if node.meta[META_TAG] == MODULE_TAG:
continue
if untag(node):
const_args.append(lifted_constants[node.name])
# We rewrite the tags, if it's a constant being directly consumed, without
# any folding opportunity, we keep it in main gm.
for node in gm.graph.find_nodes(op="get_attr"):
untag(node)
new_graph = torch.fx.Graph()
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
output_nodes = []
for node in gm.graph.nodes:
if node.meta[META_TAG] == MODULE_TAG:
continue
new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
node_remapping[node] = new_node
for user in node.users:
if user.meta[META_TAG] == MODULE_TAG:
output_nodes.append(new_node)
break
new_graph.output(tuple(output_nodes))
new_graph.lint()
new_gm = torch.fx.GraphModule(gm, new_graph)
const_result = new_gm(*const_args)
return new_gm, const_result