mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add Algebra and train helpers and proxy them to CNNMH
Summary: Add Algebra and train helpers and proxy them to CNNMH Reviewed By: salexspb Differential Revision: D4855040 fbshipit-source-id: d948ea913f674a6e47c4b72629a2d33253cb3130
This commit is contained in:
parent
b2e94a7bcb
commit
8de1ce57d2
|
|
@ -386,6 +386,20 @@ class CNNModelHelper(ModelHelperBase):
|
|||
def DepthConcat(self, *args, **kwargs):
|
||||
return model_helpers.DepthConcat(self, *args, **kwargs)
|
||||
|
||||
def Sum(self, *args, **kwargs):
|
||||
return model_helpers.Sum(self, *args, **kwargs)
|
||||
|
||||
def Transpose(self, *args, **kwargs):
|
||||
return model_helpers.Transpose(self, *args, order=self.order,
|
||||
use_cudnn=self.use_cudnn,
|
||||
**kwargs)
|
||||
|
||||
def Iter(self, *args, **kwargs):
|
||||
return model_helpers.Iter(self, *args, **kwargs)
|
||||
|
||||
def Accuracy(self, *args, **kwargs):
|
||||
return model_helpers.Accuracy(self, *args, **kwargs)
|
||||
|
||||
def MaxPool(self, *args, **kwargs):
|
||||
return model_helpers.MaxPool(self, *args, use_cudnn=self.use_cudnn,
|
||||
**kwargs)
|
||||
|
|
@ -394,43 +408,6 @@ class CNNModelHelper(ModelHelperBase):
|
|||
return model_helpers.AveragePool(self, *args, use_cudnn=self.use_cudnn,
|
||||
**kwargs)
|
||||
|
||||
def Transpose(self, blob_in, blob_out, **kwargs):
|
||||
"""Transpose."""
|
||||
if self.use_cudnn:
|
||||
kwargs['engine'] = 'CUDNN'
|
||||
return self.net.Transpose(blob_in, blob_out, **kwargs)
|
||||
|
||||
def Sum(self, blob_in, blob_out, **kwargs):
|
||||
"""Sum"""
|
||||
return self.net.Sum(blob_in, blob_out, **kwargs)
|
||||
|
||||
def Iter(self, blob_out, **kwargs):
|
||||
if 'device_option' in kwargs:
|
||||
del kwargs['device_option']
|
||||
self.param_init_net.ConstantFill(
|
||||
[], blob_out, shape=[1], value=0, dtype=core.DataType.INT64,
|
||||
device_option=core.DeviceOption(caffe2_pb2.CPU, 0),
|
||||
**kwargs)
|
||||
return self.net.Iter(blob_out, blob_out, **kwargs)
|
||||
|
||||
def Accuracy(self, blob_in, blob_out, **kwargs):
|
||||
dev = kwargs['device_option'] if 'device_option' in kwargs \
|
||||
else scope.CurrentDeviceScope()
|
||||
is_cpu = dev is None or dev.device_type == caffe2_pb2.CPU
|
||||
|
||||
# We support top_k > 1 only on CPU
|
||||
if not is_cpu and 'top_k' in kwargs and kwargs['top_k'] > 1:
|
||||
pred_host = self.net.CopyGPUToCPU(blob_in[0], blob_in[0] + "_host")
|
||||
label_host = self.net.CopyGPUToCPU(blob_in[1], blob_in[1] + "_host")
|
||||
|
||||
# Now use the Host version of the accuracy op
|
||||
self.net.Accuracy([pred_host, label_host],
|
||||
blob_out,
|
||||
device_option=core.DeviceOption(caffe2_pb2.CPU, 0),
|
||||
**kwargs)
|
||||
else:
|
||||
self.net.Accuracy(blob_in, blob_out)
|
||||
|
||||
def PadImage(
|
||||
self, blob_in, blob_out, **kwargs
|
||||
):
|
||||
|
|
|
|||
21
caffe2/python/helpers/algebra.py
Normal file
21
caffe2/python/helpers/algebra.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
## @package algebra
|
||||
# Module caffe2.python.helpers.algebra
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
__all__ = [
|
||||
'Transpose',
|
||||
'Sum',
|
||||
]
|
||||
|
||||
|
||||
def Transpose(model, blob_in, blob_out, **kwargs):
|
||||
"""Transpose."""
|
||||
return model.net.Transpose(blob_in, blob_out, **kwargs)
|
||||
|
||||
|
||||
def Sum(model, blob_in, blob_out, **kwargs):
|
||||
"""Sum"""
|
||||
return model.net.Sum(blob_in, blob_out, **kwargs)
|
||||
43
caffe2/python/helpers/train.py
Normal file
43
caffe2/python/helpers/train.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
## @package train
|
||||
# Module caffe2.python.helpers.train
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python import core, scope
|
||||
from caffe2.proto import caffe2_pb2
|
||||
|
||||
__all__ = [
|
||||
'Iter',
|
||||
'Accuracy',
|
||||
]
|
||||
|
||||
|
||||
def Iter(model, blob_out, **kwargs):
|
||||
if 'device_option' in kwargs:
|
||||
del kwargs['device_option']
|
||||
model.param_init_net.ConstantFill(
|
||||
[], blob_out, shape=[1], value=0, dtype=core.DataType.INT64,
|
||||
device_option=core.DeviceOption(caffe2_pb2.CPU, 0),
|
||||
**kwargs)
|
||||
return model.net.Iter(blob_out, blob_out, **kwargs)
|
||||
|
||||
|
||||
def Accuracy(model, blob_in, blob_out, **kwargs):
|
||||
dev = kwargs['device_option'] if 'device_option' in kwargs \
|
||||
else scope.CurrentDeviceScope()
|
||||
is_cpu = dev is None or dev.device_type == caffe2_pb2.CPU
|
||||
|
||||
# We support top_k > 1 only on CPU
|
||||
if not is_cpu and 'top_k' in kwargs and kwargs['top_k'] > 1:
|
||||
pred_host = model.net.CopyGPUToCPU(blob_in[0], blob_in[0] + "_host")
|
||||
label_host = model.net.CopyGPUToCPU(blob_in[1], blob_in[1] + "_host")
|
||||
|
||||
# Now use the Host version of the accuracy op
|
||||
model.net.Accuracy([pred_host, label_host],
|
||||
blob_out,
|
||||
device_option=core.DeviceOption(caffe2_pb2.CPU, 0),
|
||||
**kwargs)
|
||||
else:
|
||||
model.net.Accuracy(blob_in, blob_out)
|
||||
|
|
@ -13,3 +13,5 @@ from caffe2.python.helpers.pooling import *
|
|||
from caffe2.python.helpers.normalization import *
|
||||
from caffe2.python.helpers.nonlinearity import *
|
||||
from caffe2.python.helpers.array_helpers import *
|
||||
from caffe2.python.helpers.algebra import *
|
||||
from caffe2.python.helpers.train import *
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user