[Minimizer] allow overriding of ShapeProp logic by subclasses of _MinimizerBase (#148784)

Summary:
The changes contained in this diff
- allow subclass Minimizer implementations to override the default shape propagation logic with custom logic
- copies over the meta attribute on get_attr graph nodes during the graph splitting step
- for both changes, behavior for existing classes do not change

Test Plan: CI

Differential Revision: D70799942

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148784
Approved by: https://github.com/blaine-rister
This commit is contained in:
Qiaochu Yuan 2025-03-10 22:22:12 +00:00 committed by PyTorch MergeBot
parent fcb633fafa
commit 12a95390ae
2 changed files with 9 additions and 1 deletions

View File

@ -141,7 +141,7 @@ class _MinimizerBase:
callable_nodes = {
node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS
}
ShapeProp(self.module).propagate(*self.sample_input)
self.run_shape_prop()
self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)()
# Check if number of input in sample_input matches the number of placeholders
@ -155,6 +155,13 @@ class _MinimizerBase:
self.a_outputs[name] = sample_input[i]
self.b_outputs[name] = sample_input[i]
def run_shape_prop(self) -> None:
"""
Helper function to run shape propagation on module. Can be overridden by
subclasses for custom shape propagation logic.
"""
ShapeProp(self.module).propagate(*self.sample_input)
def run_a(
self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1
) -> TensorOrTensors:

View File

@ -212,6 +212,7 @@ def split_by_tags(
comp.getattr_maps[x] = comp.graph.get_attr(
x.target, type_expr=x.type
)
comp.getattr_maps[x].meta = copy.copy(x.meta)
return comp.getattr_maps[x]
# If input is not a placeholder, it should have been put into a component