Check canary_models for models too in torchbench.py (#101081)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101081
Approved by: https://github.com/desertfire
This commit is contained in:
Edward Z. Yang 2023-05-11 00:34:16 +00:00 committed by PyTorch MergeBot
parent 4eaaa08623
commit ad070b6dfa

View File

@ -270,10 +270,17 @@ class TorchBenchmarkRunner(BenchmarkRunner):
is_training = self.args.training
use_eval_mode = self.args.use_eval_mode
dynamic_shapes = self.args.dynamic_shapes
try:
module = importlib.import_module(f"torchbenchmark.models.{model_name}")
except ModuleNotFoundError:
module = importlib.import_module(f"torchbenchmark.models.fb.{model_name}")
candidates = [
f"torchbenchmark.models.{model_name}",
f"torchbenchmark.canary_models.{model_name}",
f"torchbenchmark.models.fb.{model_name}",
]
for c in candidates:
try:
module = importlib.import_module(c)
break
except ModuleNotFoundError:
pass
benchmark_cls = getattr(module, "Model", None)
if not hasattr(benchmark_cls, "name"):
benchmark_cls.name = model_name