mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This allows users to add an arbitrary of additional outputs to ImageInputOp. These are populated by reading additional TensorProto values from the TensorProtos from the DBReader, and converting them into Tensors. Similar to labels, only ints and floats are supported, and multiple values are supported. Reviewed By: panshen1 Differential Revision: D5502019 fbshipit-source-id: 5a8b61b3a8549272a112e8e02cd613d8f9a271ba
33 lines
1.1 KiB
Python
33 lines
1.1 KiB
Python
## @package tools
|
|
# Module caffe2.python.helpers.tools
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
|
|
def image_input(
|
|
model, blob_in, blob_out, order="NCHW", use_gpu_transform=False, **kwargs
|
|
):
|
|
if order == "NCHW":
|
|
if (use_gpu_transform):
|
|
kwargs['use_gpu_transform'] = 1 if use_gpu_transform else 0
|
|
# GPU transform will handle NHWC -> NCHW
|
|
outputs = model.net.ImageInput(blob_in, blob_out, **kwargs)
|
|
pass
|
|
else:
|
|
outputs = model.net.ImageInput(
|
|
blob_in, [blob_out[0] + '_nhwc'] + blob_out[1:], **kwargs
|
|
)
|
|
outputs_list = list(outputs)
|
|
outputs_list[0] = model.net.NHWC2NCHW(outputs_list[0], blob_out[0])
|
|
outputs = tuple(outputs_list)
|
|
else:
|
|
outputs = model.net.ImageInput(blob_in, blob_out, **kwargs)
|
|
return outputs
|
|
|
|
|
|
def video_input(model, blob_in, blob_out, **kwargs):
|
|
data, label = model.net.VideoInput(blob_in, blob_out, **kwargs)
|
|
return data, label
|