mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "Revert D29135358: [quant] Input-Weight Equaliaztion - convert modifications" (#60646)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60646
This reverts commit e60f9cfc58.
Test Plan: Imported from OSS
Reviewed By: supriyar
Differential Revision: D29361191
Pulled By: angelayi
fbshipit-source-id: 275d8691d8e47da4ab80bb21b51d77ec25a0f714
This commit is contained in:
parent
7188d84ccf
commit
e13a9587b4
|
|
@ -2,15 +2,16 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.quantization import default_qconfig
|
||||
from torch.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
|
||||
from torch.quantization.quantize_fx import prepare_fx
|
||||
from torch.quantization.quantize_fx import prepare_fx, convert_fx
|
||||
from torch.quantization.fx._equalize import (
|
||||
_InputEqualizationObserver,
|
||||
_WeightEqualizationObserver,
|
||||
calculate_equalization_scale,
|
||||
default_equalization_qconfig,
|
||||
_convert_equalization_ref
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_quantization import NodeSpec as ns
|
||||
from torch.testing._internal.common_quantization import NodeSpec as ns, skipIfNoFBGEMM
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
|
||||
# Standard Libraries
|
||||
|
|
@ -238,3 +239,59 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
m = M().eval()
|
||||
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict)
|
||||
self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_input_weight_equalization_convert(self):
|
||||
"""
|
||||
"""
|
||||
qconfig_dict = {"": None,
|
||||
"object_type": [(nn.Linear, default_qconfig),
|
||||
(nn.functional.linear, default_qconfig)]}
|
||||
|
||||
default_equalization_qconfig_dict = {
|
||||
"": None,
|
||||
"object_type": [(nn.Linear, default_equalization_qconfig),
|
||||
(nn.functional.linear, default_equalization_qconfig)]
|
||||
}
|
||||
|
||||
# Basic test with one linear layer
|
||||
class LinearModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
# Test with two linear layer with a fp32 operation between
|
||||
class Linear2FP32Module(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(2, 2)
|
||||
self.linear2 = nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = torch.add(x, torch.tensor([1, 2]))
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
tests = [(LinearModule, default_equalization_qconfig_dict),
|
||||
(Linear2FP32Module, default_equalization_qconfig_dict)]
|
||||
|
||||
for (M, equalization_qconfig_dict) in tests:
|
||||
m = M().eval()
|
||||
x = torch.tensor([[1.0, 2.0], [2.0, 2.5], [4.5, 6.0]])
|
||||
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict)
|
||||
output = prepared(x)
|
||||
|
||||
convert_ref = _convert_equalization_ref(prepared)
|
||||
convert_ref_output = convert_ref(x)
|
||||
|
||||
m = M().eval()
|
||||
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict)
|
||||
prepared(x)
|
||||
convert_fx(prepared) # Check if compile?
|
||||
|
||||
self.assertEqual(output, convert_ref_output)
|
||||
|
|
|
|||
|
|
@ -1,15 +1,18 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Node
|
||||
|
||||
from .utils import get_new_attr_name_with_prefix, maybe_get_next_module
|
||||
from ..observer import (
|
||||
PerChannelMinMaxObserver,
|
||||
_with_args,
|
||||
ObserverBase,
|
||||
)
|
||||
from ..utils import check_min_max_valid
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
import warnings
|
||||
|
||||
|
||||
|
|
@ -28,14 +31,8 @@ class _InputEqualizationObserver(nn.Module):
|
|||
The running minimum/maximum :math:`x_\text{min/max}` are computed in the
|
||||
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`,
|
||||
with the difference that the running min/max values are stored per column.
|
||||
|
||||
The qparams are calculated by multiplying the min/max input column values
|
||||
with the equalization scale, reducing to find the global min/max input
|
||||
values, and then calculating in the same way as in
|
||||
:class:`~torch.quantization.observer.MinMaxObserver`
|
||||
|
||||
.. note:: If the running minimum equals to the running maximum, the scales
|
||||
and zero_points are set to 1.0 and 0.
|
||||
This observer is intended to be used along with a WeightEqualizationObserver
|
||||
to calculate the equalization scale.
|
||||
"""
|
||||
|
||||
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
|
||||
|
|
@ -70,8 +67,7 @@ class _InputEqualizationObserver(nn.Module):
|
|||
self.equalization_scale = equalization_scale
|
||||
|
||||
def calculate_scaled_minmax(self):
|
||||
r"""
|
||||
Returns the scaled min/max inputs
|
||||
r""" Returns the scaled min/max inputs
|
||||
"""
|
||||
if self.equalization_scale.nelement() == 0:
|
||||
warnings.warn(
|
||||
|
|
@ -104,21 +100,13 @@ class _WeightEqualizationObserver(nn.Module):
|
|||
quant_max: Maximum quantization value. If unspecified, it will
|
||||
follow the 8-bit setup.
|
||||
|
||||
This observer is made up of 2 PerChannelMinMaxObservers
|
||||
- weight_col_obs: Used to record the running minimum and maximum of
|
||||
columns of incoming weight tensors
|
||||
- weight_row_obs: Used to record the running minimum and maximum of
|
||||
rows of incoming weight tensors
|
||||
This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used
|
||||
to record the running minimum and maximum of columns of incoming weight
|
||||
tensors. This observer is intended to be used along with an
|
||||
InputEqualizationObserver to calculate the equalization scale.
|
||||
|
||||
The running minimum/maximum :math:`w_\text{min/max}` are computed in the
|
||||
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`.
|
||||
|
||||
The qparams are calculated by multiplying the min/max weight row values
|
||||
with the inverse of the equalization scale, and then calculating in the same
|
||||
way as in :class:`~torch.quantization.observer.PerChannelMinMaxObserver`
|
||||
|
||||
.. note:: If the running minimum equals to the running maximum, the scales
|
||||
and zero_points are set to 1.0 and 0.
|
||||
"""
|
||||
|
||||
def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None,
|
||||
|
|
@ -127,7 +115,7 @@ class _WeightEqualizationObserver(nn.Module):
|
|||
|
||||
self.dtype = dtype
|
||||
self.qscheme = qscheme
|
||||
self.ch_axis = 0
|
||||
self.ch_axis = 1
|
||||
|
||||
self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
|
||||
qscheme=qscheme,
|
||||
|
|
@ -224,3 +212,280 @@ def node_supports_equalization(node: Node, modules) -> bool:
|
|||
def is_equalization_observer(observer: nn.Module) -> bool:
|
||||
return (isinstance(observer, _InputEqualizationObserver) or
|
||||
isinstance(observer, _WeightEqualizationObserver))
|
||||
|
||||
def get_op_node_and_weight_eq_obs(
|
||||
input_eq_obs_node: Node,
|
||||
model: GraphModule,
|
||||
modules: Dict[str, nn.Module]
|
||||
) -> Tuple[Optional[Node], Optional[_WeightEqualizationObserver]]:
|
||||
""" Gets the following weight equalization observer. There should always
|
||||
exist a weight equalization observer after an input equalization observer.
|
||||
|
||||
Returns the node containing the weight equalization observer, and the weight
|
||||
equalization observer if it has been newly created
|
||||
"""
|
||||
|
||||
# Find the op node that comes directly after the input equaliation observer
|
||||
op_node = None
|
||||
for user in input_eq_obs_node.users.keys():
|
||||
if node_supports_equalization(user, modules):
|
||||
op_node = user
|
||||
break
|
||||
|
||||
assert(op_node is not None)
|
||||
if op_node.op == 'call_module':
|
||||
# If the op_node is a nn.Linear layer, then it must have a
|
||||
# WeightEqualizationObserver configuration
|
||||
equalization_qconfig_map: Dict[str, Any] = model._equalization_qconfig_map # type: ignore[assignment]
|
||||
assert(equalization_qconfig_map.get(op_node.name, None) is not None)
|
||||
weight_eq_obs = equalization_qconfig_map.get(op_node.name, None).weight()
|
||||
|
||||
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
|
||||
return op_node, weight_eq_obs
|
||||
|
||||
elif op_node.op == 'call_function':
|
||||
# TODO
|
||||
return None, None
|
||||
|
||||
return None, None
|
||||
|
||||
def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Optional[_InputEqualizationObserver]:
|
||||
""" Gets the following input equalization observer if it exists.
|
||||
|
||||
For example, in the case of connecting linear layers:
|
||||
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
|
||||
If the node being passed in is the linear1 node, then we want to return eq_obs2,
|
||||
the following equalization observer for linear2.
|
||||
|
||||
However, if there are no connecting layers:
|
||||
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add
|
||||
Then we want to return None.
|
||||
"""
|
||||
|
||||
assert(node_supports_equalization(node, modules))
|
||||
|
||||
# Locate the following output observer if it exists
|
||||
maybe_obs_node = maybe_get_next_module(node, modules, ObserverBase)
|
||||
if maybe_obs_node is None:
|
||||
return None
|
||||
|
||||
maybe_eq_obs_node = maybe_get_next_module(maybe_obs_node, modules, _InputEqualizationObserver)
|
||||
if maybe_eq_obs_node is None:
|
||||
return None
|
||||
|
||||
maybe_eq_obs = modules[str(maybe_eq_obs_node)]
|
||||
assert(isinstance(maybe_eq_obs, _InputEqualizationObserver))
|
||||
return maybe_eq_obs
|
||||
|
||||
def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]:
|
||||
""" If the next next node is an InputEqualizationObserver then we want to
|
||||
return its equalization scale, else we return 1
|
||||
|
||||
This is used in the case where there are two connecting linear layers:
|
||||
linear1 -> LinearOutObs -> InputEqObs -> linear2
|
||||
In this case, the node given is linear1 and we want to locate the InputEqObs.
|
||||
"""
|
||||
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
|
||||
if next_inp_eq_obs:
|
||||
return next_inp_eq_obs.equalization_scale
|
||||
return None
|
||||
|
||||
def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None:
|
||||
""" Scales the following input quantization observer's min/max values by
|
||||
updating the values with the scaled min/max values calculated by the input
|
||||
equalization observer
|
||||
"""
|
||||
input_eq_obs = modules[str(node.target)]
|
||||
assert(isinstance(input_eq_obs, _InputEqualizationObserver))
|
||||
|
||||
input_quant_obs_node = node.args[0]
|
||||
assert(isinstance(input_quant_obs_node, Node))
|
||||
|
||||
input_quant_obs = modules[str(input_quant_obs_node.target)]
|
||||
if not isinstance(input_quant_obs, ObserverBase):
|
||||
return
|
||||
|
||||
min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
|
||||
input_quant_obs.min_val = min_input_scaled
|
||||
input_quant_obs.max_val = max_input_scaled
|
||||
|
||||
def scale_weight_node(
|
||||
node: Node,
|
||||
modules: Dict[str, nn.Module],
|
||||
equalization_scale: torch.Tensor,
|
||||
next_equalization_scale: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
""" Scale the weights for input-weight equalization by multiplying the
|
||||
weight by 1/equalization_scale and next_equalization_scale
|
||||
|
||||
Args:
|
||||
node: Current node whose weights we want to scale
|
||||
equalization_scale: Current node's calculated equalization scale
|
||||
next_equalization_scale: Next node's calculated equalization scale if
|
||||
the following node needs to be equalized, 1 otherwise
|
||||
"""
|
||||
assert(isinstance(node.target, str))
|
||||
|
||||
# Scale the weights for input-weight equalization
|
||||
# If the following layer needs to be equalized then we will multiply its scale
|
||||
weight = modules[node.target].weight
|
||||
assert(isinstance(weight, torch.Tensor))
|
||||
|
||||
scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale))
|
||||
|
||||
if next_equalization_scale is None:
|
||||
modules[node.target].weight = nn.Parameter(scaled_weight)
|
||||
return
|
||||
|
||||
scaled_weight = torch.mul(scaled_weight, next_equalization_scale)
|
||||
modules[node.target].weight = nn.Parameter(scaled_weight)
|
||||
|
||||
# TODO: The bias may need to be scaled for connecting linear layers
|
||||
bias = modules[node.target].bias
|
||||
assert(isinstance(bias, torch.Tensor))
|
||||
|
||||
scaled_bias = torch.mul(bias, next_equalization_scale)
|
||||
modules[node.target].bias = nn.Parameter(scaled_bias)
|
||||
|
||||
def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module]) -> Dict[str, _WeightEqualizationObserver]:
|
||||
""" Update all of the observer's equalization scale. For each
|
||||
InputEqualizationObserver, we will find the location of the next
|
||||
WeightEqualizationObserver, create it, and calculate the equalization scale
|
||||
based on the two observers.
|
||||
|
||||
We will then return a dictionary mapping operation node names to
|
||||
the corresponding WeightEqualizationObservers for that operation.
|
||||
"""
|
||||
weight_eq_obs_dict = {}
|
||||
for node in model.graph.nodes:
|
||||
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
|
||||
input_eq_obs = modules[node.target]
|
||||
assert(isinstance(input_eq_obs, _InputEqualizationObserver))
|
||||
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
|
||||
|
||||
if op_node is None or weight_eq_obs is None:
|
||||
continue
|
||||
|
||||
weight_eq_obs(modules[str(op_node.target)].weight)
|
||||
|
||||
# Calculate and set the equalization scale values
|
||||
equalization_scale = calculate_equalization_scale(input_eq_obs, weight_eq_obs)
|
||||
input_eq_obs.set_equalization_scale(equalization_scale)
|
||||
weight_eq_obs.set_equalization_scale(equalization_scale)
|
||||
|
||||
weight_eq_obs_dict[op_node.name] = weight_eq_obs
|
||||
|
||||
return weight_eq_obs_dict
|
||||
|
||||
def convert_eq_obs(
|
||||
model: GraphModule,
|
||||
modules: Dict[str, nn.Module],
|
||||
weight_eq_obs_dict: Dict[str, _WeightEqualizationObserver],
|
||||
) -> None:
|
||||
""" Converts the equalization operations and updates the other nodes in the
|
||||
following way:
|
||||
- Removes the input equalization observers and inserts a mul operator
|
||||
along with an equalization scale node wherever applicable (we do not
|
||||
want to insert a mul operator between connecting linear layers).
|
||||
- Updates the input quantization observers with the scaled input min/max
|
||||
values.
|
||||
- Scales the weights by the current and next equalization scales.
|
||||
- Removes the weight equalization observer node if it exists.
|
||||
|
||||
Before (after prepare):
|
||||
weight values
|
||||
|
|
||||
WeightQuantObs
|
||||
|
|
||||
WeightEqObs
|
||||
|
|
||||
x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs
|
||||
|
||||
After this function:
|
||||
scaled weight values
|
||||
|
|
||||
equalization scale WeightQuantObs
|
||||
| |
|
||||
x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs
|
||||
|
||||
After convert:
|
||||
equalization scale scaled weight values
|
||||
| |
|
||||
x -> mul -> quantize_per_tensor -> quantized::linear
|
||||
|
||||
Note that although the equalization observer appeared after the quantization
|
||||
observer after prepare_fx, the mul node appears before the quantization node
|
||||
after convert_fx. This is because placing the equalization observer after
|
||||
the quantization observer in prepare_fx would allow us to keep the invariant
|
||||
that the graph before the current node inserts its observers is not
|
||||
modified.
|
||||
|
||||
Having the equalization observer before the quantization observer would also
|
||||
cause some inconsistences between the ordering of the quantization and
|
||||
equalization observers.
|
||||
For example, a single linear layer would look like:
|
||||
x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1
|
||||
But between two connected linear layers, it would look like:
|
||||
linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2
|
||||
"""
|
||||
for node in model.graph.nodes:
|
||||
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
|
||||
inp_quant_obs_node = node.args[0]
|
||||
prev_node = inp_quant_obs_node.args[0]
|
||||
|
||||
# TODO: Possible special handling for connected linear layers
|
||||
# Update the following input quantization observer's min/max values
|
||||
scale_input_observer(node, modules)
|
||||
|
||||
# Remove the InputEqualization node and add a mul operator before
|
||||
# the quantization observer node that appears before the equalization node
|
||||
# Before: x -> input_quant_obs -> input_eq_obs -> linear
|
||||
# After: x -> mul -> input_quant_obs -> linear
|
||||
|
||||
# Create a node containing the equalization scale
|
||||
with model.graph.inserting_before(inp_quant_obs_node):
|
||||
get_new_eq_scale_name = get_new_attr_name_with_prefix(prev_node.name + '_equalization_scale')
|
||||
name = get_new_eq_scale_name(modules)
|
||||
setattr(model, name, modules[node.target].equalization_scale)
|
||||
eq_scale_node = model.graph.create_node('get_attr', name)
|
||||
|
||||
# Create a node multiplying the input with the equalization scale
|
||||
with model.graph.inserting_after(eq_scale_node):
|
||||
inputs = (prev_node, eq_scale_node)
|
||||
mul_node = model.graph.create_node("call_function", torch.mul, inputs)
|
||||
|
||||
# Set the mul nod to be the input_quant_obs_node's input instead of
|
||||
# the previous node
|
||||
inp_quant_obs_node.replace_input_with(prev_node, mul_node)
|
||||
|
||||
# For all of the current node's users, replace the current node with
|
||||
# the input quantization observer node
|
||||
orig_users = list(node.users.keys())
|
||||
for user_node in orig_users:
|
||||
user_node.replace_input_with(node, inp_quant_obs_node)
|
||||
|
||||
# Erase the InputEqualizationObserver node
|
||||
model.graph.erase_node(node)
|
||||
|
||||
elif weight_eq_obs_dict.get(node.name, None) is not None:
|
||||
weight_eq_obs = weight_eq_obs_dict.get(node.name)
|
||||
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
|
||||
|
||||
equalization_scale = weight_eq_obs.equalization_scale
|
||||
|
||||
# Scales the weights and runs the weight quantization observers
|
||||
maybe_next_equalization_scale = maybe_get_next_equalization_scale(node, modules)
|
||||
scale_weight_node(node, modules, equalization_scale, maybe_next_equalization_scale)
|
||||
|
||||
def _convert_equalization_ref(model: GraphModule):
|
||||
""" Reference function which applies changes needed for equalization, but
|
||||
does not quantize the nodes
|
||||
"""
|
||||
modules = dict(model.named_modules(remove_duplicate=False))
|
||||
|
||||
# Calculate the equalization scale, update the observers with the scaled
|
||||
# inputs, and scale the weight
|
||||
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
|
||||
convert_eq_obs(model, modules, weight_eq_obs_dict)
|
||||
|
||||
return GraphModule(model, model.graph)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from .graph_module import (
|
|||
from .quantization_patterns import (
|
||||
QuantizeHandler,
|
||||
)
|
||||
from ._equalize import update_obs_for_equalization, convert_eq_obs
|
||||
from .utils import (
|
||||
is_get_tensor_info_node,
|
||||
node_return_type_is_int,
|
||||
|
|
@ -180,6 +181,13 @@ def convert(model: GraphModule, is_reference: bool = False,
|
|||
qconfig_map,
|
||||
custom_module_classes=custom_module_classes)
|
||||
|
||||
if model._equalization_qconfig_map is not None:
|
||||
# If we want to do equalization then do the following:
|
||||
# Calculate the equalization scale, update the observers with the scaled
|
||||
# inputs, and scale the weight
|
||||
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
|
||||
convert_eq_obs(model, modules, weight_eq_obs_dict)
|
||||
|
||||
quantized_graph = Graph()
|
||||
env: Dict[str, Dict[Optional[torch.dtype], Node]] = defaultdict(lambda: defaultdict(Node)) # type: ignore[arg-type]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..utils import is_per_tensor, is_per_channel
|
||||
from ..quantize import is_activation_post_process
|
||||
|
||||
|
|
@ -10,7 +11,7 @@ from torch.fx.graph import (
|
|||
Node,
|
||||
)
|
||||
|
||||
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union
|
||||
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
|
||||
import operator
|
||||
|
||||
# A dictionary for querying the weight index for a given op
|
||||
|
|
@ -481,3 +482,23 @@ def is_get_tensor_info_node(node: Node) -> bool:
|
|||
result: bool = \
|
||||
node.op == "call_function" and node.target == getattr and node.args[1] == "shape" # type: ignore[assignment]
|
||||
return result
|
||||
|
||||
def maybe_get_next_module(
|
||||
node: Node,
|
||||
modules: Dict[str, nn.Module],
|
||||
target_module_type: Type[nn.Module],
|
||||
) -> Optional[Node]:
|
||||
""" Gets the next module that matches what is needed in
|
||||
is_target_module_type if it exists
|
||||
|
||||
Args:
|
||||
node: The node whose users we want to look at
|
||||
is_target_module_type: Function that returns true if the given module
|
||||
matches the type specified in the function.
|
||||
"""
|
||||
|
||||
for user, _ in node.users.items():
|
||||
if user.op == 'call_module' and isinstance(modules[str(user.target)], target_module_type):
|
||||
return user
|
||||
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user