mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[torchbench] fix dynamic_shapes spec for moco (#148772)
Fixes https://github.com/pytorch/pytorch/issues/148333 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148772 Approved by: https://github.com/yushangdi, https://github.com/desertfire
This commit is contained in:
parent
dbea13ed45
commit
e0e8639a10
|
|
@ -1408,7 +1408,7 @@ class AOTInductorModelCache:
|
|||
def load(cls, model, example_inputs):
|
||||
import torch._inductor
|
||||
import torch.export._trace
|
||||
from torch.export.dynamic_shapes import _tree_map_with_path
|
||||
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
|
||||
|
||||
key = weakref.ref(model)
|
||||
if key not in cls.cache:
|
||||
|
|
@ -1428,7 +1428,7 @@ class AOTInductorModelCache:
|
|||
else:
|
||||
_register_dataclass_output_as_pytree(example_outputs)
|
||||
|
||||
combined_args = tuple(example_args) + tuple(example_kwargs.values())
|
||||
combined_args = _combine_args(model, example_args, example_kwargs)
|
||||
dynamic_shapes = _tree_map_with_path(
|
||||
_produce_dynamic_shapes_for_export, combined_args
|
||||
)
|
||||
|
|
@ -1449,13 +1449,13 @@ class AOTInductorModelCache:
|
|||
|
||||
|
||||
def export(model, example_inputs):
|
||||
from torch.export.dynamic_shapes import _tree_map_with_path
|
||||
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
|
||||
|
||||
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
|
||||
example_outputs = model(*example_args, **example_kwargs)
|
||||
_register_dataclass_output_as_pytree(example_outputs)
|
||||
|
||||
combined_args = tuple(example_args) + tuple(example_kwargs.values())
|
||||
combined_args = _combine_args(model, example_args, example_kwargs)
|
||||
dynamic_shapes = _tree_map_with_path(
|
||||
_produce_dynamic_shapes_for_export, combined_args
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user