mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[operator_benchmark] Added channels last 3d option to interpolate test (#53117)
Summary: Description: - Added channels last 3d option to interpolate test - split config non-4d into two : 3d and 5d Pull Request resolved: https://github.com/pytorch/pytorch/pull/53117 Reviewed By: NicolasHug Differential Revision: D26754243 Pulled By: fmassa fbshipit-source-id: 49bbab3bb47de27790e39537d0fbeca0f01782c4
This commit is contained in:
parent
62d1cdd725
commit
cb1596a193
|
|
@ -10,7 +10,14 @@ class InterpolateBenchmark(op_bench.TorchBenchmarkBase):
|
|||
input_image = torch.randint(0, 256, size=input_size, dtype=torch.float, device='cpu',
|
||||
requires_grad=self.auto_set())
|
||||
if channels_last:
|
||||
input_image = input_image.contiguous(memory_format=torch.channels_last)
|
||||
if input_image.ndim == 4:
|
||||
input_image = input_image.contiguous(memory_format=torch.channels_last)
|
||||
elif input_image.ndim == 5:
|
||||
input_image = input_image.contiguous(memory_format=torch.channels_last_3d)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Can not set channels_last to the input of {input_image.ndim} dims"
|
||||
)
|
||||
|
||||
ndim_to_mode = {
|
||||
3: 'linear',
|
||||
|
|
@ -61,13 +68,10 @@ config_long = op_bench.config_list(
|
|||
)
|
||||
|
||||
|
||||
config_not_4d = op_bench.config_list(
|
||||
# no channels_last as it's only valid for 4D tensors
|
||||
config_3d = op_bench.config_list(
|
||||
# no channels_last for 3D tensors
|
||||
attr_names=["input_size", "output_size"],
|
||||
attrs=[
|
||||
[(1, 3, 16, 320, 320), (8, 256, 256)],
|
||||
[(1, 3, 16, 320, 320), (32, 512, 512)],
|
||||
|
||||
[(4, 512, 320), (256,)],
|
||||
[(4, 512, 320), (512,)],
|
||||
],
|
||||
|
|
@ -75,7 +79,20 @@ config_not_4d = op_bench.config_list(
|
|||
)
|
||||
|
||||
|
||||
for config in (config_short, config_long, config_not_4d):
|
||||
config_5d = op_bench.config_list(
|
||||
attr_names=["input_size", "output_size"],
|
||||
attrs=[
|
||||
[(1, 3, 16, 320, 320), (8, 256, 256)],
|
||||
[(1, 3, 16, 320, 320), (32, 512, 512)],
|
||||
],
|
||||
cross_product_configs={
|
||||
'channels_last': [True, False],
|
||||
},
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
for config in (config_short, config_long, config_3d, config_5d):
|
||||
op_bench.generate_pt_test(config, InterpolateBenchmark)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user