pytorch/test/distributed/pipelining/model_registry.py

234 lines
7.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
# This file is a model zoo for testing torch.distributed.pipelining.
import torch
from torch.autograd import Function
from torch.distributed.pipelining import pipe_split, SplitPoint
class ExampleCode(torch.nn.Module):
def __init__(self, d_hid):
super().__init__()
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.cval = torch.nn.Buffer(torch.randn((d_hid,), requires_grad=False))
self.lin0 = torch.nn.Linear(d_hid, d_hid)
self.lin1 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x):
x = torch.mm(x, self.mm_param0)
x = torch.relu(x)
# try passing a value that doesn't require_grad across skip boundaries
a_constant = self.cval.clone()
x = self.lin0(x)
pipe_split()
x = torch.relu(x) + a_constant
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
x = torch.relu(x)
return x
class ModelWithKwargs(torch.nn.Module):
DEFAULT_DHID = 512
DEFAULT_BATCH_SIZE = 256
def __init__(self, d_hid: int = DEFAULT_DHID):
super().__init__()
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin0 = torch.nn.Linear(d_hid, d_hid)
self.lin1 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)):
x = torch.mm(x, self.mm_param0)
x = x + y
x = self.lin0(x)
x = torch.relu(x)
pipe_split()
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
x = torch.relu(x)
return x
class ModelWithParamAlias(torch.nn.Module):
default_dhid = 512
default_batch_size = 256
def __init__(self, d_hid: int = default_dhid):
super().__init__()
self.mm_param1 = self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin1 = self.lin0 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x, y):
x = torch.mm(x, self.mm_param0)
x = x + y
x = self.lin0(x)
x = torch.relu(x)
pipe_split()
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
x = torch.relu(x)
return x
# MLP Layer
class MLPModule(torch.nn.Module):
def __init__(self, d_hid: int):
super().__init__()
self.net1 = torch.nn.Linear(d_hid, d_hid)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x):
x = self.net1(x)
x = self.relu(x)
x = self.net2(x)
return x
# Multi-MLP model
class MultiMLP(torch.nn.Module):
def __init__(self, d_hid: int, n_layers: int = 2):
super().__init__()
self.layers = torch.nn.ModuleList([MLPModule(d_hid) for _ in range(n_layers)])
# For testing purpose only, this should be defined by user
self.split_spec = {
f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers)
}
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class CustomLinearDx(Function):
@staticmethod
def forward(ctx, input_val, weight, bias, module, layer_idx):
ctx.save_for_backward(input_val, weight, bias)
ctx.module = module
ctx.layer_idx = layer_idx
return input_val.mm(weight.t()) + bias
@staticmethod
def backward(ctx, grad_output):
input_val, weight, _ = ctx.saved_tensors
grad_input = grad_output.mm(weight)
ctx.module.cached_context[ctx.layer_idx].append(grad_output.clone())
ctx.module.cached_context[str(ctx.layer_idx) + "_input"].append(
input_val.clone()
)
return grad_input, None, None, None, None
class CustomLinearDxDw(Function):
@staticmethod
def forward(ctx, input_val, weight, bias):
ctx.save_for_backward(input_val, weight, bias)
return input_val.mm(weight.t()) + bias
@staticmethod
def backward(ctx, grad_output):
input_val, weight, _ = ctx.saved_tensors
grad_input = grad_output.mm(weight)
grad_weight = grad_output.t().mm(input_val)
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
class MLPModuleWithDw(torch.nn.Module):
def __init__(self, d_hid: int):
super().__init__()
self.fc1_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.fc1_bias = torch.nn.Parameter(torch.randn(d_hid))
self.fc2_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.fc2_bias = torch.nn.Parameter(torch.randn(d_hid))
torch.nn.init.uniform_(self.fc1_weight, -0.01, 0.01)
torch.nn.init.uniform_(self.fc2_weight, -0.01, 0.01)
torch.nn.init.uniform_(self.fc1_bias, -0.01, 0.01)
torch.nn.init.uniform_(self.fc2_bias, -0.01, 0.01)
self.cached_context = {}
self.cached_context["fc1"] = []
self.cached_context["fc2"] = []
self.cached_context["fc1_input"] = []
self.cached_context["fc2_input"] = []
self.use_custom_logic = False
def forward(self, x):
if not self.use_custom_logic:
self.hidden = CustomLinearDxDw.apply(x, self.fc1_weight, self.fc1_bias)
self.hidden = torch.nn.functional.relu(self.hidden)
output = CustomLinearDxDw.apply(self.hidden, self.fc2_weight, self.fc2_bias)
return output
self.hidden = CustomLinearDx.apply(
x, self.fc1_weight, self.fc1_bias, self, "fc1"
)
self.hidden = torch.nn.functional.relu(self.hidden)
output = CustomLinearDx.apply(
self.hidden, self.fc2_weight, self.fc2_bias, self, "fc2"
)
return output
def compute_dW(self):
grad_output_fc1 = self.cached_context["fc1"].pop(0)
grad_output_fc2 = self.cached_context["fc2"].pop(0)
cached_input_fc1 = self.cached_context["fc1_input"].pop(0)
cached_input_fc2 = self.cached_context["fc2_input"].pop(0)
dW2 = grad_output_fc2.t().mm(cached_input_fc2)
db2 = grad_output_fc2.sum(0)
dW1 = grad_output_fc1.t().mm(cached_input_fc1)
db1 = grad_output_fc1.sum(0)
if self.fc1_weight.grad is not None:
self.fc1_weight.grad += dW1
self.fc1_bias.grad += db1
self.fc2_weight.grad += dW2
self.fc2_bias.grad += db2
else:
self.fc1_weight.grad = dW1
self.fc1_bias.grad = db1
self.fc2_weight.grad = dW2
self.fc2_bias.grad = db2
def toggle(self):
self.use_custom_logic = not self.use_custom_logic
# Multi-MLP model With Dw
class MultiMLPWithDw(torch.nn.Module):
def __init__(self, d_hid: int, n_layers: int = 2):
super().__init__()
self.layers = torch.nn.ModuleList(
[MLPModuleWithDw(d_hid) for _ in range(n_layers)]
)
# For testing purpose only, this should be defined by user
self.split_spec = {
f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers)
}
self.use_custom_logic = False
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def toggle(self):
self.use_custom_logic = not self.use_custom_logic
for layer in self.layers:
layer.toggle()
def compute_dW(self):
if not self.use_custom_logic:
raise RuntimeError("Need to call toggle() to enable custom backward and dW")
for i in reversed(range(len(self.layers))):
self.layers[i].compute_dW()