pytorch/caffe2/contrib/cuda-convnet2/python_util/data.py
Mickaël Schoentgen 04f5605ba1 Fix several DeprecationWarning: invalid escape sequence (#15733)
Summary:
Hello,

This is a little patch to fix `DeprecationWarning: invalid escape sequence`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15733

Differential Revision: D13587291

Pulled By: soumith

fbshipit-source-id: ce68db2de92ca7eaa42f78ca5ae6fbc1d4d90e05
2019-01-05 08:53:35 -08:00

195 lines
7.6 KiB
Python

# Copyright 2014 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as n
from numpy.random import randn, rand, random_integers
import os
from threading import Thread
from util import *
BATCH_META_FILE = "batches.meta"
class DataLoaderThread(Thread):
def __init__(self, path, tgt):
Thread.__init__(self)
self.path = path
self.tgt = tgt
def run(self):
self.tgt += [unpickle(self.path)]
class DataProvider:
BATCH_REGEX = re.compile(r'^data_batch_(\d+)(\.\d+)?$')
def __init__(self, data_dir, batch_range=None, init_epoch=1, init_batchnum=None, dp_params={}, test=False):
if batch_range == None:
batch_range = DataProvider.get_batch_nums(data_dir)
if init_batchnum is None or init_batchnum not in batch_range:
init_batchnum = batch_range[0]
self.data_dir = data_dir
self.batch_range = batch_range
self.curr_epoch = init_epoch
self.curr_batchnum = init_batchnum
self.dp_params = dp_params
self.batch_meta = self.get_batch_meta(data_dir)
self.data_dic = None
self.test = test
self.batch_idx = batch_range.index(init_batchnum)
def get_next_batch(self):
if self.data_dic is None or len(self.batch_range) > 1:
self.data_dic = self.get_batch(self.curr_batchnum)
epoch, batchnum = self.curr_epoch, self.curr_batchnum
self.advance_batch()
return epoch, batchnum, self.data_dic
def get_batch(self, batch_num):
fname = self.get_data_file_name(batch_num)
if os.path.isdir(fname): # batch in sub-batches
sub_batches = sorted(os.listdir(fname), key=alphanum_key)
#print sub_batches
num_sub_batches = len(sub_batches)
tgts = [[] for i in xrange(num_sub_batches)]
threads = [DataLoaderThread(os.path.join(fname, s), tgt) for (s, tgt) in zip(sub_batches, tgts)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
return [t[0] for t in tgts]
return unpickle(self.get_data_file_name(batch_num))
def get_data_dims(self,idx=0):
return self.batch_meta['num_vis'] if idx == 0 else 1
def advance_batch(self):
self.batch_idx = self.get_next_batch_idx()
self.curr_batchnum = self.batch_range[self.batch_idx]
if self.batch_idx == 0: # we wrapped
self.curr_epoch += 1
def get_next_batch_idx(self):
return (self.batch_idx + 1) % len(self.batch_range)
def get_next_batch_num(self):
return self.batch_range[self.get_next_batch_idx()]
# get filename of current batch
def get_data_file_name(self, batchnum=None):
if batchnum is None:
batchnum = self.curr_batchnum
return os.path.join(self.data_dir, 'data_batch_%d' % batchnum)
@classmethod
def get_instance(cls, data_dir, batch_range=None, init_epoch=1, init_batchnum=None, type="default", dp_params={}, test=False):
# why the fuck can't i reference DataProvider in the original definition?
#cls.dp_classes['default'] = DataProvider
type = type or DataProvider.get_batch_meta(data_dir)['dp_type'] # allow data to decide data provider
if type.startswith("dummy-"):
name = "-".join(type.split('-')[:-1]) + "-n"
if name not in dp_types:
raise DataProviderException("No such data provider: %s" % type)
_class = dp_classes[name]
dims = int(type.split('-')[-1])
return _class(dims)
elif type in dp_types:
_class = dp_classes[type]
return _class(data_dir, batch_range, init_epoch, init_batchnum, dp_params, test)
raise DataProviderException("No such data provider: %s" % type)
@classmethod
def register_data_provider(cls, name, desc, _class):
if name in dp_types:
raise DataProviderException("Data provider %s already registered" % name)
dp_types[name] = desc
dp_classes[name] = _class
@staticmethod
def get_batch_meta(data_dir):
return unpickle(os.path.join(data_dir, BATCH_META_FILE))
@staticmethod
def get_batch_filenames(srcdir):
return sorted([f for f in os.listdir(srcdir) if DataProvider.BATCH_REGEX.match(f)], key=alphanum_key)
@staticmethod
def get_batch_nums(srcdir):
names = DataProvider.get_batch_filenames(srcdir)
return sorted(list(set(int(DataProvider.BATCH_REGEX.match(n).group(1)) for n in names)))
@staticmethod
def get_num_batches(srcdir):
return len(DataProvider.get_batch_nums(srcdir))
class DummyDataProvider(DataProvider):
def __init__(self, data_dim):
#self.data_dim = data_dim
self.batch_range = [1]
self.batch_meta = {'num_vis': data_dim, 'data_in_rows':True}
self.curr_epoch = 1
self.curr_batchnum = 1
self.batch_idx = 0
def get_next_batch(self):
epoch, batchnum = self.curr_epoch, self.curr_batchnum
self.advance_batch()
data = rand(512, self.get_data_dims()).astype(n.single)
return self.curr_epoch, self.curr_batchnum, {'data':data}
class LabeledDataProvider(DataProvider):
def __init__(self, data_dir, batch_range=None, init_epoch=1, init_batchnum=None, dp_params={}, test=False):
DataProvider.__init__(self, data_dir, batch_range, init_epoch, init_batchnum, dp_params, test)
def get_num_classes(self):
return len(self.batch_meta['label_names'])
class LabeledDummyDataProvider(DummyDataProvider):
def __init__(self, data_dim, num_classes=10, num_cases=7):
#self.data_dim = data_dim
self.batch_range = [1]
self.batch_meta = {'num_vis': data_dim,
'label_names': [str(x) for x in range(num_classes)],
'data_in_rows':True}
self.num_cases = num_cases
self.num_classes = num_classes
self.curr_epoch = 1
self.curr_batchnum = 1
self.batch_idx=0
self.data = None
def get_num_classes(self):
return self.num_classes
def get_next_batch(self):
epoch, batchnum = self.curr_epoch, self.curr_batchnum
self.advance_batch()
if self.data is None:
data = rand(self.num_cases, self.get_data_dims()).astype(n.single) # <--changed to rand
labels = n.require(n.c_[random_integers(0,self.num_classes-1,self.num_cases)], requirements='C', dtype=n.single)
self.data, self.labels = data, labels
else:
data, labels = self.data, self.labels
# print data.shape, labels.shape
return self.curr_epoch, self.curr_batchnum, [data.T, labels.T ]
dp_types = {"dummy-n": "Dummy data provider for n-dimensional data",
"dummy-labeled-n": "Labeled dummy data provider for n-dimensional data"}
dp_classes = {"dummy-n": DummyDataProvider,
"dummy-labeled-n": LabeledDummyDataProvider}
class DataProviderException(Exception):
pass