pytorch/caffe2/python/examples/lmdb_create_example.py
Aapo Kyrola e65eeff665 LMDB example
Summary:
This examples writes a LMDB database of image data and labels (random). Then it reads them using Caffe2's TensorProtosDBINput and validates the checksums match. This example shows how to coerce image data into TensorProtos and be happy.

Before there was no clear example how to create databases for Caffe2.

Differential Revision: D4263614

fbshipit-source-id: 21e08066899095b4efcc2d23dbc3ede81e75914a
2016-12-05 11:53:26 -08:00

107 lines
3.0 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import numpy as np
import lmdb
from caffe2.proto import caffe2_pb2
from caffe2.python import workspace, cnn
'''
Simple example to create an lmdb database of random image data and labels.
This can be used a skeleton to write your own data import.
It also runs a dummy-model with Caffe2 that reads the data and
validates the checksum is same.
'''
def create_db(output_file):
print(">>> Write database...")
LMDB_MAP_SIZE = 1 << 40 # MODIFY
env = lmdb.open(output_file, map_size=LMDB_MAP_SIZE)
checksum = 0
with env.begin(write=True) as txn:
for j in range(0, 128):
# MODIFY: add your own data reader / creator
label = j % 10
width = 64
height = 32
img_data = np.random.rand(3, width, height)
# ...
# Create TensorProtos
tensor_protos = caffe2_pb2.TensorProtos()
img_tensor = tensor_protos.protos.add()
img_tensor.dims.extend(img_data.shape)
img_tensor.data_type = 1
flatten_img = img_data.reshape(np.prod(img_data.shape))
img_tensor.float_data.extend(flatten_img)
label_tensor = tensor_protos.protos.add()
label_tensor.data_type = 2
label_tensor.int32_data.append(label)
txn.put(
'{}'.format(j).encode('ascii'),
tensor_protos.SerializeToString()
)
checksum += np.sum(img_data) * label
if (j % 16 == 0):
print("Inserted {} rows".format(j))
print("Checksum/write: {}".format(int(checksum)))
return checksum
def read_db_with_caffe2(db_file, expected_checksum):
print(">>> Read database...")
model = cnn.CNNModelHelper(
order="NCHW", name="lmdbtest")
batch_size = 32
data, label = model.TensorProtosDBInput(
[], ["data", "label"], batch_size=batch_size,
db=db_file, db_type="lmdb")
checksum = 0
workspace.RunNetOnce(model.param_init_net)
workspace.CreateNet(model.net)
for batch_idx in range(0, 4):
workspace.RunNet(model.net.Proto().name)
img_datas = workspace.FetchBlob("data")
labels = workspace.FetchBlob("label")
for j in range(batch_size):
checksum += np.sum(img_datas[j, :]) * labels[j]
print("Checksum/read: {}".format(int(checksum)))
assert np.abs(expected_checksum - checksum < 0.1), \
"Read/write checksums dont match"
def main():
parser = argparse.ArgumentParser(
description="Example LMDB creation"
)
parser.add_argument("--output_file", type=str, default=None,
help="Path to write the database to",
required=True)
args = parser.parse_args()
checksum = create_db(args.output_file)
# For testing reading:
read_db_with_caffe2(args.output_file, checksum)
if __name__ == '__main__':
main()