[pipelining] Remove qualname mapping (#127018)

`QualnameMapMixin` was intended to provide a mapping from new FQN of the piped model to the FQN of the original model. It was there because previous tracers and flattening during tracing would modify the FQNs.

Now that we use unflattener, the FQN of the stage modules are the same as the original FQNs. We don't need `QualnameMapMixin` any more.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127018
Approved by: https://github.com/H-Huang
This commit is contained in:
Ke Wen 2024-05-23 15:50:05 -07:00 committed by PyTorch MergeBot
parent 5f15110499
commit ed838793df
2 changed files with 3 additions and 128 deletions

View File

@ -17,7 +17,6 @@ from torch.fx.passes.split_module import split_module
from ._backward import _null_coalesce_accumulate, stage_backward
from ._unflatten import _outline_submodules
from ._utils import QualnameMapMixin
from .microbatch import split_args_kwargs_into_chunks, TensorChunkSpec
@ -480,7 +479,7 @@ def _direct_serialization_reduce(self):
)
class Pipe(QualnameMapMixin, torch.nn.Module):
class Pipe(torch.nn.Module):
# Class variables
"""
args_chunk_spec:
@ -507,14 +506,11 @@ class Pipe(QualnameMapMixin, torch.nn.Module):
def __init__(
self,
split_gm: fx.GraphModule,
splitter_qualname_map: Dict[str, str],
num_stages: int,
has_loss_and_backward: bool,
loss_spec,
tracer_qualname_map: Optional[Dict[str, str]] = None,
):
# TODO: is there a way not to hard wire init?
QualnameMapMixin.__init__(self, splitter_qualname_map, tracer_qualname_map)
torch.nn.Module.__init__(self)
self.split_gm: fx.GraphModule = split_gm
self.executor: DetachExecutor = DetachExecutor(self.split_gm)
@ -568,22 +564,6 @@ class Pipe(QualnameMapMixin, torch.nn.Module):
submod = getattr(submod, atom)
setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1])))
# Create qualname mapping for each submodule
# Dict looks like this:
# {submod_name : Dict{old_qualname : new_qualname}}
# We save this information here for use during pipeline stage creation.
self.submod_qualname_mappings: Dict[str, Dict[str, str]] = {}
for m_qualname, mod in self.split_gm.named_children():
# "submod_x." prefix
mod_prefix = m_qualname + "."
mod_qualname_mapping: Dict[str, str] = {}
for k, v in self.new_to_old_qualname_mapping.items():
if k.startswith(mod_prefix):
# Remove prefix
new_key = k[len(mod_prefix) :]
mod_qualname_mapping.setdefault(new_key, v)
self.submod_qualname_mappings[m_qualname] = mod_qualname_mapping
def throw(self, *args, **kwargs):
raise RuntimeError(
"To run pipeline locally, invoke the Pipe object directly, not `split_gm`"
@ -726,11 +706,9 @@ class Pipe(QualnameMapMixin, torch.nn.Module):
part_idx += 1
return part_idx
# Ask split_module to return mapping from new qualname to old qualname
splitter_qualname_map: Dict[str, str] = {}
# TODO: what does split do with module invocations? does it move the modules
# into the submodules?
split = split_module(traced, mod, split_callback, splitter_qualname_map)
split = split_module(traced, mod, split_callback)
# a (custom) tracer can produce dead code like orphan get_attr nodes
split.graph.eliminate_dead_code()
@ -814,19 +792,6 @@ class Pipe(QualnameMapMixin, torch.nn.Module):
)
# logger.debug(f"Moved parameter {param_fqn} to {callee_name}")
# Update qualname mapping
# New qualname will have submodule prefix
new_qualname = f"{callee_name}.{param_fqn}"
if param_fqn in splitter_qualname_map:
# Just in case the target name is already in the splitter_qualname_map
# returned by split_module() -- we update the mapping using the
# new name as a new key
splitter_qualname_map[new_qualname] = splitter_qualname_map.pop(
param_fqn
)
else:
splitter_qualname_map[new_qualname] = param_fqn
# Next step is to replace placeholder of submodule with a get_attr.
# Those placeholders are created by `split_module` inside each
# submodule.
@ -1071,19 +1036,13 @@ class Pipe(QualnameMapMixin, torch.nn.Module):
else:
logger.debug("Pipeline is in inference mode, backward pass not generated")
# Tracer may modify qualname, get the qualname mapping before and after tracing.
# This qualname mapping is different from the mapping before and after splitting.
tracer_qualname_map = Pipe._get_param_buffer_mapping(mod, traced)
logger.debug("Full pipe model:\n" f"{split}") # noqa: G004
return Pipe(
split,
splitter_qualname_map,
num_stages,
has_loss_and_backward,
generated_loss_spec,
tracer_qualname_map,
)
def print_readable(self):
@ -1205,44 +1164,6 @@ class Pipe(QualnameMapMixin, torch.nn.Module):
)
return self.pipe_info
# TODO: this util comes from pytorch/pytorch#115462, delete it from PiPPy
# when PyTorch 2.3 comes with support, or when PiPPy migrates from
# `_export_to_torch_ir` to export + unflattener.
@staticmethod
def _get_param_buffer_mapping(
original_module: torch.nn.Module,
traced_module: torch.nn.Module,
) -> Dict[str, str]:
"""
Returns a mapping of parameter/buffer names from the new module to the
original model. This is to help with restoring the FQN for parameter/buffers
of a traced module to what the original module contains.
"""
param_lookup: Dict[int, List[str]] = {}
buffer_lookup: Dict[int, List[str]] = {}
for name, param in original_module.named_parameters(remove_duplicate=False):
param_lookup.setdefault(id(param), []).append(name)
for name, buffer in original_module.named_buffers(remove_duplicate=False):
buffer_lookup.setdefault(id(buffer), []).append(name)
param_buffer_table: Dict[str, str] = {}
for dynamo_name, dynamo_param in traced_module.named_parameters(
remove_duplicate=False
):
assert dynamo_name not in param_buffer_table
if id(dynamo_param) in param_lookup:
param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)].pop()
for dynamo_name, dynamo_buffer in traced_module.named_buffers(
remove_duplicate=False
):
assert dynamo_name not in param_buffer_table
if id(dynamo_buffer) in buffer_lookup:
param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)].pop()
return param_buffer_table
class SplitPoint(Enum):
BEGINNING = 1

View File

@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Tuple, Union
import torch
from torch import fx
@ -88,52 +88,6 @@ def modify_graph_op_device(
gm.recompile()
class QualnameMapMixin:
"""
A mixin class that helps a `Pipe` object to remap its qualnames back to
original qualnames.
"""
def __init__(
self,
splitter_qualname_map: Optional[Dict[str, str]] = None,
tracer_qualname_map: Optional[Dict[str, str]] = None,
):
self.new_to_old_qualname_mapping: Dict[str, str] = splitter_qualname_map or {}
self.tracer_qualname_map = tracer_qualname_map
def remap_qualname(self, qualname: str):
# TODO: annoying
if qualname.startswith("split_gm."):
qualname = qualname[len("split_gm.") :]
name_before_split = None
if qualname in self.new_to_old_qualname_mapping:
name_before_split = self.new_to_old_qualname_mapping[qualname]
else:
# The qualname map does not store recursive items, thus,
# when passed a qualname with leaves, we need to perform longest prefix match
# Split from the right, one each time
split_names = qualname.rsplit(".", 1)
leaf = split_names[-1]
while len(split_names) > 1:
prefix = split_names[0]
if prefix in self.new_to_old_qualname_mapping:
old_prefix = self.new_to_old_qualname_mapping[prefix]
name_before_split = ".".join([old_prefix, leaf])
break
split_names = prefix.rsplit(".", 1)
leaf = ".".join([split_names[-1], leaf])
if name_before_split is None:
raise RuntimeError(f"Could not find mapping for {qualname}")
if self.tracer_qualname_map is not None:
return self.tracer_qualname_map[name_before_split]
else:
return name_before_split
class PipeliningShapeError(RuntimeError):
"""Shape mismatch between configured and runtime values."""