mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
create tensor based on provided datatype (#22468)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22468 as title Reviewed By: ajauhri Differential Revision: D15744503 fbshipit-source-id: 050b32dd7f135512385fc04f098c376c664211a9
This commit is contained in:
parent
319ef3bcbb
commit
325ec2327f
|
|
@ -24,14 +24,19 @@ class Caffe2BenchmarkBase(object):
|
|||
self.args = {}
|
||||
self.user_provided_name = None
|
||||
|
||||
# TODO: Add other dtype support
|
||||
def tensor(self, *shapes):
|
||||
""" A wapper function to create tensor (blob in caffe2) filled with random
|
||||
value. The name/label of the tensor is returned and it is available
|
||||
def tensor(self, shapes, dtype='float32'):
|
||||
""" A wapper function to create C2 tensor filled with random data.
|
||||
The name/label of the tensor is returned and it is available
|
||||
throughout the benchmark execution phase.
|
||||
Args:
|
||||
shapes: int or a sequence of ints to defining the shapes of the tensor
|
||||
dtype: use the dtypes from numpy
|
||||
(https://docs.scipy.org/doc/numpy/user/basics.types.html)
|
||||
Return:
|
||||
C2 tensor of dtype
|
||||
"""
|
||||
blob_name = 'blob_' + str(Caffe2BenchmarkBase.tensor_index)
|
||||
workspace.FeedBlob(blob_name, benchmark_utils.numpy_random_fp32(*shapes))
|
||||
workspace.FeedBlob(blob_name, benchmark_utils.numpy_random(dtype, *shapes))
|
||||
Caffe2BenchmarkBase.tensor_index += 1
|
||||
return blob_name
|
||||
|
||||
|
|
|
|||
|
|
@ -19,12 +19,18 @@ def shape_to_string(shape):
|
|||
return ', '.join([str(x) for x in shape])
|
||||
|
||||
|
||||
def numpy_random_fp32(*shape):
|
||||
"""Return a random numpy tensor of float32 type.
|
||||
def numpy_random(dtype, *shapes):
|
||||
""" Return a random numpy tensor of the provided dtype.
|
||||
Args:
|
||||
shapes: int or a sequence of ints to defining the shapes of the tensor
|
||||
dtype: use the dtypes from numpy
|
||||
(https://docs.scipy.org/doc/numpy/user/basics.types.html)
|
||||
Return:
|
||||
numpy tensor of dtype
|
||||
"""
|
||||
# TODO: consider more complex/custom dynamic ranges for
|
||||
# comprehensive test coverage.
|
||||
return np.random.rand(*shape).astype(np.float32)
|
||||
return np.random.rand(*shapes).astype(dtype)
|
||||
|
||||
|
||||
def set_omp_threads(num_threads):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user