mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
The list of added operators: add_, addcmul, arange, baddbmm…, bmm, clamp, div, div_, gelu, index_add, logical_and, mul_, sub_, topk, where This pull request is the same as a previous one: https://github.com/pytorch/pytorch/pull/145121 which inadvertently got deleted while merging. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145625 Approved by: https://github.com/jeffdaily
175 lines
4.1 KiB
Python
175 lines
4.1 KiB
Python
import operator_benchmark as op_bench
|
|
|
|
import torch
|
|
|
|
|
|
"""Microbenchmarks for point-wise unary operator."""
|
|
|
|
|
|
# Configs for pointwise unary ops
|
|
unary_ops_configs_short = op_bench.config_list(
|
|
attr_names=["M", "N"],
|
|
attrs=[
|
|
[512, 512],
|
|
],
|
|
cross_product_configs={
|
|
"device": ["cpu", "cuda"],
|
|
},
|
|
tags=["short"],
|
|
)
|
|
|
|
unary_ops_configs_long = op_bench.cross_product_configs(
|
|
M=[256, 1024], N=[256, 1024], device=["cpu", "cuda"], tags=["long"]
|
|
)
|
|
|
|
|
|
class UnaryOpBenchmark(op_bench.TorchBenchmarkBase):
|
|
def init(self, M, N, device, op_func):
|
|
self.inputs = {"input": torch.rand(M, N, device=device)}
|
|
self.op_func = op_func
|
|
|
|
def forward(self, input):
|
|
return self.op_func(input)
|
|
|
|
|
|
def bernoulli_(input):
|
|
return input.bernoulli_()
|
|
|
|
|
|
def cauchy_(input):
|
|
return input.cauchy_()
|
|
|
|
|
|
def digamma_(input):
|
|
return input.digamma_()
|
|
|
|
|
|
def exponential_(input):
|
|
return input.exponential_()
|
|
|
|
|
|
def normal_(input):
|
|
return input.normal_()
|
|
|
|
|
|
def random_(input):
|
|
return input.random_()
|
|
|
|
|
|
def sign_(input):
|
|
return input.sign_()
|
|
|
|
|
|
def uniform_(input):
|
|
return input.uniform_()
|
|
|
|
|
|
def half_(input):
|
|
return input.half()
|
|
|
|
|
|
def long_(input):
|
|
return input.long()
|
|
|
|
|
|
def clamp(input):
|
|
return torch.clamp(input, min=0.25, max=0.75)
|
|
|
|
|
|
unary_ops_list = op_bench.op_list(
|
|
attr_names=["op_name", "op_func"],
|
|
attrs=[
|
|
["abs", torch.abs],
|
|
["abs_", torch.abs_],
|
|
["acos", torch.acos],
|
|
["acos_", torch.acos_],
|
|
["argsort", torch.argsort],
|
|
["asin", torch.asin],
|
|
["asin_", torch.asin_],
|
|
["atan", torch.atan],
|
|
["atan_", torch.atan_],
|
|
["ceil", torch.ceil],
|
|
["ceil_", torch.ceil_],
|
|
["clamp", clamp],
|
|
["clone", torch.clone],
|
|
["cos", torch.cos],
|
|
["cos_", torch.cos_],
|
|
["cosh", torch.cosh],
|
|
["digamma", torch.digamma],
|
|
["erf", torch.erf],
|
|
["erf_", torch.erf_],
|
|
["erfc", torch.erfc],
|
|
["erfc_", torch.erfc_],
|
|
["erfinv", torch.erfinv],
|
|
["exp", torch.exp],
|
|
["exp_", torch.exp_],
|
|
["expm1", torch.expm1],
|
|
["expm1_", torch.expm1_],
|
|
["floor", torch.floor],
|
|
["floor_", torch.floor_],
|
|
["frac", torch.frac],
|
|
["frac_", torch.frac_],
|
|
["gelu", torch.nn.functional.gelu],
|
|
["hardshrink", torch.hardshrink],
|
|
["lgamma", torch.lgamma],
|
|
["log", torch.log],
|
|
["log10", torch.log10],
|
|
["log10_", torch.log10_],
|
|
["log1p", torch.log1p],
|
|
["log1p_", torch.log1p_],
|
|
["log2", torch.log2],
|
|
["log2_", torch.log2_],
|
|
["log_", torch.log_],
|
|
["logit", torch.logit],
|
|
["logit_", torch.logit_],
|
|
["neg", torch.neg],
|
|
["neg_", torch.neg_],
|
|
["reciprocal", torch.reciprocal],
|
|
["reciprocal_", torch.reciprocal_],
|
|
["relu", torch.relu],
|
|
["relu_", torch.relu_],
|
|
["round", torch.round],
|
|
["round_", torch.round_],
|
|
["rsqrt", torch.rsqrt],
|
|
["rsqrt_", torch.rsqrt_],
|
|
["sigmoid", torch.sigmoid],
|
|
["sigmoid_", torch.sigmoid_],
|
|
["sign", torch.sign],
|
|
["sgn", torch.sgn],
|
|
["sin", torch.sin],
|
|
["sin_", torch.sin_],
|
|
["sinh", torch.sinh],
|
|
["sqrt", torch.sqrt],
|
|
["sqrt_", torch.sqrt_],
|
|
["square", torch.square],
|
|
["square_", torch.square_],
|
|
["tan", torch.tan],
|
|
["tan_", torch.tan_],
|
|
["tanh", torch.tanh],
|
|
["tanh_", torch.tanh_],
|
|
["trunc", torch.trunc],
|
|
["trunc_", torch.trunc_],
|
|
["unique", torch.functional._return_output],
|
|
["zero_", torch.zero_],
|
|
["bernoulli_", bernoulli_],
|
|
["cauchy_", cauchy_],
|
|
["digamma_", digamma_],
|
|
["exponential_", exponential_],
|
|
["normal_", normal_],
|
|
["random_", random_],
|
|
["sign_", sign_],
|
|
["uniform_", uniform_],
|
|
["half", half_],
|
|
["long", long_],
|
|
],
|
|
)
|
|
|
|
|
|
op_bench.generate_pt_tests_from_op_list(
|
|
unary_ops_list, unary_ops_configs_short + unary_ops_configs_long, UnaryOpBenchmark
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
op_bench.benchmark_runner.main()
|