mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ao] Added generate report capability to ModelReport class
Summary: The ModelReport class in model_report.py combines the functionality of the detectors and the ModelReportObserver. It creates an end-to-end system where a user can pass in a prepared Graph Model to insert the ModelReportObservers, then after the user callibrates their model, the callibrated model can then be used by the ModelReport class to generate reports based on what the user wished to gather information on. This contains the implementation and the tests for the generate_report method which is used on a callibrated fx model to generate reports based on data collected by the inserted observers during the callibration phase and also potentially remove those observers if desired. This also addresses and fixes a revert issue that has been fixed. Test Plan: python test/test_quantization.py TestFxModelReportClass Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/80054 Approved by: https://github.com/HDCharles
This commit is contained in:
parent
f714d8f574
commit
70be6f8470
|
|
@ -827,6 +827,33 @@ class TestFxModelReportDetectDynamicStatic(QuantizationTestCase):
|
||||||
|
|
||||||
class TestFxModelReportClass(QuantizationTestCase):
|
class TestFxModelReportClass(QuantizationTestCase):
|
||||||
|
|
||||||
|
# example model to use for tests
|
||||||
|
class ThreeOps(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(3, 3)
|
||||||
|
self.bn = nn.BatchNorm2d(3)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class TwoThreeOps(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.block1 = TestFxModelReportClass.ThreeOps()
|
||||||
|
self.block2 = TestFxModelReportClass.ThreeOps()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.block1(x)
|
||||||
|
y = self.block2(x)
|
||||||
|
z = x + y
|
||||||
|
z = F.relu(z)
|
||||||
|
return z
|
||||||
|
|
||||||
@skipIfNoFBGEMM
|
@skipIfNoFBGEMM
|
||||||
def test_constructor(self):
|
def test_constructor(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -958,3 +985,104 @@ class TestFxModelReportClass(QuantizationTestCase):
|
||||||
# ensure that we can prepare for callibration only once
|
# ensure that we can prepare for callibration only once
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
prepared_for_callibrate_model = model_report.prepare_detailed_calibration(model_prep)
|
prepared_for_callibrate_model = model_report.prepare_detailed_calibration(model_prep)
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_and_graph_cnts(self, callibrated_fx_module):
|
||||||
|
r"""
|
||||||
|
Calculates number of ModelReportObserver modules in the model as well as the graph structure.
|
||||||
|
Returns a tuple of two elements:
|
||||||
|
int: The number of ModelReportObservers found in the model
|
||||||
|
int: The number of model_report nodes found in the graph
|
||||||
|
"""
|
||||||
|
# get the number of observers stored as modules
|
||||||
|
modules_observer_cnt = 0
|
||||||
|
for fqn, module in callibrated_fx_module.named_modules():
|
||||||
|
if isinstance(module, ModelReportObserver):
|
||||||
|
modules_observer_cnt += 1
|
||||||
|
|
||||||
|
# get number of observers in the graph
|
||||||
|
model_report_str_check = "model_report"
|
||||||
|
graph_observer_cnt = 0
|
||||||
|
# also make sure arguments for observers in the graph are proper
|
||||||
|
for node in callibrated_fx_module.graph.nodes:
|
||||||
|
# not all node targets are strings, so check
|
||||||
|
if isinstance(node.target, str) and model_report_str_check in node.target:
|
||||||
|
# increment if we found a graph observer
|
||||||
|
graph_observer_cnt += 1
|
||||||
|
|
||||||
|
return (modules_observer_cnt, graph_observer_cnt)
|
||||||
|
|
||||||
|
@skipIfNoFBGEMM
|
||||||
|
def test_generate_report(self):
|
||||||
|
"""
|
||||||
|
Tests model_report.generate_model_report to ensure report generation
|
||||||
|
Specifically looks at:
|
||||||
|
- Whether correct number of reports are being generated
|
||||||
|
- Whether observers are being properly removed if specified
|
||||||
|
- Whether correct blocking from generating report twice if obs removed
|
||||||
|
"""
|
||||||
|
|
||||||
|
with override_quantized_engine('fbgemm'):
|
||||||
|
# set the backend for this test
|
||||||
|
torch.backends.quantized.engine = "fbgemm"
|
||||||
|
|
||||||
|
# check whether the correct number of reports are being generated
|
||||||
|
filled_detector_set = set([DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)])
|
||||||
|
single_detector_set = set([DynamicStaticDetector()])
|
||||||
|
|
||||||
|
# initialize one with filled detector
|
||||||
|
model_report_full = ModelReport(filled_detector_set)
|
||||||
|
# initialize another with a single detector set
|
||||||
|
model_report_single = ModelReport(single_detector_set)
|
||||||
|
|
||||||
|
# prepare and callibrate two different instances of same model
|
||||||
|
# prepare the model
|
||||||
|
model_full = TestFxModelReportClass.TwoThreeOps()
|
||||||
|
model_single = TestFxModelReportClass.TwoThreeOps()
|
||||||
|
example_input = torch.randn(1, 3, 3, 3)
|
||||||
|
current_backend = torch.backends.quantized.engine
|
||||||
|
q_config_mapping = QConfigMapping()
|
||||||
|
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
||||||
|
|
||||||
|
model_prep_full = quantize_fx.prepare_fx(model_full, q_config_mapping, example_input)
|
||||||
|
model_prep_single = quantize_fx.prepare_fx(model_single, q_config_mapping, example_input)
|
||||||
|
|
||||||
|
# prepare the models for callibration
|
||||||
|
prepared_for_callibrate_model_full = model_report_full.prepare_detailed_calibration(model_prep_full)
|
||||||
|
prepared_for_callibrate_model_single = model_report_single.prepare_detailed_calibration(model_prep_single)
|
||||||
|
|
||||||
|
# now callibrate the two models
|
||||||
|
num_iterations = 10
|
||||||
|
for i in range(num_iterations):
|
||||||
|
example_input = torch.tensor(torch.randint(100, (1, 3, 3, 3)), dtype=torch.float)
|
||||||
|
prepared_for_callibrate_model_full(example_input)
|
||||||
|
prepared_for_callibrate_model_single(example_input)
|
||||||
|
|
||||||
|
# now generate the reports
|
||||||
|
model_full_report = model_report_full.generate_model_report(
|
||||||
|
prepared_for_callibrate_model_full, True
|
||||||
|
)
|
||||||
|
model_single_report = model_report_single.generate_model_report(prepared_for_callibrate_model_single, False)
|
||||||
|
|
||||||
|
# check that sizes are appropriate
|
||||||
|
self.assertEqual(len(model_full_report), len(filled_detector_set))
|
||||||
|
self.assertEqual(len(model_single_report), len(single_detector_set))
|
||||||
|
|
||||||
|
# make sure observers are being properly removed for full report since we put flag in
|
||||||
|
modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_full)
|
||||||
|
self.assertEqual(modules_observer_cnt, 0) # assert no more observer modules
|
||||||
|
self.assertEqual(graph_observer_cnt, 0) # assert no more observer nodes in graph
|
||||||
|
|
||||||
|
# make sure observers aren't being removed for single report since not specified
|
||||||
|
modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_single)
|
||||||
|
self.assertNotEqual(modules_observer_cnt, 0)
|
||||||
|
self.assertNotEqual(graph_observer_cnt, 0)
|
||||||
|
|
||||||
|
# make sure error when try to rerun report generation for full report but not single report
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
model_full_report = model_report_full.generate_model_report(
|
||||||
|
prepared_for_callibrate_model_full, False
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure we don't run into error for single report
|
||||||
|
model_single_report = model_report_single.generate_model_report(prepared_for_callibrate_model_single, False)
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ class PerChannelDetector(DetectorBase):
|
||||||
"onednn": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
|
"onednn": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, backend=torch.backends.quantized.engine):
|
def __init__(self, backend: str = torch.backends.quantized.engine):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# store the backend information
|
# store the backend information
|
||||||
|
|
|
||||||
|
|
@ -32,27 +32,28 @@ class ModelReport:
|
||||||
|
|
||||||
# keep the reports private so they can't be modified
|
# keep the reports private so they can't be modified
|
||||||
self._desired_report_detectors = desired_report_detectors
|
self._desired_report_detectors = desired_report_detectors
|
||||||
self._desired_reports = set([detector.get_detector_name() for detector in desired_report_detectors])
|
self._desired_detector_names = set([detector.get_detector_name() for detector in desired_report_detectors])
|
||||||
|
|
||||||
# keep a mapping of desired reports to observers of interest
|
# keep a mapping of desired reports to observers of interest
|
||||||
# this is to get the readings, and to remove them, can create a large set
|
# this is to get the readings, and to remove them, can create a large set
|
||||||
# this set can then be used to traverse the graph and remove added observers
|
# this set can then be used to traverse the graph and remove added observers
|
||||||
self._report_name_to_observer_fqns: Dict[str, Set[str]] = {}
|
self._detector_name_to_observer_fqns: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
# initialize each report to have empty set of observers of interest
|
# initialize each report to have empty set of observers of interest
|
||||||
for desired_report in self._desired_reports:
|
for desired_report in self._desired_detector_names:
|
||||||
self._report_name_to_observer_fqns[desired_report] = set([])
|
self._detector_name_to_observer_fqns[desired_report] = set([])
|
||||||
|
|
||||||
# flags to ensure that we can only prepare once
|
# flags to ensure that we can only prepare and generate report once
|
||||||
self._prepared_flag = False
|
self._prepared_flag = False
|
||||||
|
self._removed_observers = False
|
||||||
|
|
||||||
def get_desired_reports_names(self) -> Set[str]:
|
def get_desired_reports_names(self) -> Set[str]:
|
||||||
""" Returns a copy of the desired reports for viewing """
|
""" Returns a copy of the desired reports for viewing """
|
||||||
return self._desired_reports.copy()
|
return self._desired_detector_names.copy()
|
||||||
|
|
||||||
def get_observers_of_interest(self) -> Dict[str, Set[str]]:
|
def get_observers_of_interest(self) -> Dict[str, Set[str]]:
|
||||||
""" Returns a copy of the observers of interest for viewing """
|
""" Returns a copy of the observers of interest for viewing """
|
||||||
return self._report_name_to_observer_fqns.copy()
|
return self._detector_name_to_observer_fqns.copy()
|
||||||
|
|
||||||
def prepare_detailed_calibration(self, prepared_fx_model: GraphModule) -> GraphModule:
|
def prepare_detailed_calibration(self, prepared_fx_model: GraphModule) -> GraphModule:
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -61,7 +62,7 @@ class ModelReport:
|
||||||
|
|
||||||
Each observer is inserted based on the desired_reports into the relavent locations
|
Each observer is inserted based on the desired_reports into the relavent locations
|
||||||
|
|
||||||
Right now, each report in self._desired_reports has independent insertions
|
Right now, each report in self._desired_detector_names has independent insertions
|
||||||
However, if a module already has a Observer of the same type, the insertion will not occur
|
However, if a module already has a Observer of the same type, the insertion will not occur
|
||||||
This is because all of the same type of Observer collect same information, so redundant
|
This is because all of the same type of Observer collect same information, so redundant
|
||||||
|
|
||||||
|
|
@ -84,7 +85,7 @@ class ModelReport:
|
||||||
# map each insert point to the observer to use
|
# map each insert point to the observer to use
|
||||||
insert_observers_fqns.update(obs_fqn_to_info)
|
insert_observers_fqns.update(obs_fqn_to_info)
|
||||||
# update the set of observers this report cares about
|
# update the set of observers this report cares about
|
||||||
self._report_name_to_observer_fqns[detector.get_detector_name()] = set(obs_fqn_to_info.keys())
|
self._detector_name_to_observer_fqns[detector.get_detector_name()] = set(obs_fqn_to_info.keys())
|
||||||
|
|
||||||
# now insert all the observers at their desired locations
|
# now insert all the observers at their desired locations
|
||||||
for observer_fqn in insert_observers_fqns:
|
for observer_fqn in insert_observers_fqns:
|
||||||
|
|
@ -142,7 +143,20 @@ class ModelReport:
|
||||||
|
|
||||||
Returns the Node object of the given node_fqn otherwise returns None
|
Returns the Node object of the given node_fqn otherwise returns None
|
||||||
"""
|
"""
|
||||||
pass
|
node_to_return = None
|
||||||
|
for node in fx_model.graph.nodes:
|
||||||
|
# if the target matches the fqn, it's the node we are looking for
|
||||||
|
if node.target == node_fqn:
|
||||||
|
node_to_return = node
|
||||||
|
break
|
||||||
|
|
||||||
|
if node_to_return is None:
|
||||||
|
raise ValueError("The node_fqn is was not found within the module.")
|
||||||
|
|
||||||
|
# assert for MyPy
|
||||||
|
assert isinstance(node_to_return, torch.fx.node.Node)
|
||||||
|
|
||||||
|
return node_to_return
|
||||||
|
|
||||||
def generate_model_report(
|
def generate_model_report(
|
||||||
self, calibrated_fx_model: GraphModule, remove_inserted_observers: bool
|
self, calibrated_fx_model: GraphModule, remove_inserted_observers: bool
|
||||||
|
|
@ -161,4 +175,42 @@ class ModelReport:
|
||||||
The textual summary of that report information
|
The textual summary of that report information
|
||||||
A dictionary containing relavent statistics or information for that report
|
A dictionary containing relavent statistics or information for that report
|
||||||
"""
|
"""
|
||||||
pass
|
# if we already removed the observers, we cannot generate report
|
||||||
|
if self._removed_observers:
|
||||||
|
raise Exception("Cannot generate report on model you already removed observers from")
|
||||||
|
|
||||||
|
# keep track of all the reports of interest and their outputs
|
||||||
|
reports_of_interest = {}
|
||||||
|
|
||||||
|
for detector in self._desired_report_detectors:
|
||||||
|
# generate the individual report for the detector
|
||||||
|
report_output = detector.generate_detector_report(calibrated_fx_model)
|
||||||
|
reports_of_interest[detector.get_detector_name()] = report_output
|
||||||
|
|
||||||
|
# if user wishes to remove inserted observers, go ahead and remove
|
||||||
|
if remove_inserted_observers:
|
||||||
|
self._removed_observers = True
|
||||||
|
# get the set of all Observers inserted by this instance of ModelReport
|
||||||
|
all_observers_of_interest: Set[str] = set([])
|
||||||
|
for desired_report in self._detector_name_to_observer_fqns:
|
||||||
|
observers_of_interest = self._detector_name_to_observer_fqns[desired_report]
|
||||||
|
all_observers_of_interest.update(observers_of_interest)
|
||||||
|
|
||||||
|
# go through all_observers_of_interest and remove them from the graph and model
|
||||||
|
for observer_fqn in all_observers_of_interest:
|
||||||
|
# remove the observer from the model
|
||||||
|
calibrated_fx_model.delete_submodule(observer_fqn)
|
||||||
|
|
||||||
|
# remove the observer from the graph structure
|
||||||
|
node_obj = self._get_node_from_fqn(calibrated_fx_model, observer_fqn)
|
||||||
|
|
||||||
|
if node_obj:
|
||||||
|
calibrated_fx_model.graph.erase_node(node_obj)
|
||||||
|
else:
|
||||||
|
raise ValueError("Node no longer exists in GraphModule structure")
|
||||||
|
|
||||||
|
# remember to recompile the model
|
||||||
|
calibrated_fx_model.recompile()
|
||||||
|
|
||||||
|
# return the reports of interest
|
||||||
|
return reports_of_interest
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user