pytorch/benchmarks/operator_benchmark/benchmark_pytorch.py
Mingzhe Li 26f12af537 Fix op benchmarks error in OSS environment (#19518)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19518

Previous design needs to run the op benchmarks from PyTorch root directory which could lead to `module not found` error in OSS environment. This diff fixes that issue by making the benchmark to be launched in the `benchmarks` folder.

Reviewed By: ilia-cher

Differential Revision: D15020787

fbshipit-source-id: eb09814a33432a66cc857702bc86538cd17bea3b
2019-04-19 16:25:16 -07:00

26 lines
796 B
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from operator_benchmark import benchmark_core, benchmark_utils
import torch
"""PyTorch performance microbenchmarks.
This module contains PyTorch-specific functionalities for performance
microbenchmarks.
"""
def PyTorchOperatorTestCase(test_name, op_type, input_shapes, op_args, run_mode):
"""Benchmark Tester function for Pytorch framework.
"""
inputs = [torch.from_numpy(benchmark_utils.numpy_random_fp32(*input)) for input in input_shapes]
def benchmark_func(num_runs):
op_type(*(inputs + [num_runs]))
benchmark_core.add_benchmark_tester("PyTorch", test_name, input_shapes, op_args, run_mode, benchmark_func)