import operator_benchmark as op_bench import torch """Microbenchmarks for topk operator""" topk_configs_short = op_bench.config_list( attr_names=["shape", "k", "dim"], attrs=[ [(16, 4), 4, 1], [(1024 * 1024,), 16, 0], ], cross_product_configs={"device": ["cpu"], "dtype": [torch.float]}, tags=["short"], ) topk_configs_long = op_bench.cross_product_configs( shape=[(64, 2), (1024 * 1024,), (128,)], k=[1, 2, 4, 16, 32], dim=[0], device=["cpu", "cuda"], dtype=[torch.float, torch.bfloat16], tags=["long"], ) class TopkBenchmark(op_bench.TorchBenchmarkBase): def init(self, shape, k, dim, dtype, device): self.inputs = { "input": torch.randn(shape, device=device, dtype=dtype), "k": k, "dim": dim, } self.set_module_name("topk") def forward(self, input, k, dim): return torch.topk(input, k=k, dim=dim) op_bench.generate_pt_test(topk_configs_short + topk_configs_long, TopkBenchmark) if __name__ == "__main__": op_bench.benchmark_runner.main()