mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Check canary_models for models too in torchbench.py (#101081)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/101081 Approved by: https://github.com/desertfire
This commit is contained in:
parent
4eaaa08623
commit
ad070b6dfa
|
|
@ -270,10 +270,17 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
||||||
is_training = self.args.training
|
is_training = self.args.training
|
||||||
use_eval_mode = self.args.use_eval_mode
|
use_eval_mode = self.args.use_eval_mode
|
||||||
dynamic_shapes = self.args.dynamic_shapes
|
dynamic_shapes = self.args.dynamic_shapes
|
||||||
|
candidates = [
|
||||||
|
f"torchbenchmark.models.{model_name}",
|
||||||
|
f"torchbenchmark.canary_models.{model_name}",
|
||||||
|
f"torchbenchmark.models.fb.{model_name}",
|
||||||
|
]
|
||||||
|
for c in candidates:
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(f"torchbenchmark.models.{model_name}")
|
module = importlib.import_module(c)
|
||||||
|
break
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
module = importlib.import_module(f"torchbenchmark.models.fb.{model_name}")
|
pass
|
||||||
benchmark_cls = getattr(module, "Model", None)
|
benchmark_cls = getattr(module, "Model", None)
|
||||||
if not hasattr(benchmark_cls, "name"):
|
if not hasattr(benchmark_cls, "name"):
|
||||||
benchmark_cls.name = model_name
|
benchmark_cls.name = model_name
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user