mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Given `PipelineStage` name to manual side. Thus adding a method under `Pipe` to create PipelineStage. Moved `PipeInfo` to utils.py to avoid circular dependency between `_IR` and `PipelineStage`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128240 Approved by: https://github.com/wconstab, https://github.com/H-Huang
135 lines
4.0 KiB
Python
135 lines
4.0 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import List, Tuple, Union
|
|
|
|
import torch
|
|
from torch import fx
|
|
from torch.export.unflatten import InterpreterModule
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def flatten_args_detach(args):
|
|
"""
|
|
Flatten the args into a list form and detach the tensors from computational graph.
|
|
"""
|
|
flat_detached_args = []
|
|
|
|
def extract_tensor_args(a):
|
|
nonlocal flat_detached_args
|
|
if isinstance(a, torch.Tensor):
|
|
val = a.detach().requires_grad_(a.requires_grad)
|
|
flat_detached_args.append(val)
|
|
return val
|
|
else:
|
|
flat_detached_args.append(a)
|
|
return a
|
|
|
|
new_args = fx.node.map_aggregate(
|
|
args,
|
|
extract_tensor_args,
|
|
)
|
|
|
|
return new_args, flat_detached_args
|
|
|
|
|
|
def flatten_args(args):
|
|
"""
|
|
Flatten the args into a list form.
|
|
"""
|
|
flat_args = []
|
|
|
|
def extract_tensor_args(a):
|
|
nonlocal flat_args
|
|
flat_args.append(a)
|
|
return a
|
|
|
|
fx.node.map_aggregate(
|
|
args,
|
|
extract_tensor_args,
|
|
)
|
|
|
|
return flat_args
|
|
|
|
|
|
def modify_graph_op_device(
|
|
gm: torch.fx.GraphModule,
|
|
new_device: torch.device,
|
|
):
|
|
"""
|
|
Modify the device argument of all "call_function" nodes in the graph. This
|
|
is useful for moving the graph to a different device. In particular for
|
|
generator ops, like torch.ones.
|
|
"""
|
|
modified = False
|
|
for node in gm.graph.nodes:
|
|
if node.op == "call_function":
|
|
if "device" in node.kwargs and node.kwargs["device"] != new_device:
|
|
logger.debug(
|
|
f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004
|
|
)
|
|
node.update_kwarg("device", new_device)
|
|
modified = True
|
|
elif node.op == "call_module":
|
|
# Recursively modify "device" in submodules
|
|
submod = gm.get_submodule(node.target)
|
|
if isinstance(submod, torch.fx.GraphModule):
|
|
modify_graph_op_device(submod, new_device)
|
|
elif isinstance(submod, InterpreterModule):
|
|
# If unflattening has been performed, we need to access its graph module by `.graph_module`
|
|
modify_graph_op_device(submod.graph_module, new_device)
|
|
else:
|
|
logger.warning(
|
|
f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004
|
|
)
|
|
|
|
if modified:
|
|
gm.recompile()
|
|
|
|
|
|
class PipeliningShapeError(RuntimeError):
|
|
"""Shape mismatch between configured and runtime values."""
|
|
|
|
|
|
def validate_tensor_metadata(desc, expected, given):
|
|
if not expected.shape == given.shape:
|
|
raise PipeliningShapeError(
|
|
f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
|
|
)
|
|
if not expected.dtype == given.dtype:
|
|
raise PipeliningShapeError(
|
|
f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
|
|
)
|
|
if not expected.stride() == given.stride():
|
|
raise PipeliningShapeError(
|
|
f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
|
|
)
|
|
|
|
|
|
def validate_tensors_metadata(
|
|
desc,
|
|
expected_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
|
actual_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
|
):
|
|
if len(expected_tensors) != len(actual_tensors):
|
|
raise PipeliningShapeError(
|
|
f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
|
|
)
|
|
for i in range(len(expected_tensors)):
|
|
validate_tensor_metadata(
|
|
f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class PipeInfo:
|
|
"""
|
|
Captures information for a pipeline (`Pipe` object).
|
|
"""
|
|
|
|
graph: fx.Graph
|
|
num_stages: int
|
|
has_loss_and_backward: bool
|