From cb1596a19355d4a977882144b0aaf35437e48846 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 2 Mar 2021 11:48:57 -0800 Subject: [PATCH] [operator_benchmark] Added channels last 3d option to interpolate test (#53117) Summary: Description: - Added channels last 3d option to interpolate test - split config non-4d into two : 3d and 5d Pull Request resolved: https://github.com/pytorch/pytorch/pull/53117 Reviewed By: NicolasHug Differential Revision: D26754243 Pulled By: fmassa fbshipit-source-id: 49bbab3bb47de27790e39537d0fbeca0f01782c4 --- .../operator_benchmark/pt/interpolate_test.py | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/benchmarks/operator_benchmark/pt/interpolate_test.py b/benchmarks/operator_benchmark/pt/interpolate_test.py index 694e5cd19ca..b33de76c01e 100644 --- a/benchmarks/operator_benchmark/pt/interpolate_test.py +++ b/benchmarks/operator_benchmark/pt/interpolate_test.py @@ -10,7 +10,14 @@ class InterpolateBenchmark(op_bench.TorchBenchmarkBase): input_image = torch.randint(0, 256, size=input_size, dtype=torch.float, device='cpu', requires_grad=self.auto_set()) if channels_last: - input_image = input_image.contiguous(memory_format=torch.channels_last) + if input_image.ndim == 4: + input_image = input_image.contiguous(memory_format=torch.channels_last) + elif input_image.ndim == 5: + input_image = input_image.contiguous(memory_format=torch.channels_last_3d) + else: + raise ValueError( + f"Can not set channels_last to the input of {input_image.ndim} dims" + ) ndim_to_mode = { 3: 'linear', @@ -61,13 +68,10 @@ config_long = op_bench.config_list( ) -config_not_4d = op_bench.config_list( - # no channels_last as it's only valid for 4D tensors +config_3d = op_bench.config_list( + # no channels_last for 3D tensors attr_names=["input_size", "output_size"], attrs=[ - [(1, 3, 16, 320, 320), (8, 256, 256)], - [(1, 3, 16, 320, 320), (32, 512, 512)], - [(4, 512, 320), (256,)], [(4, 512, 320), (512,)], ], @@ -75,7 +79,20 @@ config_not_4d = op_bench.config_list( ) -for config in (config_short, config_long, config_not_4d): +config_5d = op_bench.config_list( + attr_names=["input_size", "output_size"], + attrs=[ + [(1, 3, 16, 320, 320), (8, 256, 256)], + [(1, 3, 16, 320, 320), (32, 512, 512)], + ], + cross_product_configs={ + 'channels_last': [True, False], + }, + tags=["long"], +) + + +for config in (config_short, config_long, config_3d, config_5d): op_bench.generate_pt_test(config, InterpolateBenchmark)