mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
convnet benchmark minor change
This commit is contained in:
parent
8c1bbaa2ab
commit
809d54ee50
|
|
@ -2,14 +2,23 @@
|
|||
Benchmark for common convnets.
|
||||
|
||||
Speed on Titan X, with 10 warmup steps and 10 main steps and with different
|
||||
versions of cudnn, are as follows:
|
||||
versions of cudnn, are as follows (time reported below is per-batch time,
|
||||
forward / forward+backward):
|
||||
|
||||
V3 v4
|
||||
CuDNN V3 CuDNN v4
|
||||
AlexNet 32.5 / 108.0 27.4 / 90.1
|
||||
OverFeat 113.0 / 342.3 91.7 / 276.5
|
||||
Inception 134.5 / 485.8 125.7 / 450.6
|
||||
VGG (batch 64) 200.8 / 650.0 164.1 / 551.7
|
||||
|
||||
Speed on Inception with varied batch sizes and CuDNN v4 is as follows:
|
||||
|
||||
Batch Size Speed per batch Speed per image
|
||||
16 22.8 / 72.7 1.43 / 4.54
|
||||
32 38.0 / 127.5 1.19 / 3.98
|
||||
64 67.2 / 233.6 1.05 / 3.65
|
||||
128 125.7 / 450.6 0.98 / 3.52
|
||||
|
||||
(Note that these numbers involve a "full" backprop, i.e. the gradient
|
||||
with respect to the input image is also computed.)
|
||||
|
||||
|
|
@ -28,6 +37,13 @@ PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
|
|||
PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
|
||||
--batch_size 64 --model VGGA
|
||||
|
||||
for BS in 16 32 64 128; do
|
||||
PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
|
||||
--batch_size $BS --model Inception --forward_only True
|
||||
PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
|
||||
--batch_size $BS --model Inception
|
||||
done
|
||||
|
||||
Note that VGG needs to be run at batch 64 due to memory limit on the backward
|
||||
pass.
|
||||
"""
|
||||
|
|
@ -236,7 +252,19 @@ def Inception(order):
|
|||
|
||||
def Benchmark(model_gen, arg):
|
||||
model, input_size = model_gen(arg.order)
|
||||
for op in model.net._net.op:
|
||||
|
||||
# In order to be able to run everything without feeding more stuff, let's
|
||||
# add the data and label blobs to the parameter initialization net as well.
|
||||
if arg.order == "NCHW":
|
||||
input_shape = [arg.batch_size, 3, input_size, input_size]
|
||||
else:
|
||||
input_shape = [arg.batch_size, input_size, input_size, 3]
|
||||
model.param_init_net.GaussianFill([], "data", shape=input_shape,
|
||||
mean=0.0, std=1.0)
|
||||
model.param_init_net.UniformIntFill([], "label", shape=[arg.batch_size,],
|
||||
min=0, max=999)
|
||||
|
||||
for op in model.net.Proto().op:
|
||||
if op.type == 'Conv':
|
||||
op.engine = 'CUDNN'
|
||||
#op.arg.add().CopyFrom(utils.MakeArgument('ws_nbytes_limit', arg.cudnn_limit))
|
||||
|
|
@ -256,17 +284,6 @@ def Benchmark(model_gen, arg):
|
|||
model.param_init_net.RunAllOnGPU()
|
||||
model.net.RunAllOnGPU()
|
||||
|
||||
workspace.ResetWorkspace()
|
||||
if arg.order == 'NCHW':
|
||||
data_shape = (arg.batch_size, 3, input_size, input_size)
|
||||
else:
|
||||
data_shape = (arg.batch_size, input_size, input_size, 3)
|
||||
device_option = model.net.Proto().device_option
|
||||
workspace.FeedBlob("data", np.random.randn(*data_shape).astype(np.float32),
|
||||
device_option)
|
||||
workspace.FeedBlob("label", np.asarray(range(arg.batch_size)).astype(np.int32),
|
||||
device_option)
|
||||
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.CreateNet(model.net)
|
||||
for i in range(arg.warmup_iterations):
|
||||
|
|
@ -280,7 +297,13 @@ def Benchmark(model_gen, arg):
|
|||
print 'Layer-wise benchmark.'
|
||||
workspace.BenchmarkNet(
|
||||
model.net.Proto().name, 1, arg.iterations, True)
|
||||
|
||||
# Writes out the pbtxt for benchmarks on e.g. Android
|
||||
with open("{0}_init_batch_{1}.pbtxt".format(arg.model, arg.batch_size),
|
||||
"w") as fid:
|
||||
fid.write(str(model.param_init_net.Proto()))
|
||||
with open("{0}.pbtxt".format(arg.model, arg.batch_size),
|
||||
"w") as fid:
|
||||
fid.write(str(model.net.Proto()))
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Caffe2 benchmark.")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user