mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44735 Reviewed By: mruberry Differential Revision: D23731306 Pulled By: ezyang fbshipit-source-id: 0ba009a99e475ddbe22981be8ac636f8a1c8b02f
42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
from caffe2.python import workspace, core
|
|
import numpy as np
|
|
|
|
from utils import NUM_LOOP_ITERS
|
|
|
|
workspace.GlobalInit(['caffe2'])
|
|
|
|
def add_blob(ws, blob_name, tensor_size):
|
|
blob_tensor = np.random.randn(*tensor_size).astype(np.float32)
|
|
ws.FeedBlob(blob_name, blob_tensor)
|
|
|
|
class C2SimpleNet(object):
|
|
"""
|
|
This module constructs a net with 'op_name' operator. The net consist
|
|
a series of such operator.
|
|
It initializes the workspace with input blob equal to the number of parameters
|
|
needed for the op.
|
|
Provides forward method to run the net niter times.
|
|
"""
|
|
def __init__(self, op_name, num_inputs=1, debug=False):
|
|
self.input_names = []
|
|
self.net = core.Net("framework_benchmark_net")
|
|
self.input_names = ["in_{}".format(i) for i in range(num_inputs)]
|
|
for i in range(num_inputs):
|
|
add_blob(workspace, self.input_names[i], [1])
|
|
self.net.AddExternalInputs(self.input_names)
|
|
op_constructor = getattr(self.net, op_name)
|
|
op_constructor(self.input_names)
|
|
self.output_name = self.net._net.op[-1].output
|
|
print("Benchmarking op {}:".format(op_name))
|
|
for _ in range(NUM_LOOP_ITERS):
|
|
output_name = self.net._net.op[-1].output
|
|
self.input_names[-1] = output_name[0]
|
|
assert len(self.input_names) == num_inputs
|
|
op_constructor(self.input_names)
|
|
workspace.CreateNet(self.net)
|
|
if debug:
|
|
print(self.net._net)
|
|
|
|
def forward(self, niters):
|
|
workspace.RunNet(self.net, niters, False)
|