pytorch/benchmarks/operator_benchmark/pt/unary_test.py
Arash Pakbin f3ddc08ddc Additional operators in operator benchmark (#145625)
The list of added operators:
add_, addcmul, arange, baddbmm…, bmm, clamp, div, div_, gelu, index_add, logical_and, mul_, sub_, topk, where

This pull request is the same as a previous one: https://github.com/pytorch/pytorch/pull/145121 which inadvertently got deleted while merging.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145625
Approved by: https://github.com/jeffdaily
2025-01-26 19:20:02 +00:00

175 lines
4.1 KiB
Python

import operator_benchmark as op_bench
import torch
"""Microbenchmarks for point-wise unary operator."""
# Configs for pointwise unary ops
unary_ops_configs_short = op_bench.config_list(
attr_names=["M", "N"],
attrs=[
[512, 512],
],
cross_product_configs={
"device": ["cpu", "cuda"],
},
tags=["short"],
)
unary_ops_configs_long = op_bench.cross_product_configs(
M=[256, 1024], N=[256, 1024], device=["cpu", "cuda"], tags=["long"]
)
class UnaryOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, device, op_func):
self.inputs = {"input": torch.rand(M, N, device=device)}
self.op_func = op_func
def forward(self, input):
return self.op_func(input)
def bernoulli_(input):
return input.bernoulli_()
def cauchy_(input):
return input.cauchy_()
def digamma_(input):
return input.digamma_()
def exponential_(input):
return input.exponential_()
def normal_(input):
return input.normal_()
def random_(input):
return input.random_()
def sign_(input):
return input.sign_()
def uniform_(input):
return input.uniform_()
def half_(input):
return input.half()
def long_(input):
return input.long()
def clamp(input):
return torch.clamp(input, min=0.25, max=0.75)
unary_ops_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["abs", torch.abs],
["abs_", torch.abs_],
["acos", torch.acos],
["acos_", torch.acos_],
["argsort", torch.argsort],
["asin", torch.asin],
["asin_", torch.asin_],
["atan", torch.atan],
["atan_", torch.atan_],
["ceil", torch.ceil],
["ceil_", torch.ceil_],
["clamp", clamp],
["clone", torch.clone],
["cos", torch.cos],
["cos_", torch.cos_],
["cosh", torch.cosh],
["digamma", torch.digamma],
["erf", torch.erf],
["erf_", torch.erf_],
["erfc", torch.erfc],
["erfc_", torch.erfc_],
["erfinv", torch.erfinv],
["exp", torch.exp],
["exp_", torch.exp_],
["expm1", torch.expm1],
["expm1_", torch.expm1_],
["floor", torch.floor],
["floor_", torch.floor_],
["frac", torch.frac],
["frac_", torch.frac_],
["gelu", torch.nn.functional.gelu],
["hardshrink", torch.hardshrink],
["lgamma", torch.lgamma],
["log", torch.log],
["log10", torch.log10],
["log10_", torch.log10_],
["log1p", torch.log1p],
["log1p_", torch.log1p_],
["log2", torch.log2],
["log2_", torch.log2_],
["log_", torch.log_],
["logit", torch.logit],
["logit_", torch.logit_],
["neg", torch.neg],
["neg_", torch.neg_],
["reciprocal", torch.reciprocal],
["reciprocal_", torch.reciprocal_],
["relu", torch.relu],
["relu_", torch.relu_],
["round", torch.round],
["round_", torch.round_],
["rsqrt", torch.rsqrt],
["rsqrt_", torch.rsqrt_],
["sigmoid", torch.sigmoid],
["sigmoid_", torch.sigmoid_],
["sign", torch.sign],
["sgn", torch.sgn],
["sin", torch.sin],
["sin_", torch.sin_],
["sinh", torch.sinh],
["sqrt", torch.sqrt],
["sqrt_", torch.sqrt_],
["square", torch.square],
["square_", torch.square_],
["tan", torch.tan],
["tan_", torch.tan_],
["tanh", torch.tanh],
["tanh_", torch.tanh_],
["trunc", torch.trunc],
["trunc_", torch.trunc_],
["unique", torch.functional._return_output],
["zero_", torch.zero_],
["bernoulli_", bernoulli_],
["cauchy_", cauchy_],
["digamma_", digamma_],
["exponential_", exponential_],
["normal_", normal_],
["random_", random_],
["sign_", sign_],
["uniform_", uniform_],
["half", half_],
["long", long_],
],
)
op_bench.generate_pt_tests_from_op_list(
unary_ops_list, unary_ops_configs_short + unary_ops_configs_long, UnaryOpBenchmark
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()