mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Minimizer] allow overriding of ShapeProp logic by subclasses of _MinimizerBase (#148784)
Summary: The changes contained in this diff - allow subclass Minimizer implementations to override the default shape propagation logic with custom logic - copies over the meta attribute on get_attr graph nodes during the graph splitting step - for both changes, behavior for existing classes do not change Test Plan: CI Differential Revision: D70799942 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148784 Approved by: https://github.com/blaine-rister
This commit is contained in:
parent
fcb633fafa
commit
12a95390ae
|
|
@ -141,7 +141,7 @@ class _MinimizerBase:
|
|||
callable_nodes = {
|
||||
node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS
|
||||
}
|
||||
ShapeProp(self.module).propagate(*self.sample_input)
|
||||
self.run_shape_prop()
|
||||
self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)()
|
||||
|
||||
# Check if number of input in sample_input matches the number of placeholders
|
||||
|
|
@ -155,6 +155,13 @@ class _MinimizerBase:
|
|||
self.a_outputs[name] = sample_input[i]
|
||||
self.b_outputs[name] = sample_input[i]
|
||||
|
||||
def run_shape_prop(self) -> None:
|
||||
"""
|
||||
Helper function to run shape propagation on module. Can be overridden by
|
||||
subclasses for custom shape propagation logic.
|
||||
"""
|
||||
ShapeProp(self.module).propagate(*self.sample_input)
|
||||
|
||||
def run_a(
|
||||
self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1
|
||||
) -> TensorOrTensors:
|
||||
|
|
|
|||
|
|
@ -212,6 +212,7 @@ def split_by_tags(
|
|||
comp.getattr_maps[x] = comp.graph.get_attr(
|
||||
x.target, type_expr=x.type
|
||||
)
|
||||
comp.getattr_maps[x].meta = copy.copy(x.meta)
|
||||
return comp.getattr_maps[x]
|
||||
|
||||
# If input is not a placeholder, it should have been put into a component
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user