mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[torchbench] fix dynamic_shapes spec for moco (#148772)
Fixes https://github.com/pytorch/pytorch/issues/148333 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148772 Approved by: https://github.com/yushangdi, https://github.com/desertfire
This commit is contained in:
parent
dbea13ed45
commit
e0e8639a10
|
|
@ -1408,7 +1408,7 @@ class AOTInductorModelCache:
|
||||||
def load(cls, model, example_inputs):
|
def load(cls, model, example_inputs):
|
||||||
import torch._inductor
|
import torch._inductor
|
||||||
import torch.export._trace
|
import torch.export._trace
|
||||||
from torch.export.dynamic_shapes import _tree_map_with_path
|
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
|
||||||
|
|
||||||
key = weakref.ref(model)
|
key = weakref.ref(model)
|
||||||
if key not in cls.cache:
|
if key not in cls.cache:
|
||||||
|
|
@ -1428,7 +1428,7 @@ class AOTInductorModelCache:
|
||||||
else:
|
else:
|
||||||
_register_dataclass_output_as_pytree(example_outputs)
|
_register_dataclass_output_as_pytree(example_outputs)
|
||||||
|
|
||||||
combined_args = tuple(example_args) + tuple(example_kwargs.values())
|
combined_args = _combine_args(model, example_args, example_kwargs)
|
||||||
dynamic_shapes = _tree_map_with_path(
|
dynamic_shapes = _tree_map_with_path(
|
||||||
_produce_dynamic_shapes_for_export, combined_args
|
_produce_dynamic_shapes_for_export, combined_args
|
||||||
)
|
)
|
||||||
|
|
@ -1449,13 +1449,13 @@ class AOTInductorModelCache:
|
||||||
|
|
||||||
|
|
||||||
def export(model, example_inputs):
|
def export(model, example_inputs):
|
||||||
from torch.export.dynamic_shapes import _tree_map_with_path
|
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
|
||||||
|
|
||||||
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
|
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
|
||||||
example_outputs = model(*example_args, **example_kwargs)
|
example_outputs = model(*example_args, **example_kwargs)
|
||||||
_register_dataclass_output_as_pytree(example_outputs)
|
_register_dataclass_output_as_pytree(example_outputs)
|
||||||
|
|
||||||
combined_args = tuple(example_args) + tuple(example_kwargs.values())
|
combined_args = _combine_args(model, example_args, example_kwargs)
|
||||||
dynamic_shapes = _tree_map_with_path(
|
dynamic_shapes = _tree_map_with_path(
|
||||||
_produce_dynamic_shapes_for_export, combined_args
|
_produce_dynamic_shapes_for_export, combined_args
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user