[torchbench] fix dynamic_shapes spec for moco (#148772)

Fixes https://github.com/pytorch/pytorch/issues/148333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148772
Approved by: https://github.com/yushangdi, https://github.com/desertfire
This commit is contained in:
Pian Pawakapan 2025-03-18 18:16:51 +00:00 committed by PyTorch MergeBot
parent dbea13ed45
commit e0e8639a10

View File

@ -1408,7 +1408,7 @@ class AOTInductorModelCache:
def load(cls, model, example_inputs): def load(cls, model, example_inputs):
import torch._inductor import torch._inductor
import torch.export._trace import torch.export._trace
from torch.export.dynamic_shapes import _tree_map_with_path from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
key = weakref.ref(model) key = weakref.ref(model)
if key not in cls.cache: if key not in cls.cache:
@ -1428,7 +1428,7 @@ class AOTInductorModelCache:
else: else:
_register_dataclass_output_as_pytree(example_outputs) _register_dataclass_output_as_pytree(example_outputs)
combined_args = tuple(example_args) + tuple(example_kwargs.values()) combined_args = _combine_args(model, example_args, example_kwargs)
dynamic_shapes = _tree_map_with_path( dynamic_shapes = _tree_map_with_path(
_produce_dynamic_shapes_for_export, combined_args _produce_dynamic_shapes_for_export, combined_args
) )
@ -1449,13 +1449,13 @@ class AOTInductorModelCache:
def export(model, example_inputs): def export(model, example_inputs):
from torch.export.dynamic_shapes import _tree_map_with_path from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
example_args, example_kwargs = _normalize_bench_inputs(example_inputs) example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
example_outputs = model(*example_args, **example_kwargs) example_outputs = model(*example_args, **example_kwargs)
_register_dataclass_output_as_pytree(example_outputs) _register_dataclass_output_as_pytree(example_outputs)
combined_args = tuple(example_args) + tuple(example_kwargs.values()) combined_args = _combine_args(model, example_args, example_kwargs)
dynamic_shapes = _tree_map_with_path( dynamic_shapes = _tree_map_with_path(
_produce_dynamic_shapes_for_export, combined_args _produce_dynamic_shapes_for_export, combined_args
) )