mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: prigoyal sharply noticed a bug in the Resnet models: we have not been checkpointing, nor synchronizing between gpus, the moving average and variance computed by the SpatialBN ops. Particularly the first problen is serious, since models starting from checkpoint would have started from a null-state for SpatialBN. Not synchronizing with the data parallel model is less tragic since each GPU should see very similar data. Thus I propose keeping track of "computed params", i.e params that are computed from data but not optimized. I don't know if there are other examples, but SpatialBN's moving avg and var definitely are one. - I modified the checkpointign for xray model to store those blobs + also ensure the synchronization of those blobs - I modified data parallel model to broadcast those params from gpu0. I first tried averaging, but hit some NCCL deadlocks ... :( Differential Revision: D4281265 fbshipit-source-id: 933311afeec4b7e9344a13cf2d38aa939c50ac31
269 lines
8.4 KiB
Python
269 lines
8.4 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
|
|
'''
|
|
Utility for creating ResNets
|
|
See "Deep Residual Learning for Image Recognition" by He, Zhang et. al. 2015
|
|
'''
|
|
|
|
|
|
class ResNetBuilder():
|
|
'''
|
|
Helper class for constructing residual blocks.
|
|
'''
|
|
def __init__(self, model, prev_blob, is_test):
|
|
self.model = model
|
|
self.comp_count = 0
|
|
self.comp_idx = 0
|
|
self.prev_blob = prev_blob
|
|
self.is_test = is_test
|
|
|
|
def add_conv(self, in_filters, out_filters, kernel, stride=1, pad=0):
|
|
self.comp_idx += 1
|
|
self.prev_blob = self.model.Conv(
|
|
self.prev_blob,
|
|
'comp_%d_conv_%d' % (self.comp_count, self.comp_idx),
|
|
in_filters,
|
|
out_filters,
|
|
weight_init=("MSRAFill", {}),
|
|
kernel=kernel,
|
|
stride=stride,
|
|
pad=pad
|
|
)
|
|
return self.prev_blob
|
|
|
|
def add_relu(self):
|
|
self.prev_blob = self.model.Relu(
|
|
self.prev_blob,
|
|
'comp_%d_relu_%d' % (self.comp_count, self.comp_idx)
|
|
)
|
|
return self.prev_blob
|
|
|
|
def add_spatial_bn(self, num_filters):
|
|
self.prev_blob = self.model.SpatialBN(
|
|
self.prev_blob,
|
|
'comp_%d_spatbn_%d' % (self.comp_count, self.comp_idx),
|
|
num_filters,
|
|
epsilon=1e-3,
|
|
is_test=self.is_test,
|
|
)
|
|
return self.prev_blob
|
|
|
|
'''
|
|
Add a "bottleneck" component as decribed in He et. al. Figure 3 (right)
|
|
'''
|
|
def add_bottleneck(
|
|
self,
|
|
input_filters, # num of feature maps from preceding layer
|
|
base_filters, # num of filters internally in the component
|
|
output_filters, # num of feature maps to output
|
|
down_sampling=False,
|
|
spatial_batch_norm=True,
|
|
):
|
|
self.comp_idx = 0
|
|
shortcut_blob = self.prev_blob
|
|
|
|
# 1x1
|
|
self.add_conv(
|
|
input_filters,
|
|
base_filters,
|
|
kernel=1,
|
|
stride=1
|
|
)
|
|
|
|
if spatial_batch_norm:
|
|
self.add_spatial_bn(base_filters)
|
|
|
|
self.add_relu()
|
|
|
|
# 3x3 (note the pad, required for keeping dimensions)
|
|
self.add_conv(
|
|
base_filters,
|
|
base_filters,
|
|
kernel=3,
|
|
stride=(1 if down_sampling is False else 2),
|
|
pad=1
|
|
)
|
|
|
|
if spatial_batch_norm:
|
|
self.add_spatial_bn(base_filters)
|
|
self.add_relu()
|
|
|
|
# 1x1
|
|
last_conv = self.add_conv(base_filters, output_filters, kernel=1)
|
|
if spatial_batch_norm:
|
|
last_conv = self.add_spatial_bn(output_filters)
|
|
|
|
# Summation with input signal (shortcut)
|
|
# If we need to increase dimensions (feature maps), need to
|
|
# do do a projection for the short cut
|
|
if (output_filters > input_filters):
|
|
shortcut_blob = self.model.Conv(
|
|
shortcut_blob,
|
|
'shortcut_projection_%d' % self.comp_count,
|
|
input_filters,
|
|
output_filters,
|
|
weight_init=("MSRAFill", {}),
|
|
kernel=1,
|
|
stride=(1 if down_sampling is False else 2)
|
|
)
|
|
if spatial_batch_norm:
|
|
shortcut_blob = self.model.SpatialBN(
|
|
shortcut_blob,
|
|
'shortcut_projection_%d_spatbn' % self.comp_count,
|
|
output_filters,
|
|
epsilon=1e-3,
|
|
is_test=self.is_test,
|
|
)
|
|
|
|
self.prev_blob = self.model.Sum(
|
|
[shortcut_blob, last_conv],
|
|
'comp_%d_sum_%d' % (self.comp_count, self.comp_idx)
|
|
)
|
|
self.comp_idx += 1
|
|
self.add_relu()
|
|
|
|
# Keep track of number of high level components if this ResNetBuilder
|
|
self.comp_count += 1
|
|
|
|
def add_simple_block(
|
|
self,
|
|
input_filters,
|
|
num_filters,
|
|
down_sampling=False,
|
|
spatial_batch_norm=True
|
|
):
|
|
self.comp_idx = 0
|
|
shortcut_blob = self.prev_blob
|
|
|
|
# 3x3
|
|
self.add_conv(
|
|
input_filters,
|
|
num_filters,
|
|
kernel=3,
|
|
stride=(1 if down_sampling is False else 2),
|
|
pad=1
|
|
)
|
|
|
|
if spatial_batch_norm:
|
|
self.add_spatial_bn(num_filters)
|
|
self.add_relu()
|
|
|
|
last_conv = self.add_conv(num_filters, num_filters, kernel=3, pad=1)
|
|
if spatial_batch_norm:
|
|
last_conv = self.add_spatial_bn(num_filters)
|
|
|
|
# Increase of dimensions, need a projection for the shortcut
|
|
if (num_filters != input_filters):
|
|
shortcut_blob = self.model.Conv(
|
|
shortcut_blob,
|
|
'shortcut_projection_%d' % self.comp_count,
|
|
input_filters,
|
|
num_filters,
|
|
weight_init=("MSRAFill", {}),
|
|
kernel=1,
|
|
stride=(1 if down_sampling is False else 2),
|
|
)
|
|
if spatial_batch_norm:
|
|
shortcut_blob = self.model.SpatialBN(
|
|
shortcut_blob,
|
|
'shortcut_projection_%d_spatbn' % self.comp_count,
|
|
num_filters,
|
|
epsilon=1e-3,
|
|
is_test=self.is_test,
|
|
)
|
|
|
|
self.prev_blob = self.model.Sum(
|
|
[shortcut_blob, last_conv],
|
|
'comp_%d_sum_%d' % (self.comp_count, self.comp_idx)
|
|
)
|
|
self.comp_idx += 1
|
|
self.add_relu()
|
|
|
|
# Keep track of number of high level components if this ResNetBuilder
|
|
self.comp_count += 1
|
|
|
|
|
|
def create_resnet50(model, data, num_input_channels,
|
|
num_labels, label=None, is_test=False):
|
|
# conv1 + maxpool
|
|
model.Conv(data, 'conv1', num_input_channels, 64, weight_init=("MSRAFill", {}), kernel=7, stride=2, pad=3)
|
|
model.SpatialBN('conv1', 'conv1_spatbn', 64, epsilon=1e-3, is_test=is_test)
|
|
model.Relu('conv1_spatbn', 'relu1')
|
|
model.MaxPool('relu1', 'pool1', kernel=3, stride=2)
|
|
|
|
# Residual blocks...
|
|
builder = ResNetBuilder(model, 'pool1', is_test=is_test)
|
|
|
|
# conv2_x (ref Table 1 in He et al. (2015))
|
|
builder.add_bottleneck(64, 64, 256)
|
|
builder.add_bottleneck(256, 64, 256)
|
|
builder.add_bottleneck(256, 64, 256)
|
|
|
|
# conv3_x
|
|
builder.add_bottleneck(256, 128, 512, down_sampling=True)
|
|
for i in range(1, 4):
|
|
builder.add_bottleneck(512, 128, 512)
|
|
|
|
# conv4_x
|
|
builder.add_bottleneck(512, 256, 1024, down_sampling=True)
|
|
for i in range(1, 6):
|
|
builder.add_bottleneck(1024, 256, 1024)
|
|
|
|
# conv5_x
|
|
builder.add_bottleneck(1024, 512, 2048, down_sampling=True)
|
|
builder.add_bottleneck(2048, 512, 2048)
|
|
builder.add_bottleneck(2048, 512, 2048)
|
|
|
|
# Final layers
|
|
model.AveragePool(builder.prev_blob, 'final_avg', kernel=7, stride=1)
|
|
|
|
# Final dimension of the "image" is reduced to 7x7
|
|
model.FC('final_avg', 'pred', 2048, num_labels)
|
|
|
|
# If we create model for training, use softmax-with-loss
|
|
if (label is not None):
|
|
(softmax, loss) = model.SoftmaxWithLoss(
|
|
["pred", label],
|
|
["softmax", "loss"],
|
|
)
|
|
|
|
return (softmax, loss)
|
|
else:
|
|
# For inference, we just return softmax
|
|
return model.Softmax("pred", "softmax")
|
|
|
|
def create_resnet_32x32(
|
|
model, data, num_input_channels, num_groups, num_labels, is_test=False
|
|
):
|
|
'''
|
|
Create residual net for smaller images (sec 4.2 of He et. al (2015))
|
|
num_groups = 'n' in the paper
|
|
'''
|
|
# conv1 + maxpool
|
|
model.Conv(data, 'conv1', num_input_channels, 16, kernel=3, stride=1)
|
|
model.SpatialBN('conv1', 'conv1_spatbn', 16, epsilon=1e-3, is_test=is_test)
|
|
model.Relu('conv1_spatbn', 'relu1')
|
|
|
|
# Number of blocks as described in sec 4.2
|
|
filters = [16, 32, 64]
|
|
|
|
builder = ResNetBuilder(model, 'relu1', is_test=is_test)
|
|
prev_filters = 16
|
|
for groupidx in range(0, 3):
|
|
for blockidx in range(0, 2 * num_groups):
|
|
builder.add_simple_block(
|
|
prev_filters if blockidx == 0 else filters[groupidx],
|
|
filters[groupidx],
|
|
down_sampling=(True if blockidx == 0 and
|
|
groupidx > 0 else False))
|
|
prev_filters = filters[groupidx]
|
|
|
|
# Final layers
|
|
model.AveragePool(builder.prev_blob, 'final_avg', kernel=8, stride=1)
|
|
model.FC('final_avg', 'pred', 64, num_labels)
|
|
softmax = model.Softmax('pred', 'softmax')
|
|
return softmax
|