mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enables two ruff rules derived from pylint: * PLR1722 replaces any exit() calls with sys.exit(). exit() is only designed to be used in repl contexts as may not always be imported by default. This always use the version in the sys module which is better * PLW3301 replaces nested min / max calls with simplified versions (ie. `min(a, min(b, c))` => `min(a, b. c)`). The new version is more idiomatic and more efficient. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109461 Approved by: https://github.com/ezyang
104 lines
2.9 KiB
Python
Executable File
104 lines
2.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import argparse
|
|
import os
|
|
import sys
|
|
|
|
from typing import Set
|
|
|
|
|
|
# Note - hf and timm have their own version of this, torchbench does not
|
|
# TOOD(voz): Someday, consolidate all the files into one runner instead of a shim like this...
|
|
def model_names(filename: str) -> Set[str]:
|
|
names = set()
|
|
with open(filename) as fh:
|
|
lines = fh.readlines()
|
|
lines = [line.rstrip() for line in lines]
|
|
for line in lines:
|
|
line_parts = line.split(" ")
|
|
if len(line_parts) == 1:
|
|
line_parts = line.split(",")
|
|
model_name = line_parts[0]
|
|
names.add(model_name)
|
|
return names
|
|
|
|
|
|
TIMM_MODEL_NAMES = model_names(
|
|
os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
|
|
)
|
|
HF_MODELS_FILE_NAME = model_names(
|
|
os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt")
|
|
)
|
|
TORCHBENCH_MODELS_FILE_NAME = model_names(
|
|
os.path.join(os.path.dirname(__file__), "all_torchbench_models_list.txt")
|
|
)
|
|
|
|
# timm <> HF disjoint
|
|
assert TIMM_MODEL_NAMES.isdisjoint(HF_MODELS_FILE_NAME)
|
|
# timm <> torch disjoint
|
|
assert TIMM_MODEL_NAMES.isdisjoint(TORCHBENCH_MODELS_FILE_NAME)
|
|
# torch <> hf disjoint
|
|
assert TORCHBENCH_MODELS_FILE_NAME.isdisjoint(HF_MODELS_FILE_NAME)
|
|
|
|
|
|
def parse_args(args=None):
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--only",
|
|
help="""Run just one model from whichever model suite it belongs to. Or
|
|
specify the path and class name of the model in format like:
|
|
--only=path:<MODEL_FILE_PATH>,class:<CLASS_NAME>
|
|
|
|
Due to the fact that dynamo changes current working directory,
|
|
the path should be an absolute path.
|
|
|
|
The class should have a method get_example_inputs to return the inputs
|
|
for the model. An example looks like
|
|
```
|
|
class LinearModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
def get_example_inputs(self):
|
|
return (torch.randn(2, 10),)
|
|
```
|
|
""",
|
|
)
|
|
return parser.parse_known_args(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args, unknown = parse_args()
|
|
if args.only:
|
|
name = args.only
|
|
if name in TIMM_MODEL_NAMES:
|
|
import timm_models
|
|
|
|
timm_models.timm_main()
|
|
elif name in HF_MODELS_FILE_NAME:
|
|
import huggingface
|
|
|
|
huggingface.huggingface_main()
|
|
elif name in TORCHBENCH_MODELS_FILE_NAME:
|
|
import torchbench
|
|
|
|
torchbench.torchbench_main()
|
|
else:
|
|
print(f"Illegal model name? {name}")
|
|
sys.exit(-1)
|
|
else:
|
|
import torchbench
|
|
|
|
torchbench.torchbench_main()
|
|
|
|
import huggingface
|
|
|
|
huggingface.huggingface_main()
|
|
|
|
import timm_models
|
|
|
|
timm_models.timm_main()
|