pytorch/caffe2/python/helpers/tools.py
Yiming Wu 2c8b41e3f3 Adding add_weight_decay and image_input to brew module
Summary:
Adding add_weight_decay and image_input to brew module & remove `getWeights` and `getBias` from CNNModelHelper

With fbgs `useWeights`, the results show that noone but add_weight_decay is using this function. I checked with oculus people, their getWeights is a different function.

kennyhorror Please notice whether this is going to affect you :)

Reviewed By: salexspb

Differential Revision: D4945392

fbshipit-source-id: 4ef350fd81dd40a91847e9f3ebc5421eb564df32
2017-04-25 16:03:58 -07:00

28 lines
937 B
Python

## @package tools
# Module caffe2.python.helpers.tools
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
def image_input(
model, blob_in, blob_out, order="NCHW", use_gpu_transform=False, **kwargs
):
if order == "NCHW":
if (use_gpu_transform):
kwargs['use_gpu_transform'] = 1 if use_gpu_transform else 0
# GPU transform will handle NHWC -> NCHW
data, label = model.net.ImageInput(
blob_in, [blob_out[0], blob_out[1]], **kwargs
)
pass
else:
data, label = model.net.ImageInput(
blob_in, [blob_out[0] + '_nhwc', blob_out[1]], **kwargs
)
data = model.net.NHWC2NCHW(data, blob_out[0])
else:
data, label = model.net.ImageInput(blob_in, blob_out, **kwargs)
return data, label