mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[pipelining] Remove qualname mapping (#127018)
`QualnameMapMixin` was intended to provide a mapping from new FQN of the piped model to the FQN of the original model. It was there because previous tracers and flattening during tracing would modify the FQNs. Now that we use unflattener, the FQN of the stage modules are the same as the original FQNs. We don't need `QualnameMapMixin` any more. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127018 Approved by: https://github.com/H-Huang
This commit is contained in:
parent
5f15110499
commit
ed838793df
|
|
@ -17,7 +17,6 @@ from torch.fx.passes.split_module import split_module
|
|||
|
||||
from ._backward import _null_coalesce_accumulate, stage_backward
|
||||
from ._unflatten import _outline_submodules
|
||||
from ._utils import QualnameMapMixin
|
||||
from .microbatch import split_args_kwargs_into_chunks, TensorChunkSpec
|
||||
|
||||
|
||||
|
|
@ -480,7 +479,7 @@ def _direct_serialization_reduce(self):
|
|||
)
|
||||
|
||||
|
||||
class Pipe(QualnameMapMixin, torch.nn.Module):
|
||||
class Pipe(torch.nn.Module):
|
||||
# Class variables
|
||||
"""
|
||||
args_chunk_spec:
|
||||
|
|
@ -507,14 +506,11 @@ class Pipe(QualnameMapMixin, torch.nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
split_gm: fx.GraphModule,
|
||||
splitter_qualname_map: Dict[str, str],
|
||||
num_stages: int,
|
||||
has_loss_and_backward: bool,
|
||||
loss_spec,
|
||||
tracer_qualname_map: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
# TODO: is there a way not to hard wire init?
|
||||
QualnameMapMixin.__init__(self, splitter_qualname_map, tracer_qualname_map)
|
||||
torch.nn.Module.__init__(self)
|
||||
self.split_gm: fx.GraphModule = split_gm
|
||||
self.executor: DetachExecutor = DetachExecutor(self.split_gm)
|
||||
|
|
@ -568,22 +564,6 @@ class Pipe(QualnameMapMixin, torch.nn.Module):
|
|||
submod = getattr(submod, atom)
|
||||
setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1])))
|
||||
|
||||
# Create qualname mapping for each submodule
|
||||
# Dict looks like this:
|
||||
# {submod_name : Dict{old_qualname : new_qualname}}
|
||||
# We save this information here for use during pipeline stage creation.
|
||||
self.submod_qualname_mappings: Dict[str, Dict[str, str]] = {}
|
||||
for m_qualname, mod in self.split_gm.named_children():
|
||||
# "submod_x." prefix
|
||||
mod_prefix = m_qualname + "."
|
||||
mod_qualname_mapping: Dict[str, str] = {}
|
||||
for k, v in self.new_to_old_qualname_mapping.items():
|
||||
if k.startswith(mod_prefix):
|
||||
# Remove prefix
|
||||
new_key = k[len(mod_prefix) :]
|
||||
mod_qualname_mapping.setdefault(new_key, v)
|
||||
self.submod_qualname_mappings[m_qualname] = mod_qualname_mapping
|
||||
|
||||
def throw(self, *args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"To run pipeline locally, invoke the Pipe object directly, not `split_gm`"
|
||||
|
|
@ -726,11 +706,9 @@ class Pipe(QualnameMapMixin, torch.nn.Module):
|
|||
part_idx += 1
|
||||
return part_idx
|
||||
|
||||
# Ask split_module to return mapping from new qualname to old qualname
|
||||
splitter_qualname_map: Dict[str, str] = {}
|
||||
# TODO: what does split do with module invocations? does it move the modules
|
||||
# into the submodules?
|
||||
split = split_module(traced, mod, split_callback, splitter_qualname_map)
|
||||
split = split_module(traced, mod, split_callback)
|
||||
# a (custom) tracer can produce dead code like orphan get_attr nodes
|
||||
split.graph.eliminate_dead_code()
|
||||
|
||||
|
|
@ -814,19 +792,6 @@ class Pipe(QualnameMapMixin, torch.nn.Module):
|
|||
)
|
||||
# logger.debug(f"Moved parameter {param_fqn} to {callee_name}")
|
||||
|
||||
# Update qualname mapping
|
||||
# New qualname will have submodule prefix
|
||||
new_qualname = f"{callee_name}.{param_fqn}"
|
||||
if param_fqn in splitter_qualname_map:
|
||||
# Just in case the target name is already in the splitter_qualname_map
|
||||
# returned by split_module() -- we update the mapping using the
|
||||
# new name as a new key
|
||||
splitter_qualname_map[new_qualname] = splitter_qualname_map.pop(
|
||||
param_fqn
|
||||
)
|
||||
else:
|
||||
splitter_qualname_map[new_qualname] = param_fqn
|
||||
|
||||
# Next step is to replace placeholder of submodule with a get_attr.
|
||||
# Those placeholders are created by `split_module` inside each
|
||||
# submodule.
|
||||
|
|
@ -1071,19 +1036,13 @@ class Pipe(QualnameMapMixin, torch.nn.Module):
|
|||
else:
|
||||
logger.debug("Pipeline is in inference mode, backward pass not generated")
|
||||
|
||||
# Tracer may modify qualname, get the qualname mapping before and after tracing.
|
||||
# This qualname mapping is different from the mapping before and after splitting.
|
||||
tracer_qualname_map = Pipe._get_param_buffer_mapping(mod, traced)
|
||||
|
||||
logger.debug("Full pipe model:\n" f"{split}") # noqa: G004
|
||||
|
||||
return Pipe(
|
||||
split,
|
||||
splitter_qualname_map,
|
||||
num_stages,
|
||||
has_loss_and_backward,
|
||||
generated_loss_spec,
|
||||
tracer_qualname_map,
|
||||
)
|
||||
|
||||
def print_readable(self):
|
||||
|
|
@ -1205,44 +1164,6 @@ class Pipe(QualnameMapMixin, torch.nn.Module):
|
|||
)
|
||||
return self.pipe_info
|
||||
|
||||
# TODO: this util comes from pytorch/pytorch#115462, delete it from PiPPy
|
||||
# when PyTorch 2.3 comes with support, or when PiPPy migrates from
|
||||
# `_export_to_torch_ir` to export + unflattener.
|
||||
@staticmethod
|
||||
def _get_param_buffer_mapping(
|
||||
original_module: torch.nn.Module,
|
||||
traced_module: torch.nn.Module,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Returns a mapping of parameter/buffer names from the new module to the
|
||||
original model. This is to help with restoring the FQN for parameter/buffers
|
||||
of a traced module to what the original module contains.
|
||||
"""
|
||||
|
||||
param_lookup: Dict[int, List[str]] = {}
|
||||
buffer_lookup: Dict[int, List[str]] = {}
|
||||
for name, param in original_module.named_parameters(remove_duplicate=False):
|
||||
param_lookup.setdefault(id(param), []).append(name)
|
||||
for name, buffer in original_module.named_buffers(remove_duplicate=False):
|
||||
buffer_lookup.setdefault(id(buffer), []).append(name)
|
||||
|
||||
param_buffer_table: Dict[str, str] = {}
|
||||
for dynamo_name, dynamo_param in traced_module.named_parameters(
|
||||
remove_duplicate=False
|
||||
):
|
||||
assert dynamo_name not in param_buffer_table
|
||||
if id(dynamo_param) in param_lookup:
|
||||
param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)].pop()
|
||||
|
||||
for dynamo_name, dynamo_buffer in traced_module.named_buffers(
|
||||
remove_duplicate=False
|
||||
):
|
||||
assert dynamo_name not in param_buffer_table
|
||||
if id(dynamo_buffer) in buffer_lookup:
|
||||
param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)].pop()
|
||||
|
||||
return param_buffer_table
|
||||
|
||||
|
||||
class SplitPoint(Enum):
|
||||
BEGINNING = 1
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
|
|
@ -88,52 +88,6 @@ def modify_graph_op_device(
|
|||
gm.recompile()
|
||||
|
||||
|
||||
class QualnameMapMixin:
|
||||
"""
|
||||
A mixin class that helps a `Pipe` object to remap its qualnames back to
|
||||
original qualnames.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
splitter_qualname_map: Optional[Dict[str, str]] = None,
|
||||
tracer_qualname_map: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
self.new_to_old_qualname_mapping: Dict[str, str] = splitter_qualname_map or {}
|
||||
self.tracer_qualname_map = tracer_qualname_map
|
||||
|
||||
def remap_qualname(self, qualname: str):
|
||||
# TODO: annoying
|
||||
if qualname.startswith("split_gm."):
|
||||
qualname = qualname[len("split_gm.") :]
|
||||
|
||||
name_before_split = None
|
||||
if qualname in self.new_to_old_qualname_mapping:
|
||||
name_before_split = self.new_to_old_qualname_mapping[qualname]
|
||||
else:
|
||||
# The qualname map does not store recursive items, thus,
|
||||
# when passed a qualname with leaves, we need to perform longest prefix match
|
||||
# Split from the right, one each time
|
||||
split_names = qualname.rsplit(".", 1)
|
||||
leaf = split_names[-1]
|
||||
while len(split_names) > 1:
|
||||
prefix = split_names[0]
|
||||
if prefix in self.new_to_old_qualname_mapping:
|
||||
old_prefix = self.new_to_old_qualname_mapping[prefix]
|
||||
name_before_split = ".".join([old_prefix, leaf])
|
||||
break
|
||||
split_names = prefix.rsplit(".", 1)
|
||||
leaf = ".".join([split_names[-1], leaf])
|
||||
|
||||
if name_before_split is None:
|
||||
raise RuntimeError(f"Could not find mapping for {qualname}")
|
||||
|
||||
if self.tracer_qualname_map is not None:
|
||||
return self.tracer_qualname_map[name_before_split]
|
||||
else:
|
||||
return name_before_split
|
||||
|
||||
|
||||
class PipeliningShapeError(RuntimeError):
|
||||
"""Shape mismatch between configured and runtime values."""
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user