mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Added optional support for using activation blobs for sharing as well. Doing this change revealed an non-optimal implementation in the blob sharing: we need to prefer to reuse freeblobs by prefering those blobs that are already shared by many other blobs. Otherwise the memory usage can increase when the pool of 'free blobs' grows. Also, my first version only passed "free blobs" (i.e blobs in recycling pool) down the first branch when operators forked. But now we pass those blobs that were not used by the first branch down the second branch and so on. Also added support for blob size information in the heuristic. This uses the shape inference mechanism. I had to also do some small tweaks: - use Sum() operator as a way to match shapes of blobs that had otherwise unknown shapes. This is related to the Sum() operator that is added to combine multiple incoming gradient inputs (with _autosplit gradients). - a couple of random shape inference fixes This reduces the Resnet-50 memory usage on 64 batch from 9.45 Gig to 8.5 Gig. For a 32 batch, the memory usage is 4330 MiB, down from 4800 MB, compared to Torch's 6856MiB (thanks prigoyal for checking this for me). This is unfortunately quite a bunch to review... Reviewed By: asaadaldien Differential Revision: D4393909 fbshipit-source-id: 9c7c94125f96512bea80463ebcb63c215ef95ff9
369 lines
15 KiB
Python
369 lines
15 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import numpy as np
|
|
|
|
from caffe2.python import workspace, cnn, memonger, core
|
|
import caffe2.python.hypothesis_test_util as hu
|
|
import hypothesis.strategies as st
|
|
from hypothesis import given
|
|
|
|
|
|
def has_blob(proto, needle):
|
|
for op in proto.op:
|
|
for inp in op.input:
|
|
if inp == needle:
|
|
return True
|
|
for outp in op.output:
|
|
if outp == needle:
|
|
return True
|
|
return False
|
|
|
|
|
|
def count_blobs(proto):
|
|
blobs = set()
|
|
for op in proto.op:
|
|
blobs = blobs.union(set(op.input)).union(set(op.output))
|
|
return len(blobs)
|
|
|
|
|
|
class MemongerTest(hu.HypothesisTestCase):
|
|
@given(input_dim=st.integers(min_value=1, max_value=10),
|
|
output_dim=st.integers(min_value=1, max_value=10),
|
|
batch_size=st.integers(min_value=1, max_value=10),
|
|
do=st.sampled_from(hu.device_options),
|
|
algo=st.sampled_from(memonger.AssignmentAlgorithm))
|
|
def test_simple_memonger(self, input_dim, output_dim, batch_size, do, algo):
|
|
m = cnn.CNNModelHelper()
|
|
fc1 = m.FC("data", "fc1", dim_in=input_dim, dim_out=output_dim)
|
|
fc2 = m.FC(fc1, "fc2", dim_in=output_dim, dim_out=output_dim)
|
|
fc3 = m.FC(fc2, "fc3", dim_in=output_dim, dim_out=output_dim)
|
|
|
|
fc3.Relu([], fc3)\
|
|
.Softmax([], "pred") \
|
|
.LabelCrossEntropy(["label"], ["xent"]) \
|
|
.AveragedLoss([], "loss")
|
|
input_to_grad = m.AddGradientOperators(["loss"])
|
|
m.net.Proto().device_option.CopyFrom(do)
|
|
m.param_init_net.Proto().device_option.CopyFrom(do)
|
|
static_blobs = \
|
|
[o for op in m.param_init_net.Proto().op for o in op.output] + \
|
|
["data", "label", "loss", input_to_grad["fc1_w"]]
|
|
|
|
optimization = memonger.optimize_interference(
|
|
m.Proto(), static_blobs, algo=algo)
|
|
data = np.random.randn(batch_size, input_dim).astype(np.float32)
|
|
label = np.random.randint(
|
|
low=0, high=output_dim, size=(batch_size,)).astype(np.int32)
|
|
workspace.RunNetOnce(m.param_init_net)
|
|
workspace.FeedBlob("data", data, device_option=do)
|
|
workspace.FeedBlob("label", label, device_option=do)
|
|
workspace.RunNetOnce(m.net)
|
|
loss = workspace.FetchBlob("loss")
|
|
grad = workspace.FetchBlob(str(input_to_grad["fc1_w"]))
|
|
workspace.RunNetOnce(optimization.net)
|
|
optimized_loss = workspace.FetchBlob("loss")
|
|
optimized_grad = workspace.FetchBlob(str(input_to_grad["fc1_w"]))
|
|
np.testing.assert_almost_equal(loss, optimized_loss)
|
|
np.testing.assert_almost_equal(grad, optimized_grad)
|
|
stats = memonger.compute_statistics(optimization.assignments)
|
|
self.assertLess(stats.optimized_nbytes, stats.baseline_nbytes)
|
|
|
|
# run with blob sizes
|
|
blob_sizes = memonger.collect_blob_sizes(m.Proto())
|
|
optimization1 = memonger.optimize_interference(
|
|
m.Proto(), static_blobs, blob_sizes=blob_sizes, algo=algo)
|
|
workspace.RunNetOnce(optimization1.net)
|
|
optimized_loss = workspace.FetchBlob("loss")
|
|
optimized_grad = workspace.FetchBlob(str(input_to_grad["fc1_w"]))
|
|
np.testing.assert_almost_equal(loss, optimized_loss)
|
|
np.testing.assert_almost_equal(grad, optimized_grad)
|
|
stats = memonger.compute_statistics(optimization1.assignments)
|
|
self.assertLessEqual(stats.optimized_nbytes, stats.baseline_nbytes)
|
|
|
|
@given(input_dim=st.integers(min_value=1, max_value=4),
|
|
output_dim=st.integers(min_value=1, max_value=4),
|
|
batch_size=st.integers(min_value=1, max_value=4))
|
|
def test_gradient_optim(self, input_dim, output_dim, batch_size):
|
|
m = cnn.CNNModelHelper()
|
|
with core.NameScope("name_x"):
|
|
fc1 = m.FC("data", "fc1", dim_in=input_dim, dim_out=output_dim)
|
|
fc2 = m.FC(fc1, "fc2", dim_in=output_dim, dim_out=output_dim)
|
|
fc3 = m.FC(fc2, "fc3", dim_in=output_dim, dim_out=output_dim)
|
|
fc4 = m.FC(fc3, "fc4", dim_in=output_dim, dim_out=output_dim)
|
|
fc5 = m.FC(fc4, "fc5", dim_in=output_dim, dim_out=output_dim)
|
|
fc5.Relu([], fc5)\
|
|
.Softmax([], "pred") \
|
|
.LabelCrossEntropy(["label"], ["xent"]) \
|
|
.AveragedLoss([], "loss")
|
|
input_to_grad = m.AddGradientOperators(["name_x/loss"])
|
|
|
|
blobs_before = count_blobs(m.net.Proto())
|
|
optim_proto = memonger.share_grad_blobs(
|
|
m.net,
|
|
["name_x/loss"],
|
|
set(m.param_to_grad.values()),
|
|
"name_x/",
|
|
share_activations=False,
|
|
)
|
|
blobs_after = count_blobs(optim_proto)
|
|
self.assertLess(blobs_after, blobs_before)
|
|
|
|
optim_proto_wacts = memonger.share_grad_blobs(
|
|
m.net,
|
|
["name_x/loss"],
|
|
set(m.param_to_grad.values()),
|
|
"name_x/",
|
|
share_activations=True,
|
|
)
|
|
blobs_wact_optim = count_blobs(optim_proto_wacts)
|
|
self.assertLessEqual(blobs_wact_optim, blobs_after)
|
|
|
|
# Check that the last activations are not shared
|
|
self.assertTrue(has_blob(optim_proto, "name_x/fc5"))
|
|
self.assertTrue(
|
|
has_blob(optim_proto_wacts, "name_x/fc5"),
|
|
"Dont remap final activation",
|
|
)
|
|
|
|
# Test networks produce exactly same gradients
|
|
data = np.random.randn(batch_size, input_dim).astype(np.float32)
|
|
label = np.random.randint(
|
|
low=0, high=output_dim, size=(batch_size,)).astype(np.int32)
|
|
workspace.RunNetOnce(m.param_init_net)
|
|
workspace.FeedBlob("name_x/data", data)
|
|
workspace.FeedBlob("name_x/label", label)
|
|
workspace.RunNetOnce(m.net)
|
|
loss = workspace.FetchBlob("name_x/loss")
|
|
grad = workspace.FetchBlob(str(input_to_grad["name_x/fc1_w"]))
|
|
workspace.RunNetOnce(optim_proto)
|
|
optimized_loss = workspace.FetchBlob("name_x/loss")
|
|
optimized_grad = workspace.FetchBlob(str(input_to_grad["name_x/fc1_w"]))
|
|
np.testing.assert_almost_equal(loss, optimized_loss)
|
|
np.testing.assert_almost_equal(grad, optimized_grad)
|
|
|
|
# Run with the forward optimization
|
|
workspace.RunNetOnce(optim_proto_wacts)
|
|
optimized_loss = workspace.FetchBlob("name_x/loss")
|
|
optimized_grad = workspace.FetchBlob(str(input_to_grad["name_x/fc1_w"]))
|
|
np.testing.assert_almost_equal(loss, optimized_loss)
|
|
np.testing.assert_almost_equal(grad, optimized_grad)
|
|
|
|
@given(input_dim=st.integers(min_value=4, max_value=4),
|
|
output_dim=st.integers(min_value=4, max_value=4),
|
|
batch_size=st.integers(min_value=4, max_value=4))
|
|
def test_gradient_optim_tree(self, input_dim, output_dim, batch_size):
|
|
m = cnn.CNNModelHelper()
|
|
with core.NameScope("name_x"):
|
|
fc1 = m.FC("data", "fc1", dim_in=input_dim, dim_out=output_dim)
|
|
fc2 = m.FC(fc1, "fc2", dim_in=output_dim, dim_out=output_dim)
|
|
fc3 = m.FC(fc2, "fc3", dim_in=output_dim, dim_out=output_dim)
|
|
fc4 = m.FC(fc3, "fc4", dim_in=output_dim, dim_out=output_dim)
|
|
fc5 = m.FC(fc4, "fc5", dim_in=output_dim, dim_out=output_dim)
|
|
fc5.Relu([], fc5) \
|
|
.Softmax([], "pred1") \
|
|
.LabelCrossEntropy(["label"], ["xent1"]) \
|
|
.AveragedLoss([], "loss1")
|
|
fc6 = m.FC(fc5, "fc6", dim_in=output_dim, dim_out=output_dim)
|
|
fc6.Relu([], fc6) \
|
|
.Softmax([], "pred2") \
|
|
.LabelCrossEntropy(["label"], ["xent2"]) \
|
|
.AveragedLoss([], "loss2")
|
|
input_to_grad = m.AddGradientOperators(["name_x/loss1", "name_x/loss2"])
|
|
|
|
blobs_before = count_blobs(m.net.Proto())
|
|
optim_proto = memonger.share_grad_blobs(
|
|
m.net,
|
|
["name_x/loss1", "name_x/loss2"],
|
|
set(m.param_to_grad.values()),
|
|
"name_x", # "name_x//shared_gradinp_0_shared" if using "name_x/"
|
|
share_activations=True,
|
|
dont_share_blobs=set(['name_x/fc6', 'name_x/fc5']),
|
|
)
|
|
blobs_after = count_blobs(optim_proto)
|
|
self.assertLess(blobs_after, blobs_before)
|
|
self.assertTrue(has_blob(optim_proto, "name_x/fc6"))
|
|
|
|
# Test networks produce exactly same gradients
|
|
data = np.random.randn(batch_size, input_dim).astype(np.float32)
|
|
label = np.random.randint(
|
|
low=0, high=output_dim, size=(batch_size,)).astype(np.int32)
|
|
workspace.RunNetOnce(m.param_init_net)
|
|
workspace.FeedBlob("name_x/data", data)
|
|
workspace.FeedBlob("name_x/label", label)
|
|
workspace.RunNetOnce(m.net)
|
|
loss1 = workspace.FetchBlob("name_x/loss1")
|
|
loss2 = workspace.FetchBlob("name_x/loss2")
|
|
grad = workspace.FetchBlob(str(input_to_grad["name_x/fc1_w"]))
|
|
workspace.RunNetOnce(optim_proto)
|
|
optimized_loss1 = workspace.FetchBlob("name_x/loss1")
|
|
optimized_loss2 = workspace.FetchBlob("name_x/loss2")
|
|
optimized_grad = workspace.FetchBlob(str(input_to_grad["name_x/fc1_w"]))
|
|
np.testing.assert_almost_equal(loss1, optimized_loss1)
|
|
np.testing.assert_almost_equal(loss2, optimized_loss2)
|
|
np.testing.assert_almost_equal(grad, optimized_grad)
|
|
|
|
@given(input_dim=st.integers(min_value=4, max_value=4),
|
|
output_dim=st.integers(min_value=4, max_value=4),
|
|
batch_size=st.integers(min_value=4, max_value=4))
|
|
def test_forward_optim_tree_daggy(self, input_dim, output_dim, batch_size):
|
|
m = cnn.CNNModelHelper()
|
|
m.Proto().type = "dag"
|
|
m.Proto().num_workers = 4
|
|
|
|
with core.NameScope("name_x"):
|
|
fc1 = m.FC("data", "fc1", dim_in=input_dim, dim_out=output_dim)
|
|
fc2 = m.FC(fc1, "fc2", dim_in=output_dim, dim_out=output_dim)
|
|
|
|
fc3 = m.FC(fc2, "fc3", dim_in=output_dim, dim_out=output_dim)
|
|
fc4 = m.FC(fc3, "fc4", dim_in=output_dim, dim_out=output_dim)
|
|
fc5 = m.FC(fc4, "fc5", dim_in=output_dim, dim_out=output_dim)
|
|
|
|
# Branch
|
|
fc3b = m.FC(fc2, "fc3b", dim_in=output_dim, dim_out=output_dim)
|
|
fc4b = m.FC(fc3b, "fc4b", dim_in=output_dim, dim_out=output_dim)
|
|
fc5b = m.FC(fc4b, "fc5b", dim_in=output_dim, dim_out=output_dim)
|
|
|
|
fc5sum = m.Sum([fc5, fc5b], "fc5sum")
|
|
|
|
fc5.Relu([], fc5sum) \
|
|
.Softmax([], "pred1") \
|
|
.LabelCrossEntropy(["label"], ["xent1"]) \
|
|
.AveragedLoss([], "loss1")
|
|
fc6 = m.FC(fc5, "fc6", dim_in=output_dim, dim_out=output_dim)
|
|
fc6.Relu([], fc6) \
|
|
.Softmax([], "pred2") \
|
|
.LabelCrossEntropy(["label"], ["xent2"]) \
|
|
.AveragedLoss([], "loss2")
|
|
|
|
blobs_before = count_blobs(m.net.Proto())
|
|
optim_proto = memonger.optimize_inference_for_dag(
|
|
m.net, ["name_x/data"], "name_x"
|
|
)
|
|
blobs_after = count_blobs(optim_proto)
|
|
self.assertLess(blobs_after, blobs_before)
|
|
|
|
# Test networks produce exactly same results
|
|
data = np.random.randn(batch_size, input_dim).astype(np.float32)
|
|
label = np.random.randint(
|
|
low=0, high=output_dim, size=(batch_size,)).astype(np.int32)
|
|
workspace.RunNetOnce(m.param_init_net)
|
|
workspace.FeedBlob("name_x/data", data)
|
|
workspace.FeedBlob("name_x/label", label)
|
|
workspace.RunNetOnce(m.net)
|
|
loss1 = workspace.FetchBlob("name_x/loss1")
|
|
loss2 = workspace.FetchBlob("name_x/loss2")
|
|
workspace.RunNetOnce(optim_proto)
|
|
optimized_loss1 = workspace.FetchBlob("name_x/loss1")
|
|
optimized_loss2 = workspace.FetchBlob("name_x/loss2")
|
|
np.testing.assert_almost_equal(loss1, optimized_loss1)
|
|
np.testing.assert_almost_equal(loss2, optimized_loss2)
|
|
|
|
def test_topological_sort_longest_path(self):
|
|
m = cnn.CNNModelHelper()
|
|
# 0
|
|
m.Copy("conv0_w_comp", "conv0_w")
|
|
# 1
|
|
conv0 = m.Conv("data", "conv0", 32, 32, 4)
|
|
# 2
|
|
m.Copy("conv2_w", "conv2_w")
|
|
# 3
|
|
m.Conv(conv0, "conv2", 16, 32, 4)
|
|
|
|
g = memonger.compute_interference_graph(m.net.Proto().op)
|
|
|
|
orders_org = memonger.topological_sort_traversal(g)
|
|
orders_gt_org = [2, 0, 1, 3]
|
|
self.assertEqual(orders_gt_org, orders_org)
|
|
|
|
orders = memonger.topological_sort_traversal_longest_path(g)
|
|
# longer path is in front of the shorter one
|
|
orders_gt = [0, 1, 2, 3]
|
|
self.assertEqual(orders_gt, orders)
|
|
|
|
def test_topological_sort_longest_path_multi_target(self):
|
|
# two outputs: conv2 and data4
|
|
m = cnn.CNNModelHelper()
|
|
# 0
|
|
m.Copy("conv0_w_comp", "conv0_w")
|
|
# 1
|
|
conv0 = m.Conv("data", "conv0", 32, 32, 4)
|
|
# 2
|
|
m.Copy("conv2_w", "conv2_w")
|
|
# 3
|
|
m.Conv(conv0, "conv2", 16, 32, 4)
|
|
# 4
|
|
m.Copy("data1", "data2")
|
|
# 5
|
|
m.Copy("data2", "data3")
|
|
|
|
g = memonger.compute_interference_graph(m.net.Proto().op)
|
|
|
|
orders_org = memonger.topological_sort_traversal(g)
|
|
orders_gt_org = [4, 5, 2, 0, 1, 3]
|
|
self.assertEqual(orders_gt_org, orders_org)
|
|
|
|
orders = memonger.topological_sort_traversal_longest_path(g)
|
|
# longer path is in front of the shorter one
|
|
orders_gt = [0, 1, 2, 3, 4, 5]
|
|
self.assertEqual(orders_gt, orders)
|
|
|
|
def test_topological_sort_longest_path_single_node(self):
|
|
# single node
|
|
m = cnn.CNNModelHelper()
|
|
# 0
|
|
m.Copy("conv0_w_comp", "conv0_w")
|
|
|
|
g = memonger.compute_interference_graph(m.net.Proto().op)
|
|
|
|
orders_org = memonger.topological_sort_traversal(g)
|
|
orders_gt_org = [0]
|
|
self.assertEqual(orders_gt_org, orders_org)
|
|
|
|
orders = memonger.topological_sort_traversal_longest_path(g)
|
|
# longer path is in front of the shorter one
|
|
orders_gt = [0]
|
|
self.assertEqual(orders_gt, orders)
|
|
|
|
def test_compute_assignments_greedy(self):
|
|
LiveRange = memonger.LiveRange
|
|
ranges_sorted = [
|
|
('b1', LiveRange(1, 3, 10)),
|
|
('b2', LiveRange(3, 4, 1)),
|
|
('b3', LiveRange(5, 6, 1)),
|
|
('b4', LiveRange(5, 7, 10)),
|
|
]
|
|
assignment_gt = [
|
|
[ranges_sorted[0], ranges_sorted[3]],
|
|
[ranges_sorted[1], ranges_sorted[2]],
|
|
]
|
|
|
|
best = memonger.compute_assignments_greedy(ranges_sorted, None)
|
|
self.assertEqual(memonger.get_memory_usage(best), 11)
|
|
self.assertEqual(best, assignment_gt)
|
|
|
|
def test_compute_assignments_dp(self):
|
|
LiveRange = memonger.LiveRange
|
|
ranges_sorted = [
|
|
('b1', LiveRange(1, 3, 10)),
|
|
('b2', LiveRange(3, 4, 1)),
|
|
('b3', LiveRange(5, 6, 1)),
|
|
('b4', LiveRange(5, 7, 10)),
|
|
]
|
|
|
|
best = memonger.compute_assignments_dp(ranges_sorted, None)
|
|
self.assertEqual(memonger.get_memory_usage(best), 11)
|
|
|
|
def test_compute_assignments_dp1(self):
|
|
LiveRange = memonger.LiveRange
|
|
ranges_sorted = [
|
|
('b1', LiveRange(1, 2, 10)),
|
|
('b2', LiveRange(4, 6, 1)),
|
|
('b3', LiveRange(5, 6, 10)),
|
|
]
|
|
|
|
best = memonger.compute_assignments_dp(ranges_sorted, [])
|
|
self.assertEqual(memonger.get_memory_usage(best), 11)
|