pytorch/caffe2/python/helpers/tools.py
Kevin Wilfong 60cb55461e Caffe2: Support additional outputs in ImageInputOp
Summary: This allows users to add an arbitrary of additional outputs to ImageInputOp.  These are populated by reading additional TensorProto values from the TensorProtos from the DBReader, and converting them into Tensors.  Similar to labels, only ints and floats are supported, and multiple values are supported.

Reviewed By: panshen1

Differential Revision: D5502019

fbshipit-source-id: 5a8b61b3a8549272a112e8e02cd613d8f9a271ba
2017-08-01 14:36:05 -07:00

33 lines
1.1 KiB
Python

## @package tools
# Module caffe2.python.helpers.tools
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
def image_input(
model, blob_in, blob_out, order="NCHW", use_gpu_transform=False, **kwargs
):
if order == "NCHW":
if (use_gpu_transform):
kwargs['use_gpu_transform'] = 1 if use_gpu_transform else 0
# GPU transform will handle NHWC -> NCHW
outputs = model.net.ImageInput(blob_in, blob_out, **kwargs)
pass
else:
outputs = model.net.ImageInput(
blob_in, [blob_out[0] + '_nhwc'] + blob_out[1:], **kwargs
)
outputs_list = list(outputs)
outputs_list[0] = model.net.NHWC2NCHW(outputs_list[0], blob_out[0])
outputs = tuple(outputs_list)
else:
outputs = model.net.ImageInput(blob_in, blob_out, **kwargs)
return outputs
def video_input(model, blob_in, blob_out, **kwargs):
data, label = model.net.VideoInput(blob_in, blob_out, **kwargs)
return data, label