pytorch/test/test_static_runtime.py
Hao Lu 996f444c00 [pt][static_runtime] Memory model (#46896)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46896

The idea of the memory model is quite similar to that of BlackBoxPredictor, however, it's more complicated in pt due to 1) tensor views that share storage with storage refcount bumps but with different TensorImpls, 2) tensors sharing the same TensorImpl and the same storage, but with no refcount bump of the StorageImpl, 3) data types such as TensorList and Tuples that have Tensors in them, 4) need to support non-out/out variant mix while we move the aten ops to out variants.

As a result, I have to make the following adjustments:
1) remove tensors in output Tuples from internal blob list;
2) for memory allocation/deallocation, get candidate Tensors from the outputs of ops with out variant, extract StorageImpls from the Tensors, dedup, and remove output tensor StorageImpls, and get the final list of blobs for memory planning;
3) during the clean_up_memory pass, clean up memory held by the StorageImpls as well as Tensors/Lists/Tuples in IValues that don't participate in memory planning to reduce overall memory usage

Risk:
PyTorch team is planning to deprecate the current resize_outout api, which we do rely on. This is a pretty big risk.

https://www.internalfb.com/intern/diffusion/FBS/browsefile/master/fbcode/caffe2/aten/src/ATen/native/Resize.cpp?commit=6457b329847607553d34e788a3a7092f41f38895&lines=9-23

Test Plan:
```
buck test //caffe2/test:static_runtime
buck test //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test //caffe2/caffe2/fb/predictor:pytorch_predictor_test
```
Benchmarks:
```
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 numactl -m 0 -C 13 \
buck-out/opt/gen/caffe2/caffe2/fb/predictor/ptvsc2_predictor_bench \
--scripted_model=/home/hlu/ads/adindexer/adindexer_ctr_mobilefeed/pt/merge/traced_precomputation.pt \
--pt_inputs=/home/hlu/ads/adindexer/adindexer_ctr_mobilefeed/pt/merge/container_precomputation_bs1.pt \
--iters=1000 --warmup_iters=10000 --num_threads=1 --pt_enable_static_runtime=true \
--pt_cleanup_activations=true --pt_enable_out_variant=false
```

|pt_cleanup_activations	|pt_enable_out_variant	|old ms/iter	|new ms/iter	|
|---	|---	|---	|---	|
|0	|0	|0.31873	|0.30228	|
|0	|1	|0.30018	|0.29184	|
|1	|0	|0.35246	|0.31895	|
|1	|1	|0.35742	|0.30417	|

Reviewed By: bwasti, raziel

Differential Revision: D24471854

fbshipit-source-id: 4ac37dca7d2a0c362120a7f02fd3995460c9a55c
2020-11-03 23:47:59 -08:00

201 lines
7.0 KiB
Python

import numpy as np
import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
class StaticRuntime:
def __init__(self, scripted):
# this is an nn.Module
if hasattr(scripted, "_c"):
self.static_runtime = torch._C._jit_to_static_runtime(scripted._c)
else:
self.static_runtime = torch._C._jit_to_static_runtime(scripted.graph)
def __call__(self, *args, **kwargs):
if not kwargs:
return self.static_runtime.run(args)
else:
return self.static_runtime.run(args, kwargs)
def benchmark(self, args, kwargs, warmup_runs, main_runs):
self.static_runtime.benchmark(args, kwargs, warmup_runs, main_runs)
def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
return self.static_runtime.benchmark_individual_ops(
args, kwargs, warmup_runs, main_runs
)
def linear_shim(input, weight, bias=None):
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output
return ret
torch.nn.functional.linear = linear_shim
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, hid_dim, n_heads, dropout, device):
super().__init__()
assert hid_dim % n_heads == 0
self.hid_dim = hid_dim
self.n_heads = n_heads
self.head_dim = hid_dim // n_heads
self.fc_q = nn.Linear(hid_dim, hid_dim)
self.fc_k = nn.Linear(hid_dim, hid_dim)
self.fc_v = nn.Linear(hid_dim, hid_dim)
self.fc_o = nn.Linear(hid_dim, hid_dim)
# self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
def forward(self, query, key, value, mask):
batch_size = query.shape[0]
Q = self.fc_q(query)
K = self.fc_k(key)
V = self.fc_v(value)
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
# energy = energy.masked_fill(mask == 0, -1e10)
attention = torch.softmax(energy, dim=-1)
# x = torch.matmul(self.dropout(attention), V)
x = torch.matmul(attention, V)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(batch_size, -1, self.hid_dim)
x = self.fc_o(x)
return x, attention
# Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py
def create_mlp(ln, sigmoid_layer):
layers = nn.ModuleList()
for i in range(0, len(ln) - 1):
n = ln[i]
m = ln[i + 1]
LL = nn.Linear(int(n), int(m), bias=True)
mean = 0.0 # std_dev = np.sqrt(variance)
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
LL.weight.data = torch.tensor(W, requires_grad=True)
LL.bias.data = torch.tensor(bt, requires_grad=True)
layers.append(LL)
if i == sigmoid_layer:
layers.append(nn.Sigmoid())
else:
layers.append(nn.ReLU())
with torch.no_grad():
s = torch.jit.script(torch.nn.Sequential(*layers))
s.eval()
return s
def trivial_graph(a, b, c):
s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s
class TestStaticRuntime(TestCase):
def test_multihead_attention_layer(self):
HID_DIM = 256
QUERY_LEN = 8
BATCH_SIZE = 128
LAYERS = 3
HEADS = 8
DROPOUT = 0.1
device = torch.device("cpu")
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
with torch.no_grad():
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
attention.eval()
attention = torch.jit.script(attention)
attention.eval()
o_ref = attention(src, src, src, src_mask)
attention_a = StaticRuntime(attention)
o_test = attention_a(src, src, src, src_mask)
o_test_kw = attention_a(src, src, value=src, mask=src_mask)
for a, b in zip(o_ref, o_test):
torch.testing.assert_allclose(a, b)
for a, b in zip(o_ref, o_test_kw):
torch.testing.assert_allclose(a, b)
def test_multihead_attention_layer_benchmark(self):
HID_DIM = 256
QUERY_LEN = 8
BATCH_SIZE = 128
LAYERS = 3
HEADS = 8
DROPOUT = 0.1
device = torch.device("cpu")
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
with torch.no_grad():
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
attention.eval()
attention = torch.jit.script(attention)
attention_a = StaticRuntime(attention)
attention_a.benchmark([src, src, src, src_mask], {}, 2, 2)
metrics = attention_a.benchmark_individual_ops(
[src, src, src, src_mask], {}, 2, 2
)
def test_mlp(self):
# Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
ln_bot = [512, 512, 64]
sigmoid_bot = -1
ln_top = [100, 1024, 1024, 1024, 1]
sigmoid_top = 3
bot_l = create_mlp(ln_bot, sigmoid_bot)
bot_l_acc = StaticRuntime(bot_l)
top_l = create_mlp(ln_top, sigmoid_top)
top_l_acc = StaticRuntime(top_l)
with torch.no_grad():
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)[0]
torch.testing.assert_allclose(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
torch.testing.assert_allclose(acc_top, ref_top)
for _ in range(5):
with torch.no_grad():
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)[0]
torch.testing.assert_allclose(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
torch.testing.assert_allclose(acc_top, ref_top)
def test_trivial_graph(self):
s = torch.full((2, 2), 2)
tg = torch.jit.script(trivial_graph)
o_ref = tg(s, s, s)
tg_a = StaticRuntime(tg)
o_test = tg_a(s, s, s)[0]
torch.testing.assert_allclose(o_ref, o_test)
if __name__ == "__main__":
run_tests()