mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
rename model_helpers to brew. This is a big diff now. I did these things:
1. replace model_helpers with brew:
find . -type f -exec sed -i 's/model_helpers/brew/g' {} +
2. rename model_helpers.py and model_helpers_test.py
3. rename ModelHelpersTest to BrewTest
4. lowercase all the helper functions to distinguish them from single op
5. run my unittests
6. run converge tests
Reviewed By: salexspb
Differential Revision: D4930465
fbshipit-source-id: f420a1b03238df1cbe9f4426e0b9c43a12119661
253 lines
8.2 KiB
Python
253 lines
8.2 KiB
Python
## @package cnn
|
|
# Module caffe2.python.cnn
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import scope, brew
|
|
from caffe2.python.model_helper import ModelHelper
|
|
from caffe2.proto import caffe2_pb2
|
|
|
|
|
|
class CNNModelHelper(ModelHelper):
|
|
"""A helper model so we can write CNN models more easily, without having to
|
|
manually define parameter initializations and operators separately.
|
|
"""
|
|
|
|
def __init__(self, order="NCHW", name=None,
|
|
use_cudnn=True, cudnn_exhaustive_search=False,
|
|
ws_nbytes_limit=None, init_params=True,
|
|
skip_sparse_optim=False,
|
|
param_model=None):
|
|
|
|
super(CNNModelHelper, self).__init__(
|
|
skip_sparse_optim=skip_sparse_optim,
|
|
name="CNN" if name is None else name,
|
|
init_params=init_params,
|
|
param_model=param_model,
|
|
)
|
|
|
|
self.order = order
|
|
self.use_cudnn = use_cudnn
|
|
self.cudnn_exhaustive_search = cudnn_exhaustive_search
|
|
self.ws_nbytes_limit = ws_nbytes_limit
|
|
if self.order != "NHWC" and self.order != "NCHW":
|
|
raise ValueError(
|
|
"Cannot understand the CNN storage order %s." % self.order
|
|
)
|
|
|
|
def GetWeights(self, namescope=None):
|
|
if namescope is None:
|
|
namescope = scope.CurrentNameScope()
|
|
|
|
if namescope == '':
|
|
return self.weights[:]
|
|
else:
|
|
return [w for w in self.weights if w.GetNameScope() == namescope]
|
|
|
|
def GetBiases(self, namescope=None):
|
|
if namescope is None:
|
|
namescope = scope.CurrentNameScope()
|
|
|
|
if namescope == '':
|
|
return self.biases[:]
|
|
else:
|
|
return [b for b in self.biases if b.GetNameScope() == namescope]
|
|
|
|
def ImageInput(
|
|
self, blob_in, blob_out, use_gpu_transform=False, **kwargs
|
|
):
|
|
"""Image Input."""
|
|
if self.order == "NCHW":
|
|
if (use_gpu_transform):
|
|
kwargs['use_gpu_transform'] = 1 if use_gpu_transform else 0
|
|
# GPU transform will handle NHWC -> NCHW
|
|
data, label = self.net.ImageInput(
|
|
blob_in, [blob_out[0], blob_out[1]], **kwargs)
|
|
# data = self.net.Transform(data, blob_out[0], **kwargs)
|
|
pass
|
|
else:
|
|
data, label = self.net.ImageInput(
|
|
blob_in, [blob_out[0] + '_nhwc', blob_out[1]], **kwargs)
|
|
data = self.net.NHWC2NCHW(data, blob_out[0])
|
|
else:
|
|
data, label = self.net.ImageInput(
|
|
blob_in, blob_out, **kwargs)
|
|
return data, label
|
|
|
|
def PadImage(self, blob_in, blob_out, **kwargs):
|
|
self.net.PadImage(blob_in, blob_out, **kwargs)
|
|
|
|
def ConvNd(self, *args, **kwargs):
|
|
return brew.conv_nd(
|
|
self,
|
|
*args,
|
|
use_cudnn=self.use_cudnn,
|
|
order=self.order,
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
**kwargs
|
|
)
|
|
|
|
def Conv(self, *args, **kwargs):
|
|
return brew.conv(
|
|
self,
|
|
*args,
|
|
use_cudnn=self.use_cudnn,
|
|
order=self.order,
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
**kwargs
|
|
)
|
|
|
|
def ConvTranspose(self, *args, **kwargs):
|
|
return brew.conv_transpose(
|
|
self,
|
|
*args,
|
|
use_cudnn=self.use_cudnn,
|
|
order=self.order,
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
**kwargs
|
|
)
|
|
|
|
def GroupConv(self, *args, **kwargs):
|
|
return brew.group_conv(
|
|
self,
|
|
*args,
|
|
use_cudnn=self.use_cudnn,
|
|
order=self.order,
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
**kwargs
|
|
)
|
|
|
|
def GroupConv_Deprecated(self, *args, **kwargs):
|
|
return brew.group_conv_deprecated(
|
|
self,
|
|
*args,
|
|
use_cudnn=self.use_cudnn,
|
|
order=self.order,
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
**kwargs
|
|
)
|
|
|
|
def FC(self, *args, **kwargs):
|
|
return brew.fc(self, *args, **kwargs)
|
|
|
|
def PackedFC(self, *args, **kwargs):
|
|
return brew.packed_fc(self, *args, **kwargs)
|
|
|
|
def FC_Prune(self, *args, **kwargs):
|
|
return brew.fc_prune(self, *args, **kwargs)
|
|
|
|
def FC_Decomp(self, *args, **kwargs):
|
|
return brew.fc_decomp(self, *args, **kwargs)
|
|
|
|
def FC_Sparse(self, *args, **kwargs):
|
|
return brew.fc_sparse(self, *args, **kwargs)
|
|
|
|
def Dropout(self, *args, **kwargs):
|
|
return brew.dropout(self, *args, **kwargs)
|
|
|
|
def LRN(self, *args, **kwargs):
|
|
return brew.lrn(self, *args, **kwargs)
|
|
|
|
def Softmax(self, *args, **kwargs):
|
|
return brew.softmax(self, *args, use_cudnn=self.use_cudnn,
|
|
**kwargs)
|
|
|
|
def SpatialBN(self, *args, **kwargs):
|
|
return brew.spatial_bn(self, *args, order=self.order, **kwargs)
|
|
|
|
def InstanceNorm(self, *args, **kwargs):
|
|
return brew.instance_norm(self, *args, order=self.order,
|
|
**kwargs)
|
|
|
|
def Relu(self, *args, **kwargs):
|
|
return brew.relu(self, *args, order=self.order,
|
|
use_cudnn=self.use_cudnn, **kwargs)
|
|
|
|
def PRelu(self, *args, **kwargs):
|
|
return brew.prelu(self, *args, **kwargs)
|
|
|
|
def Concat(self, *args, **kwargs):
|
|
return brew.concat(self, *args, order=self.order, **kwargs)
|
|
|
|
def DepthConcat(self, *args, **kwargs):
|
|
"""The old depth concat function - we should move to use concat."""
|
|
print("DepthConcat is deprecated. use Concat instead.")
|
|
return self.Concat(*args, **kwargs)
|
|
|
|
def Sum(self, *args, **kwargs):
|
|
return brew.sum(self, *args, **kwargs)
|
|
|
|
def Transpose(self, *args, **kwargs):
|
|
return brew.transpose(self, *args, use_cudnn=self.use_cudnn,
|
|
**kwargs)
|
|
|
|
def Iter(self, *args, **kwargs):
|
|
return brew.iter(self, *args, **kwargs)
|
|
|
|
def Accuracy(self, *args, **kwargs):
|
|
return brew.accuracy(self, *args, **kwargs)
|
|
|
|
def MaxPool(self, *args, **kwargs):
|
|
return brew.max_pool(self, *args, use_cudnn=self.use_cudnn,
|
|
order=self.order, **kwargs)
|
|
|
|
def AveragePool(self, *args, **kwargs):
|
|
return brew.average_pool(self, *args, use_cudnn=self.use_cudnn,
|
|
order=self.order, **kwargs)
|
|
|
|
@property
|
|
def XavierInit(self):
|
|
return ('XavierFill', {})
|
|
|
|
def ConstantInit(self, value):
|
|
return ('ConstantFill', dict(value=value))
|
|
|
|
@property
|
|
def MSRAInit(self):
|
|
return ('MSRAFill', {})
|
|
|
|
@property
|
|
def ZeroInit(self):
|
|
return ('ConstantFill', {})
|
|
|
|
def AddWeightDecay(self, weight_decay):
|
|
"""Adds a decay to weights in the model.
|
|
|
|
This is a form of L2 regularization.
|
|
|
|
Args:
|
|
weight_decay: strength of the regularization
|
|
"""
|
|
if weight_decay <= 0.0:
|
|
return
|
|
wd = self.param_init_net.ConstantFill([], 'wd', shape=[1],
|
|
value=weight_decay)
|
|
ONE = self.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0)
|
|
for param in self.GetWeights():
|
|
# Equivalent to: grad += wd * param
|
|
grad = self.param_to_grad[param]
|
|
self.net.WeightedSum(
|
|
[grad, ONE, param, wd],
|
|
grad,
|
|
)
|
|
|
|
@property
|
|
def CPU(self):
|
|
device_option = caffe2_pb2.DeviceOption()
|
|
device_option.device_type = caffe2_pb2.CPU
|
|
return device_option
|
|
|
|
@property
|
|
def GPU(self, gpu_id=0):
|
|
device_option = caffe2_pb2.DeviceOption()
|
|
device_option.device_type = caffe2_pb2.CUDA
|
|
device_option.cuda_gpu_id = gpu_id
|
|
return device_option
|