mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Check canary_models for models too in torchbench.py (#101081)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/101081 Approved by: https://github.com/desertfire
This commit is contained in:
parent
4eaaa08623
commit
ad070b6dfa
|
|
@ -270,10 +270,17 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
|||
is_training = self.args.training
|
||||
use_eval_mode = self.args.use_eval_mode
|
||||
dynamic_shapes = self.args.dynamic_shapes
|
||||
try:
|
||||
module = importlib.import_module(f"torchbenchmark.models.{model_name}")
|
||||
except ModuleNotFoundError:
|
||||
module = importlib.import_module(f"torchbenchmark.models.fb.{model_name}")
|
||||
candidates = [
|
||||
f"torchbenchmark.models.{model_name}",
|
||||
f"torchbenchmark.canary_models.{model_name}",
|
||||
f"torchbenchmark.models.fb.{model_name}",
|
||||
]
|
||||
for c in candidates:
|
||||
try:
|
||||
module = importlib.import_module(c)
|
||||
break
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
benchmark_cls = getattr(module, "Model", None)
|
||||
if not hasattr(benchmark_cls, "name"):
|
||||
benchmark_cls.name = model_name
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user