mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/105928 Approved by: https://github.com/albanD
67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
import torch
|
|
|
|
import operator_benchmark as op_bench
|
|
|
|
"""Microbenchmarks for the quantized interpolate op.
|
|
|
|
Note: We are not benchmarking `upsample` as it is being depricated, and calls
|
|
the `interpolate` anyway.
|
|
"""
|
|
|
|
qinterpolate_long_configs = op_bench.config_list(
|
|
attr_names=["M", "N", "K"],
|
|
attrs=[
|
|
[512, 512, 512],
|
|
],
|
|
cross_product_configs={
|
|
"dtype": [torch.quint8, torch.qint8, torch.qint32],
|
|
"mode": ["nearest", "bilinear"],
|
|
"scale": [0.5, 1.0, 2.0],
|
|
"contig": [True], # TODO: Add `False` after #29435
|
|
},
|
|
tags=["long"],
|
|
)
|
|
|
|
|
|
qinterpolate_short_configs = op_bench.config_list(
|
|
attr_names=["M", "N", "K", "dtype", "mode", "scale", "contig"],
|
|
attrs=[
|
|
[32, 32, 32, torch.quint8, "nearest", 0.5, True], # Downsample
|
|
[32, 32, 32, torch.quint8, "bilinear", 0.5, True], # Downsample
|
|
[32, 32, 32, torch.quint8, "nearest", 2.0, True], # Upsample
|
|
[32, 32, 32, torch.quint8, "bilinear", 2.0, True], # Upsample
|
|
[3, 720, 1280, torch.quint8, "bilinear", 0.83333, True], # Downsample
|
|
],
|
|
tags=["short"],
|
|
)
|
|
|
|
|
|
class QInterpolateBenchmark(op_bench.TorchBenchmarkBase):
|
|
def init(self, M, N, K, dtype, mode, scale, contig):
|
|
f_input = (torch.rand(1, M, N, K) - 0.5) * 256
|
|
scale = 0.1
|
|
zero_point = 42
|
|
self.q_input = torch.quantize_per_tensor(
|
|
f_input, scale=scale, zero_point=zero_point, dtype=dtype
|
|
)
|
|
if not contig:
|
|
permute_dims = list(range(self.q_input.ndim))[::-1]
|
|
self.q_input = self.q_input.permute(permute_dims)
|
|
|
|
self.inputs = {"q_input": self.q_input, "scale_factor": scale, "mode": mode}
|
|
self.set_module_name("q_interpolate")
|
|
|
|
def forward(self, q_input, scale_factor: float, mode: str):
|
|
return torch.nn.functional.interpolate(
|
|
q_input, scale_factor=scale_factor, mode=mode
|
|
)
|
|
|
|
|
|
op_bench.generate_pt_test(
|
|
qinterpolate_short_configs + qinterpolate_long_configs, QInterpolateBenchmark
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
op_bench.benchmark_runner.main()
|