mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
create tensor based on provided datatype (#22468)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22468 as title Reviewed By: ajauhri Differential Revision: D15744503 fbshipit-source-id: 050b32dd7f135512385fc04f098c376c664211a9
This commit is contained in:
parent
319ef3bcbb
commit
325ec2327f
|
|
@ -24,14 +24,19 @@ class Caffe2BenchmarkBase(object):
|
||||||
self.args = {}
|
self.args = {}
|
||||||
self.user_provided_name = None
|
self.user_provided_name = None
|
||||||
|
|
||||||
# TODO: Add other dtype support
|
def tensor(self, shapes, dtype='float32'):
|
||||||
def tensor(self, *shapes):
|
""" A wapper function to create C2 tensor filled with random data.
|
||||||
""" A wapper function to create tensor (blob in caffe2) filled with random
|
The name/label of the tensor is returned and it is available
|
||||||
value. The name/label of the tensor is returned and it is available
|
|
||||||
throughout the benchmark execution phase.
|
throughout the benchmark execution phase.
|
||||||
|
Args:
|
||||||
|
shapes: int or a sequence of ints to defining the shapes of the tensor
|
||||||
|
dtype: use the dtypes from numpy
|
||||||
|
(https://docs.scipy.org/doc/numpy/user/basics.types.html)
|
||||||
|
Return:
|
||||||
|
C2 tensor of dtype
|
||||||
"""
|
"""
|
||||||
blob_name = 'blob_' + str(Caffe2BenchmarkBase.tensor_index)
|
blob_name = 'blob_' + str(Caffe2BenchmarkBase.tensor_index)
|
||||||
workspace.FeedBlob(blob_name, benchmark_utils.numpy_random_fp32(*shapes))
|
workspace.FeedBlob(blob_name, benchmark_utils.numpy_random(dtype, *shapes))
|
||||||
Caffe2BenchmarkBase.tensor_index += 1
|
Caffe2BenchmarkBase.tensor_index += 1
|
||||||
return blob_name
|
return blob_name
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,12 +19,18 @@ def shape_to_string(shape):
|
||||||
return ', '.join([str(x) for x in shape])
|
return ', '.join([str(x) for x in shape])
|
||||||
|
|
||||||
|
|
||||||
def numpy_random_fp32(*shape):
|
def numpy_random(dtype, *shapes):
|
||||||
"""Return a random numpy tensor of float32 type.
|
""" Return a random numpy tensor of the provided dtype.
|
||||||
|
Args:
|
||||||
|
shapes: int or a sequence of ints to defining the shapes of the tensor
|
||||||
|
dtype: use the dtypes from numpy
|
||||||
|
(https://docs.scipy.org/doc/numpy/user/basics.types.html)
|
||||||
|
Return:
|
||||||
|
numpy tensor of dtype
|
||||||
"""
|
"""
|
||||||
# TODO: consider more complex/custom dynamic ranges for
|
# TODO: consider more complex/custom dynamic ranges for
|
||||||
# comprehensive test coverage.
|
# comprehensive test coverage.
|
||||||
return np.random.rand(*shape).astype(np.float32)
|
return np.random.rand(*shapes).astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
def set_omp_threads(num_threads):
|
def set_omp_threads(num_threads):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user