pytorch/caffe2/python/cnn.py
Yiming Wu 0bb558716a rename model_helpers to brew and lowercase all helper functions
Summary:
rename model_helpers to brew. This is a big diff now. I did these things:

1. replace model_helpers with brew:

  find . -type f -exec sed -i 's/model_helpers/brew/g' {} +

2. rename model_helpers.py and model_helpers_test.py
3. rename ModelHelpersTest to BrewTest
4. lowercase all the helper functions to distinguish them from single op
5. run my unittests
6. run converge tests

Reviewed By: salexspb

Differential Revision: D4930465

fbshipit-source-id: f420a1b03238df1cbe9f4426e0b9c43a12119661
2017-04-24 15:52:26 -07:00

253 lines
8.2 KiB
Python

## @package cnn
# Module caffe2.python.cnn
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import scope, brew
from caffe2.python.model_helper import ModelHelper
from caffe2.proto import caffe2_pb2
class CNNModelHelper(ModelHelper):
"""A helper model so we can write CNN models more easily, without having to
manually define parameter initializations and operators separately.
"""
def __init__(self, order="NCHW", name=None,
use_cudnn=True, cudnn_exhaustive_search=False,
ws_nbytes_limit=None, init_params=True,
skip_sparse_optim=False,
param_model=None):
super(CNNModelHelper, self).__init__(
skip_sparse_optim=skip_sparse_optim,
name="CNN" if name is None else name,
init_params=init_params,
param_model=param_model,
)
self.order = order
self.use_cudnn = use_cudnn
self.cudnn_exhaustive_search = cudnn_exhaustive_search
self.ws_nbytes_limit = ws_nbytes_limit
if self.order != "NHWC" and self.order != "NCHW":
raise ValueError(
"Cannot understand the CNN storage order %s." % self.order
)
def GetWeights(self, namescope=None):
if namescope is None:
namescope = scope.CurrentNameScope()
if namescope == '':
return self.weights[:]
else:
return [w for w in self.weights if w.GetNameScope() == namescope]
def GetBiases(self, namescope=None):
if namescope is None:
namescope = scope.CurrentNameScope()
if namescope == '':
return self.biases[:]
else:
return [b for b in self.biases if b.GetNameScope() == namescope]
def ImageInput(
self, blob_in, blob_out, use_gpu_transform=False, **kwargs
):
"""Image Input."""
if self.order == "NCHW":
if (use_gpu_transform):
kwargs['use_gpu_transform'] = 1 if use_gpu_transform else 0
# GPU transform will handle NHWC -> NCHW
data, label = self.net.ImageInput(
blob_in, [blob_out[0], blob_out[1]], **kwargs)
# data = self.net.Transform(data, blob_out[0], **kwargs)
pass
else:
data, label = self.net.ImageInput(
blob_in, [blob_out[0] + '_nhwc', blob_out[1]], **kwargs)
data = self.net.NHWC2NCHW(data, blob_out[0])
else:
data, label = self.net.ImageInput(
blob_in, blob_out, **kwargs)
return data, label
def PadImage(self, blob_in, blob_out, **kwargs):
self.net.PadImage(blob_in, blob_out, **kwargs)
def ConvNd(self, *args, **kwargs):
return brew.conv_nd(
self,
*args,
use_cudnn=self.use_cudnn,
order=self.order,
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
ws_nbytes_limit=self.ws_nbytes_limit,
**kwargs
)
def Conv(self, *args, **kwargs):
return brew.conv(
self,
*args,
use_cudnn=self.use_cudnn,
order=self.order,
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
ws_nbytes_limit=self.ws_nbytes_limit,
**kwargs
)
def ConvTranspose(self, *args, **kwargs):
return brew.conv_transpose(
self,
*args,
use_cudnn=self.use_cudnn,
order=self.order,
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
ws_nbytes_limit=self.ws_nbytes_limit,
**kwargs
)
def GroupConv(self, *args, **kwargs):
return brew.group_conv(
self,
*args,
use_cudnn=self.use_cudnn,
order=self.order,
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
ws_nbytes_limit=self.ws_nbytes_limit,
**kwargs
)
def GroupConv_Deprecated(self, *args, **kwargs):
return brew.group_conv_deprecated(
self,
*args,
use_cudnn=self.use_cudnn,
order=self.order,
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
ws_nbytes_limit=self.ws_nbytes_limit,
**kwargs
)
def FC(self, *args, **kwargs):
return brew.fc(self, *args, **kwargs)
def PackedFC(self, *args, **kwargs):
return brew.packed_fc(self, *args, **kwargs)
def FC_Prune(self, *args, **kwargs):
return brew.fc_prune(self, *args, **kwargs)
def FC_Decomp(self, *args, **kwargs):
return brew.fc_decomp(self, *args, **kwargs)
def FC_Sparse(self, *args, **kwargs):
return brew.fc_sparse(self, *args, **kwargs)
def Dropout(self, *args, **kwargs):
return brew.dropout(self, *args, **kwargs)
def LRN(self, *args, **kwargs):
return brew.lrn(self, *args, **kwargs)
def Softmax(self, *args, **kwargs):
return brew.softmax(self, *args, use_cudnn=self.use_cudnn,
**kwargs)
def SpatialBN(self, *args, **kwargs):
return brew.spatial_bn(self, *args, order=self.order, **kwargs)
def InstanceNorm(self, *args, **kwargs):
return brew.instance_norm(self, *args, order=self.order,
**kwargs)
def Relu(self, *args, **kwargs):
return brew.relu(self, *args, order=self.order,
use_cudnn=self.use_cudnn, **kwargs)
def PRelu(self, *args, **kwargs):
return brew.prelu(self, *args, **kwargs)
def Concat(self, *args, **kwargs):
return brew.concat(self, *args, order=self.order, **kwargs)
def DepthConcat(self, *args, **kwargs):
"""The old depth concat function - we should move to use concat."""
print("DepthConcat is deprecated. use Concat instead.")
return self.Concat(*args, **kwargs)
def Sum(self, *args, **kwargs):
return brew.sum(self, *args, **kwargs)
def Transpose(self, *args, **kwargs):
return brew.transpose(self, *args, use_cudnn=self.use_cudnn,
**kwargs)
def Iter(self, *args, **kwargs):
return brew.iter(self, *args, **kwargs)
def Accuracy(self, *args, **kwargs):
return brew.accuracy(self, *args, **kwargs)
def MaxPool(self, *args, **kwargs):
return brew.max_pool(self, *args, use_cudnn=self.use_cudnn,
order=self.order, **kwargs)
def AveragePool(self, *args, **kwargs):
return brew.average_pool(self, *args, use_cudnn=self.use_cudnn,
order=self.order, **kwargs)
@property
def XavierInit(self):
return ('XavierFill', {})
def ConstantInit(self, value):
return ('ConstantFill', dict(value=value))
@property
def MSRAInit(self):
return ('MSRAFill', {})
@property
def ZeroInit(self):
return ('ConstantFill', {})
def AddWeightDecay(self, weight_decay):
"""Adds a decay to weights in the model.
This is a form of L2 regularization.
Args:
weight_decay: strength of the regularization
"""
if weight_decay <= 0.0:
return
wd = self.param_init_net.ConstantFill([], 'wd', shape=[1],
value=weight_decay)
ONE = self.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0)
for param in self.GetWeights():
# Equivalent to: grad += wd * param
grad = self.param_to_grad[param]
self.net.WeightedSum(
[grad, ONE, param, wd],
grad,
)
@property
def CPU(self):
device_option = caffe2_pb2.DeviceOption()
device_option.device_type = caffe2_pb2.CPU
return device_option
@property
def GPU(self, gpu_id=0):
device_option = caffe2_pb2.DeviceOption()
device_option.device_type = caffe2_pb2.CUDA
device_option.cuda_gpu_id = gpu_id
return device_option