mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Reland of https://github.com/pytorch/pytorch/pull/114787 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115558 Approved by: https://github.com/zhxchen17, https://github.com/atalman ghstack dependencies: #115556, #115557
296 lines
11 KiB
Python
296 lines
11 KiB
Python
import copy
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._export.utils import _check_input_constraints_pre_hook
|
|
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
|
from .exported_program import ExportedProgram
|
|
|
|
|
|
def _unlift(
|
|
gm,
|
|
inp_pos_to_param_buffer_name,
|
|
in_spec,
|
|
out_spec,
|
|
state_dict,
|
|
tensor_constants,
|
|
buffers_to_mutate,
|
|
):
|
|
count = 0
|
|
buffer_name_to_node = {}
|
|
# Step 1: make lifted params as get_attr
|
|
for node in gm.graph.nodes:
|
|
if node.op == "placeholder":
|
|
if count in inp_pos_to_param_buffer_name:
|
|
with gm.graph.inserting_after(node):
|
|
getattr_node = gm.graph.get_attr(
|
|
inp_pos_to_param_buffer_name[count]
|
|
)
|
|
node.replace_all_uses_with(getattr_node)
|
|
metadata = node.meta
|
|
gm.graph.erase_node(node)
|
|
getattr_node.meta = metadata
|
|
buffer_name_to_node[
|
|
inp_pos_to_param_buffer_name[count]
|
|
] = getattr_node
|
|
|
|
count += 1
|
|
# Step 2: Find the all the buffers that were mutated and update them
|
|
if node.op == "output":
|
|
user_output_nodes = []
|
|
# In the case that the same node is returned multiple times,
|
|
# node.all_input_nodes will only iterate that node once
|
|
for return_node in pytree.tree_flatten(node.args)[0]:
|
|
return_node_name = return_node.name
|
|
# we found a param/buffer mutation
|
|
if return_node_name in buffers_to_mutate:
|
|
# TODO Fix situation here to replace dot with underscore...
|
|
buffer_node_name = buffers_to_mutate[return_node_name].replace(
|
|
".", "_"
|
|
)
|
|
assert buffer_node_name in buffer_name_to_node
|
|
buffer_node = buffer_name_to_node[buffer_node_name]
|
|
with gm.graph.inserting_before(node):
|
|
buffer_update_node = gm.graph.call_function(
|
|
torch.ops.aten.copy_.default, (buffer_node, return_node)
|
|
)
|
|
else:
|
|
user_output_nodes.append(return_node)
|
|
with gm.graph.inserting_before(node):
|
|
# Only return user outputs
|
|
new_output = gm.graph.output(tuple(user_output_nodes))
|
|
node.replace_all_uses_with(new_output)
|
|
gm.graph.erase_node(node)
|
|
|
|
# Step 3: Fix the input/output of the graph now that we deleted
|
|
# some args.
|
|
gm.graph.lint()
|
|
|
|
if (
|
|
in_spec.type == tuple
|
|
and len(in_spec.children_specs) == 2
|
|
and in_spec.children_specs[0].type == tuple
|
|
and in_spec.children_specs[1].type == dict
|
|
):
|
|
# if in_spec contains the args (tuple) and kwargs (dict)
|
|
|
|
num_args = len(in_spec.children_specs[0].children_specs) + len(
|
|
in_spec.children_specs[1].children_specs
|
|
)
|
|
else:
|
|
num_args = len(in_spec.children_specs)
|
|
|
|
names = [f"arg_{i}" for i in range(num_args)]
|
|
|
|
gm.graph._codegen = _PyTreeCodeGen(
|
|
_PyTreeInfo(
|
|
names,
|
|
in_spec,
|
|
out_spec,
|
|
)
|
|
)
|
|
gm.recompile()
|
|
|
|
# Step 4: Find state references in HigherOrderOps and recursively
|
|
# fix them.
|
|
for node in gm.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.ops.cond:
|
|
pred, true_graph, false_graph, operands = node.args
|
|
true_gm = getattr(gm, true_graph.name)
|
|
false_gm = getattr(gm, false_graph.name)
|
|
inp_pos_to_param_buffer_name_for_submod = {}
|
|
real_operands = []
|
|
for ix, operand in enumerate(operands):
|
|
if operand.target in inp_pos_to_param_buffer_name.values():
|
|
inp_pos_to_param_buffer_name_for_submod[ix] = operand.target
|
|
if operand.target in state_dict:
|
|
value = state_dict[operand.target]
|
|
elif operand.target in tensor_constants:
|
|
value = tensor_constants[operand.target]
|
|
else:
|
|
raise RuntimeError("Unable to find value for ", operand.target)
|
|
true_gm.register_buffer(operand.target, value)
|
|
false_gm.register_buffer(operand.target, value)
|
|
else:
|
|
real_operands.append(operand)
|
|
node.args = (pred, true_graph, false_graph, real_operands)
|
|
|
|
_, in_spec = pytree.tree_flatten(real_operands)
|
|
|
|
_unlift(
|
|
true_gm,
|
|
inp_pos_to_param_buffer_name_for_submod,
|
|
in_spec,
|
|
None,
|
|
state_dict,
|
|
tensor_constants,
|
|
buffers_to_mutate,
|
|
)
|
|
_unlift(
|
|
false_gm,
|
|
inp_pos_to_param_buffer_name_for_submod,
|
|
in_spec,
|
|
None,
|
|
state_dict,
|
|
tensor_constants,
|
|
buffers_to_mutate,
|
|
)
|
|
if node.op == "call_function" and node.target.__name__ == "map_impl":
|
|
body_graph, num_mapped, *operands = node.args
|
|
body_gm = getattr(gm, body_graph.name)
|
|
inp_pos_to_buffer_name_for_submod = {}
|
|
real_operands = []
|
|
# TODO Fix situation here to replace dot with underscore...
|
|
state_dict_for_lookup = {
|
|
key.replace(".", "_"): value for key, value in state_dict.items()
|
|
}
|
|
for ix, operand in enumerate(operands):
|
|
if operand.target in inp_pos_to_param_buffer_name.values():
|
|
inp_pos_to_buffer_name_for_submod[ix] = operand.target
|
|
if operand.target in state_dict_for_lookup:
|
|
value = state_dict_for_lookup[operand.target]
|
|
elif operand.target in tensor_constants:
|
|
value = tensor_constants[operand.target]
|
|
else:
|
|
raise RuntimeError(f"Unable to find value for {operand.target}")
|
|
body_gm.register_buffer(operand.target, value)
|
|
else:
|
|
real_operands.append(operand)
|
|
node.args = (body_graph, num_mapped, *real_operands)
|
|
|
|
_, in_spec = pytree.tree_flatten(real_operands)
|
|
|
|
_unlift(
|
|
body_gm,
|
|
inp_pos_to_buffer_name_for_submod,
|
|
in_spec,
|
|
None,
|
|
state_dict,
|
|
tensor_constants,
|
|
buffers_to_mutate,
|
|
)
|
|
gm.graph.lint()
|
|
gm.graph.eliminate_dead_code()
|
|
gm.recompile()
|
|
return gm
|
|
|
|
|
|
def _construct_inp_pos_to_param_buffer_name(
|
|
new_gm, graph_signature, state_dict, tensor_constants=None
|
|
):
|
|
# TODO Fix the period in params/buffers names later
|
|
# maybe a pass to replace graph signature with fixed names
|
|
param_buffer_name_to_corrected_name = {}
|
|
|
|
for name, value in state_dict.items():
|
|
if name in graph_signature.buffers:
|
|
if "." in name:
|
|
new_gm.register_buffer(name.replace(".", "_"), value)
|
|
param_buffer_name_to_corrected_name[name] = name.replace(".", "_")
|
|
else:
|
|
new_gm.register_buffer(name, value)
|
|
if name in graph_signature.parameters:
|
|
if "." in name:
|
|
new_gm.register_parameter(name.replace(".", "_"), value)
|
|
param_buffer_name_to_corrected_name[name] = name.replace(".", "_")
|
|
else:
|
|
new_gm.register_parameter(name, value)
|
|
|
|
if tensor_constants is not None and len(tensor_constants) > 0:
|
|
assert hasattr(graph_signature, "lifted_tensor_constants")
|
|
for name, value in tensor_constants.items():
|
|
if name in graph_signature.lifted_tensor_constants:
|
|
new_gm.register_buffer(name, value)
|
|
param_buffer_name_to_corrected_name[name] = name
|
|
|
|
count = 0
|
|
inp_pos_to_param_buffer_name = {}
|
|
for node in new_gm.graph.nodes:
|
|
if node.op == "placeholder":
|
|
if node.name in graph_signature.inputs_to_buffers:
|
|
buffer_name = graph_signature.inputs_to_buffers[node.name]
|
|
if buffer_name in param_buffer_name_to_corrected_name:
|
|
inp_pos_to_param_buffer_name[
|
|
count
|
|
] = param_buffer_name_to_corrected_name[buffer_name]
|
|
else:
|
|
inp_pos_to_param_buffer_name[count] = buffer_name
|
|
if node.name in graph_signature.inputs_to_parameters:
|
|
param_name = graph_signature.inputs_to_parameters[node.name]
|
|
if param_name in param_buffer_name_to_corrected_name:
|
|
inp_pos_to_param_buffer_name[
|
|
count
|
|
] = param_buffer_name_to_corrected_name[param_name]
|
|
else:
|
|
inp_pos_to_param_buffer_name[count] = param_name
|
|
if hasattr(graph_signature, "inputs_to_lifted_tensor_constants"):
|
|
if node.name in graph_signature.inputs_to_lifted_tensor_constants:
|
|
inp_pos_to_param_buffer_name[
|
|
count
|
|
] = graph_signature.inputs_to_lifted_tensor_constants[node.name]
|
|
count += 1
|
|
|
|
return inp_pos_to_param_buffer_name
|
|
|
|
|
|
class _StatefulGraphModuleFactory(type):
|
|
"""
|
|
Metaclass that ensures a private constructor for _StatefulGraphModule
|
|
"""
|
|
|
|
def __call__(cls, *args, **kwargs):
|
|
raise TypeError(
|
|
f"{cls.__module__}.{cls.__qualname__} has no public constructor. "
|
|
)
|
|
|
|
def _create(cls, root, graph, range_constraints=None, equality_constraints=None):
|
|
return super().__call__(
|
|
root,
|
|
graph,
|
|
range_constraints=range_constraints,
|
|
equality_constraints=equality_constraints,
|
|
)
|
|
|
|
|
|
class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory):
|
|
def __init__(self, root, graph, range_constraints=None, equality_constraints=None):
|
|
super().__init__(root, graph)
|
|
self.range_constraints = range_constraints or []
|
|
self.equality_constraints = equality_constraints or []
|
|
|
|
|
|
def _create_stateful_graph_module(
|
|
plain_graph_module: torch.fx.GraphModule, range_constraints, equality_constraints
|
|
):
|
|
stateful_gm = _StatefulGraphModule._create(
|
|
plain_graph_module,
|
|
plain_graph_module.graph,
|
|
range_constraints=range_constraints,
|
|
equality_constraints=equality_constraints,
|
|
)
|
|
stateful_gm.register_forward_pre_hook(
|
|
_check_input_constraints_pre_hook, with_kwargs=True
|
|
)
|
|
return stateful_gm
|
|
|
|
|
|
def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module:
|
|
new_gm = copy.deepcopy(ep.graph_module)
|
|
inp_pos_to_param_buffer_name = _construct_inp_pos_to_param_buffer_name(
|
|
new_gm, ep.graph_signature, ep.state_dict, ep.tensor_constants
|
|
)
|
|
new_gm = _unlift(
|
|
new_gm,
|
|
inp_pos_to_param_buffer_name,
|
|
ep.call_spec.in_spec,
|
|
ep.call_spec.out_spec,
|
|
ep.state_dict,
|
|
ep.tensor_constants,
|
|
ep.graph_signature.buffers_to_mutate,
|
|
)
|
|
unlift_gm = _create_stateful_graph_module(
|
|
new_gm, ep.range_constraints, ep.equality_constraints
|
|
)
|
|
unlift_gm.meta.update(ep.graph_module.meta)
|
|
return unlift_gm
|