pytorch/benchmarks/framework_overhead_benchmark/C2Module.py
Brian Wignall f326045b37 Fix typos, via a Levenshtein-type corrector (#31523)
Summary:
Should be non-semantic.

Uses https://en.wikipedia.org/wiki/Wikipedia:Lists_of_common_misspellings/For_machines to find likely typos, with https://github.com/bwignall/typochecker to help automate the checking.

Uses an updated version of the tool used in https://github.com/pytorch/pytorch/pull/30606 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31523

Differential Revision: D19216749

Pulled By: mrshenli

fbshipit-source-id: 7fd489cb9a77cd7e4950c1046f925d57524960ea
2020-01-17 16:03:19 -08:00

43 lines
1.6 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
from caffe2.python import workspace, core
import numpy as np
from utils import NUM_LOOP_ITERS
workspace.GlobalInit(['caffe2'])
def add_blob(ws, blob_name, tensor_size):
blob_tensor = np.random.randn(*tensor_size).astype(np.float32)
ws.FeedBlob(blob_name, blob_tensor)
class C2SimpleNet(object):
"""
This module constructs a net with 'op_name' operator. The net consist
a series of such operator.
It initializes the workspace with input blob equal to the number of parameters
needed for the op.
Provides forward method to run the net niter times.
"""
def __init__(self, op_name, num_inputs=1, debug=False):
self.input_names = []
self.net = core.Net("framework_benchmark_net")
self.input_names = ["in_{}".format(i) for i in range(num_inputs)]
for i in range(num_inputs):
add_blob(workspace, self.input_names[i], [1])
self.net.AddExternalInputs(self.input_names)
op_constructor = getattr(self.net, op_name)
op_constructor(self.input_names)
self.output_name = self.net._net.op[-1].output
print("Benchmarking op {}:".format(op_name))
for _ in range(NUM_LOOP_ITERS):
output_name = self.net._net.op[-1].output
self.input_names[-1] = output_name[0]
assert len(self.input_names) == num_inputs
op_constructor(self.input_names)
workspace.CreateNet(self.net)
if debug:
print(self.net._net)
def forward(self, niters):
workspace.RunNet(self.net, niters, False)