mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134665 Approved by: https://github.com/albanD
234 lines
7.7 KiB
Python
234 lines
7.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
# This file is a model zoo for testing torch.distributed.pipelining.
|
|
import torch
|
|
from torch.autograd import Function
|
|
from torch.distributed.pipelining import pipe_split, SplitPoint
|
|
|
|
|
|
class ExampleCode(torch.nn.Module):
|
|
def __init__(self, d_hid):
|
|
super().__init__()
|
|
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
|
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
|
self.cval = torch.nn.Buffer(torch.randn((d_hid,), requires_grad=False))
|
|
self.lin0 = torch.nn.Linear(d_hid, d_hid)
|
|
self.lin1 = torch.nn.Linear(d_hid, d_hid)
|
|
|
|
def forward(self, x):
|
|
x = torch.mm(x, self.mm_param0)
|
|
x = torch.relu(x)
|
|
# try passing a value that doesn't require_grad across skip boundaries
|
|
a_constant = self.cval.clone()
|
|
x = self.lin0(x)
|
|
pipe_split()
|
|
x = torch.relu(x) + a_constant
|
|
x = torch.mm(x, self.mm_param1)
|
|
x = self.lin1(x)
|
|
x = torch.relu(x)
|
|
return x
|
|
|
|
|
|
class ModelWithKwargs(torch.nn.Module):
|
|
DEFAULT_DHID = 512
|
|
DEFAULT_BATCH_SIZE = 256
|
|
|
|
def __init__(self, d_hid: int = DEFAULT_DHID):
|
|
super().__init__()
|
|
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
|
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
|
self.lin0 = torch.nn.Linear(d_hid, d_hid)
|
|
self.lin1 = torch.nn.Linear(d_hid, d_hid)
|
|
|
|
def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)):
|
|
x = torch.mm(x, self.mm_param0)
|
|
x = x + y
|
|
x = self.lin0(x)
|
|
x = torch.relu(x)
|
|
pipe_split()
|
|
x = torch.mm(x, self.mm_param1)
|
|
x = self.lin1(x)
|
|
x = torch.relu(x)
|
|
return x
|
|
|
|
|
|
class ModelWithParamAlias(torch.nn.Module):
|
|
default_dhid = 512
|
|
default_batch_size = 256
|
|
|
|
def __init__(self, d_hid: int = default_dhid):
|
|
super().__init__()
|
|
self.mm_param1 = self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
|
self.lin1 = self.lin0 = torch.nn.Linear(d_hid, d_hid)
|
|
|
|
def forward(self, x, y):
|
|
x = torch.mm(x, self.mm_param0)
|
|
x = x + y
|
|
x = self.lin0(x)
|
|
x = torch.relu(x)
|
|
pipe_split()
|
|
x = torch.mm(x, self.mm_param1)
|
|
x = self.lin1(x)
|
|
x = torch.relu(x)
|
|
return x
|
|
|
|
|
|
# MLP Layer
|
|
class MLPModule(torch.nn.Module):
|
|
def __init__(self, d_hid: int):
|
|
super().__init__()
|
|
self.net1 = torch.nn.Linear(d_hid, d_hid)
|
|
self.relu = torch.nn.ReLU()
|
|
self.net2 = torch.nn.Linear(d_hid, d_hid)
|
|
|
|
def forward(self, x):
|
|
x = self.net1(x)
|
|
x = self.relu(x)
|
|
x = self.net2(x)
|
|
return x
|
|
|
|
|
|
# Multi-MLP model
|
|
class MultiMLP(torch.nn.Module):
|
|
def __init__(self, d_hid: int, n_layers: int = 2):
|
|
super().__init__()
|
|
self.layers = torch.nn.ModuleList([MLPModule(d_hid) for _ in range(n_layers)])
|
|
# For testing purpose only, this should be defined by user
|
|
self.split_spec = {
|
|
f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers)
|
|
}
|
|
|
|
def forward(self, x):
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
class CustomLinearDx(Function):
|
|
@staticmethod
|
|
def forward(ctx, input_val, weight, bias, module, layer_idx):
|
|
ctx.save_for_backward(input_val, weight, bias)
|
|
ctx.module = module
|
|
ctx.layer_idx = layer_idx
|
|
return input_val.mm(weight.t()) + bias
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input_val, weight, _ = ctx.saved_tensors
|
|
grad_input = grad_output.mm(weight)
|
|
ctx.module.cached_context[ctx.layer_idx].append(grad_output.clone())
|
|
ctx.module.cached_context[str(ctx.layer_idx) + "_input"].append(
|
|
input_val.clone()
|
|
)
|
|
return grad_input, None, None, None, None
|
|
|
|
|
|
class CustomLinearDxDw(Function):
|
|
@staticmethod
|
|
def forward(ctx, input_val, weight, bias):
|
|
ctx.save_for_backward(input_val, weight, bias)
|
|
return input_val.mm(weight.t()) + bias
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input_val, weight, _ = ctx.saved_tensors
|
|
grad_input = grad_output.mm(weight)
|
|
grad_weight = grad_output.t().mm(input_val)
|
|
grad_bias = grad_output.sum(0)
|
|
return grad_input, grad_weight, grad_bias
|
|
|
|
|
|
class MLPModuleWithDw(torch.nn.Module):
|
|
def __init__(self, d_hid: int):
|
|
super().__init__()
|
|
self.fc1_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
|
self.fc1_bias = torch.nn.Parameter(torch.randn(d_hid))
|
|
self.fc2_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
|
self.fc2_bias = torch.nn.Parameter(torch.randn(d_hid))
|
|
|
|
torch.nn.init.uniform_(self.fc1_weight, -0.01, 0.01)
|
|
torch.nn.init.uniform_(self.fc2_weight, -0.01, 0.01)
|
|
torch.nn.init.uniform_(self.fc1_bias, -0.01, 0.01)
|
|
torch.nn.init.uniform_(self.fc2_bias, -0.01, 0.01)
|
|
|
|
self.cached_context = {}
|
|
self.cached_context["fc1"] = []
|
|
self.cached_context["fc2"] = []
|
|
self.cached_context["fc1_input"] = []
|
|
self.cached_context["fc2_input"] = []
|
|
|
|
self.use_custom_logic = False
|
|
|
|
def forward(self, x):
|
|
if not self.use_custom_logic:
|
|
self.hidden = CustomLinearDxDw.apply(x, self.fc1_weight, self.fc1_bias)
|
|
self.hidden = torch.nn.functional.relu(self.hidden)
|
|
output = CustomLinearDxDw.apply(self.hidden, self.fc2_weight, self.fc2_bias)
|
|
return output
|
|
|
|
self.hidden = CustomLinearDx.apply(
|
|
x, self.fc1_weight, self.fc1_bias, self, "fc1"
|
|
)
|
|
self.hidden = torch.nn.functional.relu(self.hidden)
|
|
output = CustomLinearDx.apply(
|
|
self.hidden, self.fc2_weight, self.fc2_bias, self, "fc2"
|
|
)
|
|
return output
|
|
|
|
def compute_dW(self):
|
|
grad_output_fc1 = self.cached_context["fc1"].pop(0)
|
|
grad_output_fc2 = self.cached_context["fc2"].pop(0)
|
|
cached_input_fc1 = self.cached_context["fc1_input"].pop(0)
|
|
cached_input_fc2 = self.cached_context["fc2_input"].pop(0)
|
|
|
|
dW2 = grad_output_fc2.t().mm(cached_input_fc2)
|
|
db2 = grad_output_fc2.sum(0)
|
|
|
|
dW1 = grad_output_fc1.t().mm(cached_input_fc1)
|
|
db1 = grad_output_fc1.sum(0)
|
|
|
|
if self.fc1_weight.grad is not None:
|
|
self.fc1_weight.grad += dW1
|
|
self.fc1_bias.grad += db1
|
|
self.fc2_weight.grad += dW2
|
|
self.fc2_bias.grad += db2
|
|
else:
|
|
self.fc1_weight.grad = dW1
|
|
self.fc1_bias.grad = db1
|
|
self.fc2_weight.grad = dW2
|
|
self.fc2_bias.grad = db2
|
|
|
|
def toggle(self):
|
|
self.use_custom_logic = not self.use_custom_logic
|
|
|
|
|
|
# Multi-MLP model With Dw
|
|
class MultiMLPWithDw(torch.nn.Module):
|
|
def __init__(self, d_hid: int, n_layers: int = 2):
|
|
super().__init__()
|
|
self.layers = torch.nn.ModuleList(
|
|
[MLPModuleWithDw(d_hid) for _ in range(n_layers)]
|
|
)
|
|
# For testing purpose only, this should be defined by user
|
|
self.split_spec = {
|
|
f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers)
|
|
}
|
|
self.use_custom_logic = False
|
|
|
|
def forward(self, x):
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
return x
|
|
|
|
def toggle(self):
|
|
self.use_custom_logic = not self.use_custom_logic
|
|
for layer in self.layers:
|
|
layer.toggle()
|
|
|
|
def compute_dW(self):
|
|
if not self.use_custom_logic:
|
|
raise RuntimeError("Need to call toggle() to enable custom backward and dW")
|
|
|
|
for i in reversed(range(len(self.layers))):
|
|
self.layers[i].compute_dW()
|