mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Should be non-semantic. Uses https://en.wikipedia.org/wiki/Wikipedia:Lists_of_common_misspellings/For_machines to find likely typos, with https://github.com/bwignall/typochecker to help automate the checking. Uses an updated version of the tool used in https://github.com/pytorch/pytorch/pull/30606 . Pull Request resolved: https://github.com/pytorch/pytorch/pull/31523 Differential Revision: D19216749 Pulled By: mrshenli fbshipit-source-id: 7fd489cb9a77cd7e4950c1046f925d57524960ea
43 lines
1.6 KiB
Python
43 lines
1.6 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
from caffe2.python import workspace, core
|
|
import numpy as np
|
|
|
|
from utils import NUM_LOOP_ITERS
|
|
|
|
workspace.GlobalInit(['caffe2'])
|
|
|
|
def add_blob(ws, blob_name, tensor_size):
|
|
blob_tensor = np.random.randn(*tensor_size).astype(np.float32)
|
|
ws.FeedBlob(blob_name, blob_tensor)
|
|
|
|
class C2SimpleNet(object):
|
|
"""
|
|
This module constructs a net with 'op_name' operator. The net consist
|
|
a series of such operator.
|
|
It initializes the workspace with input blob equal to the number of parameters
|
|
needed for the op.
|
|
Provides forward method to run the net niter times.
|
|
"""
|
|
def __init__(self, op_name, num_inputs=1, debug=False):
|
|
self.input_names = []
|
|
self.net = core.Net("framework_benchmark_net")
|
|
self.input_names = ["in_{}".format(i) for i in range(num_inputs)]
|
|
for i in range(num_inputs):
|
|
add_blob(workspace, self.input_names[i], [1])
|
|
self.net.AddExternalInputs(self.input_names)
|
|
op_constructor = getattr(self.net, op_name)
|
|
op_constructor(self.input_names)
|
|
self.output_name = self.net._net.op[-1].output
|
|
print("Benchmarking op {}:".format(op_name))
|
|
for _ in range(NUM_LOOP_ITERS):
|
|
output_name = self.net._net.op[-1].output
|
|
self.input_names[-1] = output_name[0]
|
|
assert len(self.input_names) == num_inputs
|
|
op_constructor(self.input_names)
|
|
workspace.CreateNet(self.net)
|
|
if debug:
|
|
print(self.net._net)
|
|
|
|
def forward(self, niters):
|
|
workspace.RunNet(self.net, niters, False)
|