mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19518 Previous design needs to run the op benchmarks from PyTorch root directory which could lead to `module not found` error in OSS environment. This diff fixes that issue by making the benchmark to be launched in the `benchmarks` folder. Reviewed By: ilia-cher Differential Revision: D15020787 fbshipit-source-id: eb09814a33432a66cc857702bc86538cd17bea3b
26 lines
796 B
Python
26 lines
796 B
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from operator_benchmark import benchmark_core, benchmark_utils
|
|
|
|
import torch
|
|
|
|
"""PyTorch performance microbenchmarks.
|
|
|
|
This module contains PyTorch-specific functionalities for performance
|
|
microbenchmarks.
|
|
"""
|
|
|
|
|
|
def PyTorchOperatorTestCase(test_name, op_type, input_shapes, op_args, run_mode):
|
|
"""Benchmark Tester function for Pytorch framework.
|
|
"""
|
|
inputs = [torch.from_numpy(benchmark_utils.numpy_random_fp32(*input)) for input in input_shapes]
|
|
|
|
def benchmark_func(num_runs):
|
|
op_type(*(inputs + [num_runs]))
|
|
|
|
benchmark_core.add_benchmark_tester("PyTorch", test_name, input_shapes, op_args, run_mode, benchmark_func)
|