pytorch/caffe2/python/memonger_test.py
Aapo Kyrola f82a510be6 share forward activation blobs + pass unused free blobs down all branches + use shape infernece
Summary:
Added optional support for using activation blobs for sharing as well. Doing this change revealed an non-optimal implementation in the blob sharing: we need to prefer to reuse freeblobs by prefering those blobs that are already shared by many other blobs. Otherwise the memory usage can increase when the pool of 'free blobs' grows.

Also, my first version only passed "free blobs" (i.e blobs in recycling pool) down the first branch when operators forked. But now we pass those blobs that were not used by the first branch down the second branch and so on.

Also added support for blob size information in the heuristic. This uses the shape inference mechanism.

I had to also do some small tweaks:
- use Sum() operator as a way to match shapes of blobs that had otherwise unknown shapes. This is related to the Sum() operator that is added to combine multiple incoming gradient inputs (with _autosplit gradients).
- a couple of random shape inference fixes

This reduces the Resnet-50 memory usage on 64 batch from 9.45 Gig to 8.5 Gig.
For a 32 batch, the memory usage is 4330 MiB, down from 4800 MB, compared to Torch's 6856MiB (thanks prigoyal  for checking this for me).

This is unfortunately quite a bunch to review...

Reviewed By: asaadaldien

Differential Revision: D4393909

fbshipit-source-id: 9c7c94125f96512bea80463ebcb63c215ef95ff9
2017-04-25 14:23:25 -07:00

369 lines
15 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
from caffe2.python import workspace, cnn, memonger, core
import caffe2.python.hypothesis_test_util as hu
import hypothesis.strategies as st
from hypothesis import given
def has_blob(proto, needle):
for op in proto.op:
for inp in op.input:
if inp == needle:
return True
for outp in op.output:
if outp == needle:
return True
return False
def count_blobs(proto):
blobs = set()
for op in proto.op:
blobs = blobs.union(set(op.input)).union(set(op.output))
return len(blobs)
class MemongerTest(hu.HypothesisTestCase):
@given(input_dim=st.integers(min_value=1, max_value=10),
output_dim=st.integers(min_value=1, max_value=10),
batch_size=st.integers(min_value=1, max_value=10),
do=st.sampled_from(hu.device_options),
algo=st.sampled_from(memonger.AssignmentAlgorithm))
def test_simple_memonger(self, input_dim, output_dim, batch_size, do, algo):
m = cnn.CNNModelHelper()
fc1 = m.FC("data", "fc1", dim_in=input_dim, dim_out=output_dim)
fc2 = m.FC(fc1, "fc2", dim_in=output_dim, dim_out=output_dim)
fc3 = m.FC(fc2, "fc3", dim_in=output_dim, dim_out=output_dim)
fc3.Relu([], fc3)\
.Softmax([], "pred") \
.LabelCrossEntropy(["label"], ["xent"]) \
.AveragedLoss([], "loss")
input_to_grad = m.AddGradientOperators(["loss"])
m.net.Proto().device_option.CopyFrom(do)
m.param_init_net.Proto().device_option.CopyFrom(do)
static_blobs = \
[o for op in m.param_init_net.Proto().op for o in op.output] + \
["data", "label", "loss", input_to_grad["fc1_w"]]
optimization = memonger.optimize_interference(
m.Proto(), static_blobs, algo=algo)
data = np.random.randn(batch_size, input_dim).astype(np.float32)
label = np.random.randint(
low=0, high=output_dim, size=(batch_size,)).astype(np.int32)
workspace.RunNetOnce(m.param_init_net)
workspace.FeedBlob("data", data, device_option=do)
workspace.FeedBlob("label", label, device_option=do)
workspace.RunNetOnce(m.net)
loss = workspace.FetchBlob("loss")
grad = workspace.FetchBlob(str(input_to_grad["fc1_w"]))
workspace.RunNetOnce(optimization.net)
optimized_loss = workspace.FetchBlob("loss")
optimized_grad = workspace.FetchBlob(str(input_to_grad["fc1_w"]))
np.testing.assert_almost_equal(loss, optimized_loss)
np.testing.assert_almost_equal(grad, optimized_grad)
stats = memonger.compute_statistics(optimization.assignments)
self.assertLess(stats.optimized_nbytes, stats.baseline_nbytes)
# run with blob sizes
blob_sizes = memonger.collect_blob_sizes(m.Proto())
optimization1 = memonger.optimize_interference(
m.Proto(), static_blobs, blob_sizes=blob_sizes, algo=algo)
workspace.RunNetOnce(optimization1.net)
optimized_loss = workspace.FetchBlob("loss")
optimized_grad = workspace.FetchBlob(str(input_to_grad["fc1_w"]))
np.testing.assert_almost_equal(loss, optimized_loss)
np.testing.assert_almost_equal(grad, optimized_grad)
stats = memonger.compute_statistics(optimization1.assignments)
self.assertLessEqual(stats.optimized_nbytes, stats.baseline_nbytes)
@given(input_dim=st.integers(min_value=1, max_value=4),
output_dim=st.integers(min_value=1, max_value=4),
batch_size=st.integers(min_value=1, max_value=4))
def test_gradient_optim(self, input_dim, output_dim, batch_size):
m = cnn.CNNModelHelper()
with core.NameScope("name_x"):
fc1 = m.FC("data", "fc1", dim_in=input_dim, dim_out=output_dim)
fc2 = m.FC(fc1, "fc2", dim_in=output_dim, dim_out=output_dim)
fc3 = m.FC(fc2, "fc3", dim_in=output_dim, dim_out=output_dim)
fc4 = m.FC(fc3, "fc4", dim_in=output_dim, dim_out=output_dim)
fc5 = m.FC(fc4, "fc5", dim_in=output_dim, dim_out=output_dim)
fc5.Relu([], fc5)\
.Softmax([], "pred") \
.LabelCrossEntropy(["label"], ["xent"]) \
.AveragedLoss([], "loss")
input_to_grad = m.AddGradientOperators(["name_x/loss"])
blobs_before = count_blobs(m.net.Proto())
optim_proto = memonger.share_grad_blobs(
m.net,
["name_x/loss"],
set(m.param_to_grad.values()),
"name_x/",
share_activations=False,
)
blobs_after = count_blobs(optim_proto)
self.assertLess(blobs_after, blobs_before)
optim_proto_wacts = memonger.share_grad_blobs(
m.net,
["name_x/loss"],
set(m.param_to_grad.values()),
"name_x/",
share_activations=True,
)
blobs_wact_optim = count_blobs(optim_proto_wacts)
self.assertLessEqual(blobs_wact_optim, blobs_after)
# Check that the last activations are not shared
self.assertTrue(has_blob(optim_proto, "name_x/fc5"))
self.assertTrue(
has_blob(optim_proto_wacts, "name_x/fc5"),
"Dont remap final activation",
)
# Test networks produce exactly same gradients
data = np.random.randn(batch_size, input_dim).astype(np.float32)
label = np.random.randint(
low=0, high=output_dim, size=(batch_size,)).astype(np.int32)
workspace.RunNetOnce(m.param_init_net)
workspace.FeedBlob("name_x/data", data)
workspace.FeedBlob("name_x/label", label)
workspace.RunNetOnce(m.net)
loss = workspace.FetchBlob("name_x/loss")
grad = workspace.FetchBlob(str(input_to_grad["name_x/fc1_w"]))
workspace.RunNetOnce(optim_proto)
optimized_loss = workspace.FetchBlob("name_x/loss")
optimized_grad = workspace.FetchBlob(str(input_to_grad["name_x/fc1_w"]))
np.testing.assert_almost_equal(loss, optimized_loss)
np.testing.assert_almost_equal(grad, optimized_grad)
# Run with the forward optimization
workspace.RunNetOnce(optim_proto_wacts)
optimized_loss = workspace.FetchBlob("name_x/loss")
optimized_grad = workspace.FetchBlob(str(input_to_grad["name_x/fc1_w"]))
np.testing.assert_almost_equal(loss, optimized_loss)
np.testing.assert_almost_equal(grad, optimized_grad)
@given(input_dim=st.integers(min_value=4, max_value=4),
output_dim=st.integers(min_value=4, max_value=4),
batch_size=st.integers(min_value=4, max_value=4))
def test_gradient_optim_tree(self, input_dim, output_dim, batch_size):
m = cnn.CNNModelHelper()
with core.NameScope("name_x"):
fc1 = m.FC("data", "fc1", dim_in=input_dim, dim_out=output_dim)
fc2 = m.FC(fc1, "fc2", dim_in=output_dim, dim_out=output_dim)
fc3 = m.FC(fc2, "fc3", dim_in=output_dim, dim_out=output_dim)
fc4 = m.FC(fc3, "fc4", dim_in=output_dim, dim_out=output_dim)
fc5 = m.FC(fc4, "fc5", dim_in=output_dim, dim_out=output_dim)
fc5.Relu([], fc5) \
.Softmax([], "pred1") \
.LabelCrossEntropy(["label"], ["xent1"]) \
.AveragedLoss([], "loss1")
fc6 = m.FC(fc5, "fc6", dim_in=output_dim, dim_out=output_dim)
fc6.Relu([], fc6) \
.Softmax([], "pred2") \
.LabelCrossEntropy(["label"], ["xent2"]) \
.AveragedLoss([], "loss2")
input_to_grad = m.AddGradientOperators(["name_x/loss1", "name_x/loss2"])
blobs_before = count_blobs(m.net.Proto())
optim_proto = memonger.share_grad_blobs(
m.net,
["name_x/loss1", "name_x/loss2"],
set(m.param_to_grad.values()),
"name_x", # "name_x//shared_gradinp_0_shared" if using "name_x/"
share_activations=True,
dont_share_blobs=set(['name_x/fc6', 'name_x/fc5']),
)
blobs_after = count_blobs(optim_proto)
self.assertLess(blobs_after, blobs_before)
self.assertTrue(has_blob(optim_proto, "name_x/fc6"))
# Test networks produce exactly same gradients
data = np.random.randn(batch_size, input_dim).astype(np.float32)
label = np.random.randint(
low=0, high=output_dim, size=(batch_size,)).astype(np.int32)
workspace.RunNetOnce(m.param_init_net)
workspace.FeedBlob("name_x/data", data)
workspace.FeedBlob("name_x/label", label)
workspace.RunNetOnce(m.net)
loss1 = workspace.FetchBlob("name_x/loss1")
loss2 = workspace.FetchBlob("name_x/loss2")
grad = workspace.FetchBlob(str(input_to_grad["name_x/fc1_w"]))
workspace.RunNetOnce(optim_proto)
optimized_loss1 = workspace.FetchBlob("name_x/loss1")
optimized_loss2 = workspace.FetchBlob("name_x/loss2")
optimized_grad = workspace.FetchBlob(str(input_to_grad["name_x/fc1_w"]))
np.testing.assert_almost_equal(loss1, optimized_loss1)
np.testing.assert_almost_equal(loss2, optimized_loss2)
np.testing.assert_almost_equal(grad, optimized_grad)
@given(input_dim=st.integers(min_value=4, max_value=4),
output_dim=st.integers(min_value=4, max_value=4),
batch_size=st.integers(min_value=4, max_value=4))
def test_forward_optim_tree_daggy(self, input_dim, output_dim, batch_size):
m = cnn.CNNModelHelper()
m.Proto().type = "dag"
m.Proto().num_workers = 4
with core.NameScope("name_x"):
fc1 = m.FC("data", "fc1", dim_in=input_dim, dim_out=output_dim)
fc2 = m.FC(fc1, "fc2", dim_in=output_dim, dim_out=output_dim)
fc3 = m.FC(fc2, "fc3", dim_in=output_dim, dim_out=output_dim)
fc4 = m.FC(fc3, "fc4", dim_in=output_dim, dim_out=output_dim)
fc5 = m.FC(fc4, "fc5", dim_in=output_dim, dim_out=output_dim)
# Branch
fc3b = m.FC(fc2, "fc3b", dim_in=output_dim, dim_out=output_dim)
fc4b = m.FC(fc3b, "fc4b", dim_in=output_dim, dim_out=output_dim)
fc5b = m.FC(fc4b, "fc5b", dim_in=output_dim, dim_out=output_dim)
fc5sum = m.Sum([fc5, fc5b], "fc5sum")
fc5.Relu([], fc5sum) \
.Softmax([], "pred1") \
.LabelCrossEntropy(["label"], ["xent1"]) \
.AveragedLoss([], "loss1")
fc6 = m.FC(fc5, "fc6", dim_in=output_dim, dim_out=output_dim)
fc6.Relu([], fc6) \
.Softmax([], "pred2") \
.LabelCrossEntropy(["label"], ["xent2"]) \
.AveragedLoss([], "loss2")
blobs_before = count_blobs(m.net.Proto())
optim_proto = memonger.optimize_inference_for_dag(
m.net, ["name_x/data"], "name_x"
)
blobs_after = count_blobs(optim_proto)
self.assertLess(blobs_after, blobs_before)
# Test networks produce exactly same results
data = np.random.randn(batch_size, input_dim).astype(np.float32)
label = np.random.randint(
low=0, high=output_dim, size=(batch_size,)).astype(np.int32)
workspace.RunNetOnce(m.param_init_net)
workspace.FeedBlob("name_x/data", data)
workspace.FeedBlob("name_x/label", label)
workspace.RunNetOnce(m.net)
loss1 = workspace.FetchBlob("name_x/loss1")
loss2 = workspace.FetchBlob("name_x/loss2")
workspace.RunNetOnce(optim_proto)
optimized_loss1 = workspace.FetchBlob("name_x/loss1")
optimized_loss2 = workspace.FetchBlob("name_x/loss2")
np.testing.assert_almost_equal(loss1, optimized_loss1)
np.testing.assert_almost_equal(loss2, optimized_loss2)
def test_topological_sort_longest_path(self):
m = cnn.CNNModelHelper()
# 0
m.Copy("conv0_w_comp", "conv0_w")
# 1
conv0 = m.Conv("data", "conv0", 32, 32, 4)
# 2
m.Copy("conv2_w", "conv2_w")
# 3
m.Conv(conv0, "conv2", 16, 32, 4)
g = memonger.compute_interference_graph(m.net.Proto().op)
orders_org = memonger.topological_sort_traversal(g)
orders_gt_org = [2, 0, 1, 3]
self.assertEqual(orders_gt_org, orders_org)
orders = memonger.topological_sort_traversal_longest_path(g)
# longer path is in front of the shorter one
orders_gt = [0, 1, 2, 3]
self.assertEqual(orders_gt, orders)
def test_topological_sort_longest_path_multi_target(self):
# two outputs: conv2 and data4
m = cnn.CNNModelHelper()
# 0
m.Copy("conv0_w_comp", "conv0_w")
# 1
conv0 = m.Conv("data", "conv0", 32, 32, 4)
# 2
m.Copy("conv2_w", "conv2_w")
# 3
m.Conv(conv0, "conv2", 16, 32, 4)
# 4
m.Copy("data1", "data2")
# 5
m.Copy("data2", "data3")
g = memonger.compute_interference_graph(m.net.Proto().op)
orders_org = memonger.topological_sort_traversal(g)
orders_gt_org = [4, 5, 2, 0, 1, 3]
self.assertEqual(orders_gt_org, orders_org)
orders = memonger.topological_sort_traversal_longest_path(g)
# longer path is in front of the shorter one
orders_gt = [0, 1, 2, 3, 4, 5]
self.assertEqual(orders_gt, orders)
def test_topological_sort_longest_path_single_node(self):
# single node
m = cnn.CNNModelHelper()
# 0
m.Copy("conv0_w_comp", "conv0_w")
g = memonger.compute_interference_graph(m.net.Proto().op)
orders_org = memonger.topological_sort_traversal(g)
orders_gt_org = [0]
self.assertEqual(orders_gt_org, orders_org)
orders = memonger.topological_sort_traversal_longest_path(g)
# longer path is in front of the shorter one
orders_gt = [0]
self.assertEqual(orders_gt, orders)
def test_compute_assignments_greedy(self):
LiveRange = memonger.LiveRange
ranges_sorted = [
('b1', LiveRange(1, 3, 10)),
('b2', LiveRange(3, 4, 1)),
('b3', LiveRange(5, 6, 1)),
('b4', LiveRange(5, 7, 10)),
]
assignment_gt = [
[ranges_sorted[0], ranges_sorted[3]],
[ranges_sorted[1], ranges_sorted[2]],
]
best = memonger.compute_assignments_greedy(ranges_sorted, None)
self.assertEqual(memonger.get_memory_usage(best), 11)
self.assertEqual(best, assignment_gt)
def test_compute_assignments_dp(self):
LiveRange = memonger.LiveRange
ranges_sorted = [
('b1', LiveRange(1, 3, 10)),
('b2', LiveRange(3, 4, 1)),
('b3', LiveRange(5, 6, 1)),
('b4', LiveRange(5, 7, 10)),
]
best = memonger.compute_assignments_dp(ranges_sorted, None)
self.assertEqual(memonger.get_memory_usage(best), 11)
def test_compute_assignments_dp1(self):
LiveRange = memonger.LiveRange
ranges_sorted = [
('b1', LiveRange(1, 2, 10)),
('b2', LiveRange(4, 6, 1)),
('b3', LiveRange(5, 6, 10)),
]
best = memonger.compute_assignments_dp(ranges_sorted, [])
self.assertEqual(memonger.get_memory_usage(best), 11)