create tensor based on provided datatype (#22468)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22468

as title

Reviewed By: ajauhri

Differential Revision: D15744503

fbshipit-source-id: 050b32dd7f135512385fc04f098c376c664211a9
This commit is contained in:
Mingzhe Li 2019-07-03 16:43:44 -07:00 committed by Facebook Github Bot
parent 319ef3bcbb
commit 325ec2327f
2 changed files with 19 additions and 8 deletions

View File

@ -24,14 +24,19 @@ class Caffe2BenchmarkBase(object):
self.args = {} self.args = {}
self.user_provided_name = None self.user_provided_name = None
# TODO: Add other dtype support def tensor(self, shapes, dtype='float32'):
def tensor(self, *shapes): """ A wapper function to create C2 tensor filled with random data.
""" A wapper function to create tensor (blob in caffe2) filled with random The name/label of the tensor is returned and it is available
value. The name/label of the tensor is returned and it is available
throughout the benchmark execution phase. throughout the benchmark execution phase.
Args:
shapes: int or a sequence of ints to defining the shapes of the tensor
dtype: use the dtypes from numpy
(https://docs.scipy.org/doc/numpy/user/basics.types.html)
Return:
C2 tensor of dtype
""" """
blob_name = 'blob_' + str(Caffe2BenchmarkBase.tensor_index) blob_name = 'blob_' + str(Caffe2BenchmarkBase.tensor_index)
workspace.FeedBlob(blob_name, benchmark_utils.numpy_random_fp32(*shapes)) workspace.FeedBlob(blob_name, benchmark_utils.numpy_random(dtype, *shapes))
Caffe2BenchmarkBase.tensor_index += 1 Caffe2BenchmarkBase.tensor_index += 1
return blob_name return blob_name

View File

@ -19,12 +19,18 @@ def shape_to_string(shape):
return ', '.join([str(x) for x in shape]) return ', '.join([str(x) for x in shape])
def numpy_random_fp32(*shape): def numpy_random(dtype, *shapes):
"""Return a random numpy tensor of float32 type. """ Return a random numpy tensor of the provided dtype.
Args:
shapes: int or a sequence of ints to defining the shapes of the tensor
dtype: use the dtypes from numpy
(https://docs.scipy.org/doc/numpy/user/basics.types.html)
Return:
numpy tensor of dtype
""" """
# TODO: consider more complex/custom dynamic ranges for # TODO: consider more complex/custom dynamic ranges for
# comprehensive test coverage. # comprehensive test coverage.
return np.random.rand(*shape).astype(np.float32) return np.random.rand(*shapes).astype(dtype)
def set_omp_threads(num_threads): def set_omp_threads(num_threads):