mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR added constant folding for quantize ops so that instead of storing fp32 weight in the quantized model, we'll get int8/int16 etc. weight Test Plan: python test/test_quantization.py TestQuantizePT2E.test_fold_quantize also will verify in executorch later Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D49399210](https://our.internmc.facebook.com/intern/diff/D49399210) Pull Request resolved: https://github.com/pytorch/pytorch/pull/109343 Approved by: https://github.com/kimishpatel, https://github.com/jgong5
190 lines
6.1 KiB
Python
190 lines
6.1 KiB
Python
import collections
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
|
|
aten = torch.ops.aten
|
|
|
|
|
|
def replace_node_with_constant(gm, node, constant):
|
|
g = gm.graph
|
|
|
|
if not hasattr(gm, "_frozen_param_count"):
|
|
gm._frozen_param_count = 0
|
|
|
|
i = gm._frozen_param_count
|
|
|
|
while True:
|
|
qualname = f"_frozen_param{i}"
|
|
if not hasattr(gm, qualname):
|
|
break
|
|
i += 1
|
|
|
|
gm._frozen_param_count = i + 1
|
|
|
|
with g.inserting_before(node):
|
|
new_input_node = g.create_node("get_attr", qualname, (), {})
|
|
node.replace_all_uses_with(new_input_node)
|
|
new_input_node.meta.update(node.meta)
|
|
g.erase_node(node)
|
|
|
|
# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
|
|
gm.register_buffer(qualname, constant)
|
|
setattr(gm, qualname, constant)
|
|
|
|
|
|
class ConstantFolder(torch.fx.Interpreter):
|
|
def __init__(
|
|
self,
|
|
gm,
|
|
skip_constructors=False,
|
|
):
|
|
super().__init__(gm)
|
|
self.node_replacements: Dict[torch.fx.Node, Any] = {}
|
|
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
|
|
self.unknown_value = object()
|
|
self.skip_constructors: bool = skip_constructors
|
|
|
|
# overwrite this to deallocate env values if their only remaining use
|
|
# is the output
|
|
self.user_to_last_uses = self.node_to_last_non_output_use()
|
|
|
|
def is_impure(self, node: torch.fx.node.Node):
|
|
if node.target in [
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
|
]:
|
|
# For the pattern fp32_weight -> q -> dq
|
|
# We only folding fp32_weight -> q
|
|
# int8_weight and leave dq in graph to be fused
|
|
return True
|
|
return False
|
|
|
|
def node_to_last_non_output_use(self):
|
|
last_non_output_use = collections.defaultdict(list)
|
|
seen_uses = set()
|
|
output_node = next(iter(reversed(self.module.graph.nodes)))
|
|
|
|
for node in reversed(self.module.graph.nodes):
|
|
if node.target == "output":
|
|
continue
|
|
|
|
def add_use(inp):
|
|
if inp in seen_uses:
|
|
return
|
|
|
|
seen_uses.add(inp)
|
|
last_non_output_use[node].append(inp)
|
|
|
|
pytree.tree_map_only(torch.fx.Node, add_use, (node.args, node.kwargs))
|
|
|
|
# if this node is only used in output, we want to gc it right away
|
|
if len(node.users) == 1 and output_node in node.users:
|
|
last_non_output_use[node].append(node)
|
|
|
|
return last_non_output_use
|
|
|
|
def run_node(self, node):
|
|
if node.target == "output":
|
|
# because we remove nodes from env on last non output use,
|
|
# re-define them now or we'll get error in interpreter
|
|
def set_env(arg):
|
|
self.env[arg] = self.unknown_value
|
|
|
|
pytree.tree_map_only(torch.fx.Node, set_env, node.args)
|
|
return super().run_node(node)
|
|
|
|
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
|
flattened_inputs = pytree.tree_flatten((args, kwargs))[0]
|
|
|
|
if self.unknown_value in flattened_inputs:
|
|
return self.unknown_value
|
|
|
|
# TODO - fix errors with this
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == aten._efficientzerotensor.default
|
|
):
|
|
return self.unknown_value
|
|
|
|
# skip constructors, since inductor generates optimal code for them already
|
|
# and turning into tensor would result in an additional global memory read
|
|
# TODO - more complicated strategy
|
|
if (
|
|
self.skip_constructors
|
|
and node.op != "get_attr"
|
|
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
|
|
):
|
|
return self.unknown_value
|
|
|
|
# All mutations should either be removed or on inputs which we did not make constant
|
|
if (
|
|
isinstance(node.target, torch._ops.OpOverload)
|
|
and torch.Tag.nondeterministic_seeded in node.target.tags
|
|
):
|
|
return self.unknown_value
|
|
|
|
out = super().run_node(node)
|
|
|
|
if node.op != "get_attr" and isinstance(out, torch.Tensor):
|
|
if not self.insertable_tensor_check(out):
|
|
return out
|
|
|
|
if self.is_impure(node):
|
|
return self.unknown_value
|
|
|
|
self.add_node_replacement(node, out)
|
|
|
|
flattened_node_inps = pytree.tree_flatten((node.args, node.kwargs))[0]
|
|
|
|
for n in flattened_node_inps:
|
|
if not isinstance(n, torch.fx.Node):
|
|
continue
|
|
|
|
self.replaced_uses[n] += 1
|
|
|
|
for to_delete in self.user_to_last_uses.get(node, []):
|
|
if self.replaced_uses[to_delete] == len(to_delete.users):
|
|
self.node_replacements.pop(to_delete, None)
|
|
|
|
return out
|
|
|
|
def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
|
|
return True
|
|
|
|
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
|
|
self.node_replacements[node] = tensor
|
|
|
|
def run(self):
|
|
env = {}
|
|
for n in self.module.graph.nodes:
|
|
if n.op == "placeholder":
|
|
env[n] = self.unknown_value
|
|
return super().run(initial_env=env)
|
|
|
|
|
|
@torch.utils._python_dispatch._disable_current_modes()
|
|
def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None):
|
|
cf = ConstantFolder(gm, skip_constructors=True)
|
|
cf.run()
|
|
|
|
for node, constant in cf.node_replacements.items():
|
|
if constraint_fn is not None and not constraint_fn(node):
|
|
continue
|
|
replace_node_with_constant(gm, node, constant)
|
|
|
|
erased_params = []
|
|
for node in gm.graph.nodes:
|
|
if node.op == "get_attr" and len(node.users) == 0:
|
|
delattr(gm, node.target)
|
|
erased_params.append(node)
|
|
|
|
for node in erased_params:
|
|
gm.graph.erase_node(node)
|
|
|
|
gm.graph.eliminate_dead_code()
|
|
gm.graph.lint()
|
|
gm.recompile()
|