mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19518 Previous design needs to run the op benchmarks from PyTorch root directory which could lead to `module not found` error in OSS environment. This diff fixes that issue by making the benchmark to be launched in the `benchmarks` folder. Reviewed By: ilia-cher Differential Revision: D15020787 fbshipit-source-id: eb09814a33432a66cc857702bc86538cd17bea3b
61 lines
1.4 KiB
Python
61 lines
1.4 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from operator_benchmark import benchmark_core, benchmark_runner
|
|
from operator_benchmark.benchmark_test_generator import *
|
|
|
|
import torch
|
|
|
|
|
|
"""Microbenchmarks for MatMul operator. Supports both Caffe2/PyTorch."""
|
|
# Long config
|
|
long_config = generate_configs(
|
|
M=get_n_rand_nums(min_val=1, max_val=128, n=2),
|
|
N=get_n_rand_nums(min_val=1, max_val=128, n=2),
|
|
K=get_n_rand_nums(min_val=1, max_val=128, n=2),
|
|
trans_a=[False, True],
|
|
trans_b=[True, False],
|
|
mode=['long'],
|
|
sample_func=cross_product
|
|
)
|
|
|
|
# Short config
|
|
short_config = generate_configs(
|
|
M=[8, 16],
|
|
N=[32, 64],
|
|
K=[64, 128],
|
|
trans_a=[False, True],
|
|
trans_b=[True, False],
|
|
mode=['short'],
|
|
sample_func=cross_product
|
|
)
|
|
|
|
|
|
@torch.jit.script
|
|
def torch_matmul(a, b, iterations):
|
|
# type: (Tensor, Tensor, int)
|
|
result = torch.jit.annotate(torch.Tensor, None)
|
|
for _ in range(iterations):
|
|
result = torch.matmul(a, b)
|
|
return result
|
|
|
|
|
|
@benchmark_core.register_test
|
|
def test_matmul():
|
|
generate_pt_test(
|
|
[long_config, short_config],
|
|
map_pt_config_matmul,
|
|
[('matmul', torch_matmul)]
|
|
)
|
|
generate_c2_test(
|
|
[long_config, short_config],
|
|
map_c2_config_matmul,
|
|
[('matmul', 'MatMul')],
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
benchmark_runner.main()
|