mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25619 Test Plan: ``` [huaminli@devvm2388.ftw3 ~/fbsource/fbcode] buck run mode/dev-nosan caffe2/benchmarks/operator_benchmark:benchmark_all_test -- --operators None --iterations 3 ``` last few lines of output P108286305 Reviewed By: mingzhe09088 Differential Revision: D17175802 fbshipit-source-id: 46b69fc1895444b15b6dfcec0625b6b9b006712a
37 lines
810 B
Python
37 lines
810 B
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import operator_benchmark as op_bench
|
|
import torch
|
|
|
|
|
|
"""Microbenchmarks for Chunk operator"""
|
|
|
|
|
|
# Configs for PT Chunk operator
|
|
chunks_short_configs = op_bench.cross_product_configs(
|
|
M=[256, 512],
|
|
N=[512],
|
|
chunks=[2],
|
|
tags=['short']
|
|
)
|
|
|
|
|
|
class ChunkBenchmark(op_bench.TorchBenchmarkBase):
|
|
def init(self, M, N, chunks):
|
|
self.input_one = torch.rand(M, N)
|
|
self.chunks = chunks
|
|
self.set_module_name('chunks')
|
|
|
|
def forward(self):
|
|
return torch.chunk(self.input_one, self.chunks)
|
|
|
|
|
|
op_bench.generate_pt_test(chunks_short_configs, ChunkBenchmark)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
op_bench.benchmark_runner.main()
|