mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: X-link: https://github.com/pytorch/executorch/pull/3 att Test Plan: Imported from OSS Differential Revision: D47202807 Pull Request resolved: https://github.com/pytorch/pytorch/pull/104668 Approved by: https://github.com/andrewor14
293 lines
11 KiB
Python
293 lines
11 KiB
Python
import torch
|
|
from torch.fx import (
|
|
Graph,
|
|
GraphModule,
|
|
Node,
|
|
)
|
|
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
|
import torch.nn.functional as F
|
|
from torch.nn.utils.fusion import fuse_conv_bn_weights
|
|
# TODO[jerryzh168]: move this to a more general util function
|
|
from torch.ao.quantization.fx.prepare import (
|
|
_is_activation_post_process_node,
|
|
)
|
|
import copy
|
|
import operator
|
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
|
|
__all__ = [
|
|
"fold_bn_weights_into_conv_node",
|
|
"get_aten_graph_module",
|
|
"remove_tensor_overload_for_qdq_ops",
|
|
]
|
|
|
|
def _get_tensor_constant_from_node(node, m):
|
|
if node is None:
|
|
return None
|
|
assert node.op == "get_attr"
|
|
return getattr(m, node.target)
|
|
|
|
def _get_all_arguments(orig_args, orig_kwargs, args_schema):
|
|
all_args = []
|
|
for i, schema in enumerate(args_schema):
|
|
if schema.name in orig_kwargs:
|
|
all_args.append(orig_kwargs[schema.name])
|
|
elif not schema.kwarg_only and i < len(orig_args):
|
|
all_args.append(orig_args[i])
|
|
else:
|
|
all_args.append(schema.default_value)
|
|
return all_args
|
|
|
|
def fold_bn_weights_into_conv_node(
|
|
conv_node: Node,
|
|
conv_weight_node: Node,
|
|
conv_bias_node: Optional[Node],
|
|
bn_node: Node,
|
|
m: GraphModule
|
|
) -> None:
|
|
# conv args: input, weight, bias, stride, padding, dilation, transposed, ...
|
|
conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
|
|
conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
|
|
transpose = conv_node.args[6]
|
|
|
|
# eval bn args: input, weight, bias, running mean, running var, momentum, eps
|
|
# train bn args: input, weight, bias, running mean, running var, training, momentum, eps
|
|
bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr]
|
|
bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema)
|
|
bn_w = _get_tensor_constant_from_node(bn_args[1], m)
|
|
bn_b = _get_tensor_constant_from_node(bn_args[2], m)
|
|
bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
|
|
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
|
|
if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
|
|
eps_arg_index = 6
|
|
elif bn_node.target == torch.ops.aten._native_batch_norm_legit.default:
|
|
eps_arg_index = 7
|
|
else:
|
|
raise ValueError("BN node target is unexpected ", bn_node.target)
|
|
bn_eps = bn_args[eps_arg_index]
|
|
|
|
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)
|
|
|
|
# update the weight and bias for conv
|
|
conv_args = list(conv_node.args)
|
|
# calling data since the fused_weight and fused_bias are nn.Parameter
|
|
weight_attr_name = conv_weight_node.target
|
|
assert isinstance(weight_attr_name, str)
|
|
setattr(m, weight_attr_name, fused_weight)
|
|
if conv_bias_node is not None:
|
|
bias_attr_name = conv_bias_node.target
|
|
else:
|
|
bias_attr_name = weight_attr_name + "_bias"
|
|
with m.graph.inserting_before(conv_node):
|
|
get_bias_node = m.graph.get_attr(bias_attr_name)
|
|
# NOTE: here we assume the bias of conv is not quantized!
|
|
conv_args[2] = get_bias_node
|
|
setattr(m, bias_attr_name, fused_bias) # type: ignore[arg-type]
|
|
conv_node.args = tuple(conv_args)
|
|
|
|
# native_batch_norm has 3 outputs, we expect getitem calls on the output
|
|
# and we want to replace the uses of getitem 0 with the output of conv
|
|
#
|
|
# Before:
|
|
# conv -> bn - (first output) -> users1
|
|
# \ - (second output) -> users2
|
|
# \ - (third output) -> users3
|
|
# After:
|
|
# conv -> (first output) -> users1
|
|
# bn -
|
|
# \ - (second output) -> users2
|
|
# \ - (third output) -> users3
|
|
# if users2 and users3 are empty then bn will be removed through dead code elimination
|
|
|
|
for user in bn_node.users:
|
|
if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
|
|
continue
|
|
user.replace_all_uses_with(conv_node)
|
|
|
|
# fuse conv bn weights, inplace modification of the graph_module and graph
|
|
def _fuse_conv_bn_(m: GraphModule) -> None:
|
|
for n in m.graph.nodes:
|
|
if n.op != "call_function" or n.target != torch.ops.aten._native_batch_norm_legit_no_training.default:
|
|
continue
|
|
bn_node = n
|
|
n = bn_node.args[0]
|
|
if n.op != "call_function" or n.target != torch.ops.aten.convolution.default:
|
|
continue
|
|
conv_node = n
|
|
conv_weight_node = conv_node.args[1]
|
|
conv_bias_node = conv_node.args[2]
|
|
fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)
|
|
|
|
m.graph.eliminate_dead_code()
|
|
m.recompile()
|
|
|
|
# TODO: remove hack when we have better support for pattern matching
|
|
# move around the observer for addmm
|
|
def _rearrange_weight_observer_for_decomposed_linear(
|
|
model: GraphModule,
|
|
) -> None:
|
|
"""
|
|
Linear is decomposed to `t - addmm` (w/ bias) or `t - mm` (w/o bias)
|
|
before:
|
|
weight - t - observer \
|
|
input - observer - addmm/mm
|
|
after:
|
|
weight - observer - t \
|
|
input - observer - addmm/mm
|
|
"""
|
|
aten = torch.ops.aten
|
|
op_to_weight_obs_index = {
|
|
aten.addmm.default : 2,
|
|
aten.mm.default : 1,
|
|
}
|
|
named_modules = dict(model.named_modules(remove_duplicate=False))
|
|
for node in model.graph.nodes:
|
|
if node.target not in (aten.addmm.default, aten.mm.default):
|
|
continue
|
|
root_node = node
|
|
maybe_weight_obs = root_node.args[op_to_weight_obs_index[root_node.target]]
|
|
if not _is_activation_post_process_node(maybe_weight_obs, named_modules):
|
|
continue
|
|
transpose_node = maybe_weight_obs.args[0]
|
|
if transpose_node.target != torch.ops.aten.t.default:
|
|
continue
|
|
# swap the order of transpose and observation
|
|
|
|
maybe_weight_obs.replace_input_with(transpose_node, transpose_node.args[0])
|
|
# remove the transpose node
|
|
with model.graph.inserting_after(maybe_weight_obs):
|
|
args = list(transpose_node.args)
|
|
args[0] = maybe_weight_obs
|
|
new_transpose_node = model.graph.create_node(
|
|
"call_function",
|
|
torch.ops.aten.t.default,
|
|
tuple(args),
|
|
transpose_node.kwargs
|
|
)
|
|
root_node.replace_input_with(maybe_weight_obs, new_transpose_node)
|
|
|
|
model.graph.eliminate_dead_code()
|
|
model.graph.lint()
|
|
model.recompile()
|
|
|
|
def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
|
|
# TODO: move this information to fx node itself
|
|
node_name_to_scope: Dict[str, Tuple[str, type]] = {}
|
|
for n in model.graph.nodes:
|
|
nn_module_stack = n.meta.get("nn_module_stack", None)
|
|
current_scope = ("", type(None))
|
|
if nn_module_stack:
|
|
bt = list(nn_module_stack.values())[-1]
|
|
current_scope = (bt[0].split(".")[-1], bt[1])
|
|
node_name_to_scope[n.name] = current_scope
|
|
return node_name_to_scope
|
|
|
|
def get_aten_graph_module(
|
|
pattern: Callable,
|
|
example_inputs: Tuple[Any, ...],
|
|
**kwargs,
|
|
) -> GraphModule:
|
|
"""
|
|
Convert the pattern to an FX graph with decomposed aten ops.
|
|
"""
|
|
# Avoid circular imports
|
|
import torch._dynamo
|
|
aten_pattern, _ = torch._dynamo.export(
|
|
pattern,
|
|
*copy.deepcopy(example_inputs),
|
|
aten_graph=True,
|
|
tracing_mode="real",
|
|
**kwargs,
|
|
)
|
|
aten_pattern.graph.eliminate_dead_code()
|
|
aten_pattern.recompile()
|
|
return aten_pattern
|
|
|
|
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
|
|
""" Remove .tensor overload for quantize/dequantize ops so that we can
|
|
use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
|
|
"""
|
|
_MAP = {
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor,
|
|
torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel,
|
|
torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp,
|
|
}
|
|
for n in match_pattern.graph.nodes:
|
|
if n.op != "call_function":
|
|
continue
|
|
if n.target in _MAP:
|
|
n.target = _MAP[n.target]
|
|
|
|
def _is_dropout_filter(
|
|
match: "InternalMatch", # type: ignore[name-defined]
|
|
original_graph: Graph,
|
|
pattern_graph: Graph,
|
|
) -> bool:
|
|
"""
|
|
Match filter for the subgraph rewriter that returns True if the matched
|
|
graph includes all the ops used in the aten dropout pattern.
|
|
"""
|
|
ops_to_match = {
|
|
torch.ops.aten.empty_like.default,
|
|
torch.ops.aten.bernoulli_.float,
|
|
torch.ops.aten.div_.Scalar,
|
|
torch.ops.aten.mul.Tensor,
|
|
}
|
|
for n in match.nodes_map.values():
|
|
if n.target in ops_to_match:
|
|
ops_to_match.remove(n.target)
|
|
return len(ops_to_match) == 0
|
|
|
|
def _replace_dropout_for_eval(m: GraphModule):
|
|
"""
|
|
Replace the aten training dropout pattern with a noop, intended for eval.
|
|
|
|
For models with dropout torch ops (nn.Dropout, F.dropout), calling model.eval()
|
|
effectively turns these dropout ops into noops. For exported models, however,
|
|
this is not done automatically, since the aten dropout patterns previously generated
|
|
for training remain in the graph. Here we rewrite these dropout patterns with noops
|
|
to avoid incorrectly applying further dropout during eval.
|
|
|
|
See https://github.com/pytorch/pytorch/issues/103681.
|
|
"""
|
|
def dropout_train(x):
|
|
return F.dropout(x, p=0.5, training=True)
|
|
|
|
def dropout_eval(x):
|
|
return F.dropout(x, p=0.5, training=False)
|
|
|
|
example_inputs = (torch.randn(1),)
|
|
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
|
|
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
|
|
|
|
# Note: The match pattern looks like:
|
|
#
|
|
# empty_like_default = torch.ops.aten.empty_like.default(x)
|
|
# bernoulli__float = torch.ops.aten.bernoulli_.float(empty_like_default)
|
|
# div__scalar = torch.ops.aten.div_.Scalar(bernoulli__float, 0.5)
|
|
# mul_tensor = torch.ops.aten.mul.Tensor(x, div__scalar)
|
|
#
|
|
# We need to use `ignore_literals=True` here to handle arbitrary dropout
|
|
# probability (not just 0.5). However, without a match filter, this would
|
|
# also match any mul op, since `div__scalar` is also a literal, e.g.:
|
|
#
|
|
# mul_tensor = torch.ops.aten.mul.Tensor(x, 0.8)
|
|
#
|
|
# Therefore, we need both `ignore_literals=True` and `_is_dropout_filter`
|
|
# to make sure we are in fact replacing the dropout pattern.
|
|
|
|
replace_pattern_with_filters(
|
|
m,
|
|
match_pattern,
|
|
replacement_pattern,
|
|
match_filters=[_is_dropout_filter],
|
|
ignore_literals=True,
|
|
)
|
|
m.recompile()
|