[ao] Added generate report capability to ModelReport class

Summary: The ModelReport class in model_report.py combines the
functionality of the detectors and the ModelReportObserver. It creates
an end-to-end system where a user can pass in a prepared Graph Model to
insert the ModelReportObservers, then after the user callibrates their
model, the callibrated model can then be used by the ModelReport class
to generate reports based on what the user wished to gather information
on.

This contains the implementation and the tests for the generate_report
method which is used on a callibrated fx model to generate reports based
on data collected by the inserted observers during the callibration
phase and also potentially remove those observers if desired.

This also addresses and fixes a revert issue that has been fixed.

Test Plan: python test/test_quantization.py TestFxModelReportClass

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80054

Approved by: https://github.com/HDCharles
This commit is contained in:
vspenubarthi 2022-06-22 10:59:27 -07:00 committed by PyTorch MergeBot
parent f714d8f574
commit 70be6f8470
3 changed files with 192 additions and 12 deletions

View File

@ -827,6 +827,33 @@ class TestFxModelReportDetectDynamicStatic(QuantizationTestCase):
class TestFxModelReportClass(QuantizationTestCase):
# example model to use for tests
class ThreeOps(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
self.bn = nn.BatchNorm2d(3)
self.relu = nn.ReLU()
def forward(self, x):
x = self.linear(x)
x = self.bn(x)
x = self.relu(x)
return x
class TwoThreeOps(nn.Module):
def __init__(self):
super().__init__()
self.block1 = TestFxModelReportClass.ThreeOps()
self.block2 = TestFxModelReportClass.ThreeOps()
def forward(self, x):
x = self.block1(x)
y = self.block2(x)
z = x + y
z = F.relu(z)
return z
@skipIfNoFBGEMM
def test_constructor(self):
"""
@ -958,3 +985,104 @@ class TestFxModelReportClass(QuantizationTestCase):
# ensure that we can prepare for callibration only once
with self.assertRaises(ValueError):
prepared_for_callibrate_model = model_report.prepare_detailed_calibration(model_prep)
def get_module_and_graph_cnts(self, callibrated_fx_module):
r"""
Calculates number of ModelReportObserver modules in the model as well as the graph structure.
Returns a tuple of two elements:
int: The number of ModelReportObservers found in the model
int: The number of model_report nodes found in the graph
"""
# get the number of observers stored as modules
modules_observer_cnt = 0
for fqn, module in callibrated_fx_module.named_modules():
if isinstance(module, ModelReportObserver):
modules_observer_cnt += 1
# get number of observers in the graph
model_report_str_check = "model_report"
graph_observer_cnt = 0
# also make sure arguments for observers in the graph are proper
for node in callibrated_fx_module.graph.nodes:
# not all node targets are strings, so check
if isinstance(node.target, str) and model_report_str_check in node.target:
# increment if we found a graph observer
graph_observer_cnt += 1
return (modules_observer_cnt, graph_observer_cnt)
@skipIfNoFBGEMM
def test_generate_report(self):
"""
Tests model_report.generate_model_report to ensure report generation
Specifically looks at:
- Whether correct number of reports are being generated
- Whether observers are being properly removed if specified
- Whether correct blocking from generating report twice if obs removed
"""
with override_quantized_engine('fbgemm'):
# set the backend for this test
torch.backends.quantized.engine = "fbgemm"
# check whether the correct number of reports are being generated
filled_detector_set = set([DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)])
single_detector_set = set([DynamicStaticDetector()])
# initialize one with filled detector
model_report_full = ModelReport(filled_detector_set)
# initialize another with a single detector set
model_report_single = ModelReport(single_detector_set)
# prepare and callibrate two different instances of same model
# prepare the model
model_full = TestFxModelReportClass.TwoThreeOps()
model_single = TestFxModelReportClass.TwoThreeOps()
example_input = torch.randn(1, 3, 3, 3)
current_backend = torch.backends.quantized.engine
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
model_prep_full = quantize_fx.prepare_fx(model_full, q_config_mapping, example_input)
model_prep_single = quantize_fx.prepare_fx(model_single, q_config_mapping, example_input)
# prepare the models for callibration
prepared_for_callibrate_model_full = model_report_full.prepare_detailed_calibration(model_prep_full)
prepared_for_callibrate_model_single = model_report_single.prepare_detailed_calibration(model_prep_single)
# now callibrate the two models
num_iterations = 10
for i in range(num_iterations):
example_input = torch.tensor(torch.randint(100, (1, 3, 3, 3)), dtype=torch.float)
prepared_for_callibrate_model_full(example_input)
prepared_for_callibrate_model_single(example_input)
# now generate the reports
model_full_report = model_report_full.generate_model_report(
prepared_for_callibrate_model_full, True
)
model_single_report = model_report_single.generate_model_report(prepared_for_callibrate_model_single, False)
# check that sizes are appropriate
self.assertEqual(len(model_full_report), len(filled_detector_set))
self.assertEqual(len(model_single_report), len(single_detector_set))
# make sure observers are being properly removed for full report since we put flag in
modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_full)
self.assertEqual(modules_observer_cnt, 0) # assert no more observer modules
self.assertEqual(graph_observer_cnt, 0) # assert no more observer nodes in graph
# make sure observers aren't being removed for single report since not specified
modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_single)
self.assertNotEqual(modules_observer_cnt, 0)
self.assertNotEqual(graph_observer_cnt, 0)
# make sure error when try to rerun report generation for full report but not single report
with self.assertRaises(Exception):
model_full_report = model_report_full.generate_model_report(
prepared_for_callibrate_model_full, False
)
# make sure we don't run into error for single report
model_single_report = model_report_single.generate_model_report(prepared_for_callibrate_model_single, False)

View File

@ -73,7 +73,7 @@ class PerChannelDetector(DetectorBase):
"onednn": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
}
def __init__(self, backend=torch.backends.quantized.engine):
def __init__(self, backend: str = torch.backends.quantized.engine):
super().__init__()
# store the backend information

View File

@ -32,27 +32,28 @@ class ModelReport:
# keep the reports private so they can't be modified
self._desired_report_detectors = desired_report_detectors
self._desired_reports = set([detector.get_detector_name() for detector in desired_report_detectors])
self._desired_detector_names = set([detector.get_detector_name() for detector in desired_report_detectors])
# keep a mapping of desired reports to observers of interest
# this is to get the readings, and to remove them, can create a large set
# this set can then be used to traverse the graph and remove added observers
self._report_name_to_observer_fqns: Dict[str, Set[str]] = {}
self._detector_name_to_observer_fqns: Dict[str, Set[str]] = {}
# initialize each report to have empty set of observers of interest
for desired_report in self._desired_reports:
self._report_name_to_observer_fqns[desired_report] = set([])
for desired_report in self._desired_detector_names:
self._detector_name_to_observer_fqns[desired_report] = set([])
# flags to ensure that we can only prepare once
# flags to ensure that we can only prepare and generate report once
self._prepared_flag = False
self._removed_observers = False
def get_desired_reports_names(self) -> Set[str]:
""" Returns a copy of the desired reports for viewing """
return self._desired_reports.copy()
return self._desired_detector_names.copy()
def get_observers_of_interest(self) -> Dict[str, Set[str]]:
""" Returns a copy of the observers of interest for viewing """
return self._report_name_to_observer_fqns.copy()
return self._detector_name_to_observer_fqns.copy()
def prepare_detailed_calibration(self, prepared_fx_model: GraphModule) -> GraphModule:
r"""
@ -61,7 +62,7 @@ class ModelReport:
Each observer is inserted based on the desired_reports into the relavent locations
Right now, each report in self._desired_reports has independent insertions
Right now, each report in self._desired_detector_names has independent insertions
However, if a module already has a Observer of the same type, the insertion will not occur
This is because all of the same type of Observer collect same information, so redundant
@ -84,7 +85,7 @@ class ModelReport:
# map each insert point to the observer to use
insert_observers_fqns.update(obs_fqn_to_info)
# update the set of observers this report cares about
self._report_name_to_observer_fqns[detector.get_detector_name()] = set(obs_fqn_to_info.keys())
self._detector_name_to_observer_fqns[detector.get_detector_name()] = set(obs_fqn_to_info.keys())
# now insert all the observers at their desired locations
for observer_fqn in insert_observers_fqns:
@ -142,7 +143,20 @@ class ModelReport:
Returns the Node object of the given node_fqn otherwise returns None
"""
pass
node_to_return = None
for node in fx_model.graph.nodes:
# if the target matches the fqn, it's the node we are looking for
if node.target == node_fqn:
node_to_return = node
break
if node_to_return is None:
raise ValueError("The node_fqn is was not found within the module.")
# assert for MyPy
assert isinstance(node_to_return, torch.fx.node.Node)
return node_to_return
def generate_model_report(
self, calibrated_fx_model: GraphModule, remove_inserted_observers: bool
@ -161,4 +175,42 @@ class ModelReport:
The textual summary of that report information
A dictionary containing relavent statistics or information for that report
"""
pass
# if we already removed the observers, we cannot generate report
if self._removed_observers:
raise Exception("Cannot generate report on model you already removed observers from")
# keep track of all the reports of interest and their outputs
reports_of_interest = {}
for detector in self._desired_report_detectors:
# generate the individual report for the detector
report_output = detector.generate_detector_report(calibrated_fx_model)
reports_of_interest[detector.get_detector_name()] = report_output
# if user wishes to remove inserted observers, go ahead and remove
if remove_inserted_observers:
self._removed_observers = True
# get the set of all Observers inserted by this instance of ModelReport
all_observers_of_interest: Set[str] = set([])
for desired_report in self._detector_name_to_observer_fqns:
observers_of_interest = self._detector_name_to_observer_fqns[desired_report]
all_observers_of_interest.update(observers_of_interest)
# go through all_observers_of_interest and remove them from the graph and model
for observer_fqn in all_observers_of_interest:
# remove the observer from the model
calibrated_fx_model.delete_submodule(observer_fqn)
# remove the observer from the graph structure
node_obj = self._get_node_from_fqn(calibrated_fx_model, observer_fqn)
if node_obj:
calibrated_fx_model.graph.erase_node(node_obj)
else:
raise ValueError("Node no longer exists in GraphModule structure")
# remember to recompile the model
calibrated_fx_model.recompile()
# return the reports of interest
return reports_of_interest