pytorch/benchmarks/dynamo/benchmarks.py
Aaron Gokaslan 6d725e7d66 [BE]: enable ruff rules PLR1722 and PLW3301 (#109461)
Enables two ruff rules derived from pylint:
* PLR1722 replaces any exit() calls with sys.exit(). exit() is only designed to be used in repl contexts as may not always be imported by default. This always use the version in the sys module which is better
* PLW3301 replaces nested min / max calls with simplified versions (ie. `min(a, min(b, c))` => `min(a, b. c)`). The new version is more idiomatic and more efficient.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109461
Approved by: https://github.com/ezyang
2023-09-18 02:07:21 +00:00

104 lines
2.9 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import os
import sys
from typing import Set
# Note - hf and timm have their own version of this, torchbench does not
# TOOD(voz): Someday, consolidate all the files into one runner instead of a shim like this...
def model_names(filename: str) -> Set[str]:
names = set()
with open(filename) as fh:
lines = fh.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
line_parts = line.split(" ")
if len(line_parts) == 1:
line_parts = line.split(",")
model_name = line_parts[0]
names.add(model_name)
return names
TIMM_MODEL_NAMES = model_names(
os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
)
HF_MODELS_FILE_NAME = model_names(
os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt")
)
TORCHBENCH_MODELS_FILE_NAME = model_names(
os.path.join(os.path.dirname(__file__), "all_torchbench_models_list.txt")
)
# timm <> HF disjoint
assert TIMM_MODEL_NAMES.isdisjoint(HF_MODELS_FILE_NAME)
# timm <> torch disjoint
assert TIMM_MODEL_NAMES.isdisjoint(TORCHBENCH_MODELS_FILE_NAME)
# torch <> hf disjoint
assert TORCHBENCH_MODELS_FILE_NAME.isdisjoint(HF_MODELS_FILE_NAME)
def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"--only",
help="""Run just one model from whichever model suite it belongs to. Or
specify the path and class name of the model in format like:
--only=path:<MODEL_FILE_PATH>,class:<CLASS_NAME>
Due to the fact that dynamo changes current working directory,
the path should be an absolute path.
The class should have a method get_example_inputs to return the inputs
for the model. An example looks like
```
class LinearModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
def get_example_inputs(self):
return (torch.randn(2, 10),)
```
""",
)
return parser.parse_known_args(args)
if __name__ == "__main__":
args, unknown = parse_args()
if args.only:
name = args.only
if name in TIMM_MODEL_NAMES:
import timm_models
timm_models.timm_main()
elif name in HF_MODELS_FILE_NAME:
import huggingface
huggingface.huggingface_main()
elif name in TORCHBENCH_MODELS_FILE_NAME:
import torchbench
torchbench.torchbench_main()
else:
print(f"Illegal model name? {name}")
sys.exit(-1)
else:
import torchbench
torchbench.torchbench_main()
import huggingface
huggingface.huggingface_main()
import timm_models
timm_models.timm_main()