pytorch/benchmarks/dynamo/timm_models.py
Bin Bao 57b36bf353 Bring back TIMM model inductor CI test (#87730)
Summary: https://github.com/pytorch/pytorch/pull/87588 has solved the
inductor compilation speed regression, so we can try to run TIMM models
with fewer shards and also enable pretained model downloading which
should resolve the flakyness we have seen previously.

cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87730
Approved by: https://github.com/anijain2305
2022-10-26 00:15:35 +00:00

335 lines
9.8 KiB
Python
Executable File

#!/usr/bin/env python3
import importlib
import logging
import os
import re
import subprocess
import sys
import time
import warnings
import torch
from common import BenchmarkRunner, main
from torch._dynamo.testing import collect_results
from torch._dynamo.utils import clone_inputs
def pip_install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
try:
importlib.import_module("timm")
except ModuleNotFoundError:
print("Installing Pytorch Image Models...")
pip_install("git+https://github.com/rwightman/pytorch-image-models")
finally:
from timm.data import resolve_data_config
from timm.models import create_model
TIMM_MODELS = dict()
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
with open(filename, "r") as fh:
lines = fh.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
model_name, batch_size = line.split(" ")
TIMM_MODELS[model_name] = int(batch_size)
# TODO - Figure out the reason of cold start memory spike
BATCH_SIZE_DIVISORS = {
"beit_base_patch16_224": 2,
"cait_m36_384": 4,
"convit_base": 4,
"convmixer_768_32": 2,
"convnext_base": 4,
"crossvit_9_240": 2,
"cspdarknet53": 2,
"deit_base_distilled_patch16_224": 2,
"dla102": 2,
"dpn107": 2,
"eca_botnext26ts_256": 2,
"eca_halonext26ts": 2,
"gluon_senet154": 2,
"gluon_xception65": 2,
"gmixer_24_224": 2,
"gmlp_s16_224": 2,
"hrnet_w18": 64,
"jx_nest_base": 4,
"mixer_b16_224": 2,
"mixnet_l": 2,
"mobilevit_s": 4,
"nfnet_l0": 2,
"pit_b_224": 2,
"pnasnet5large": 2,
"poolformer_m36": 2,
"res2net101_26w_4s": 2,
"res2net50_14w_8s": 64,
"res2next50": 64,
"resnest101e": 4,
"sebotnet33ts_256": 2,
"swin_base_patch4_window7_224": 2,
"swsl_resnext101_32x16d": 2,
"tf_mixnet_l": 2,
"tnt_s_patch16_224": 2,
"twins_pcpvt_base": 4,
"vit_base_patch16_224": 2,
"volo_d1_224": 2,
"xcit_large_24_p8_224": 4,
}
REQUIRE_HIGHER_TOLERANCE = set()
SKIP = {
# Unusual training setup
"levit_128",
}
def refresh_model_names():
import glob
from timm.models import list_models
def read_models_from_docs():
models = set()
# TODO - set the path to pytorch-image-models repo
for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):
with open(fn, "r") as f:
while True:
line = f.readline()
if not line:
break
if not line.startswith("model = timm.create_model("):
continue
model = line.split("'")[1]
# print(model)
models.add(model)
return models
def get_family_name(name):
known_families = [
"darknet",
"densenet",
"dla",
"dpn",
"ecaresnet",
"halo",
"regnet",
"efficientnet",
"deit",
"mobilevit",
"mnasnet",
"convnext",
"resnet",
"resnest",
"resnext",
"selecsls",
"vgg",
"xception",
]
for known_family in known_families:
if known_family in name:
return known_family
if name.startswith("gluon_"):
return "gluon_" + name.split("_")[1]
return name.split("_")[0]
def populate_family(models):
family = dict()
for model_name in models:
family_name = get_family_name(model_name)
if family_name not in family:
family[family_name] = []
family[family_name].append(model_name)
return family
docs_models = read_models_from_docs()
all_models = list_models(pretrained=True, exclude_filters=["*in21k"])
all_models_family = populate_family(all_models)
docs_models_family = populate_family(docs_models)
# print(docs_models_family.keys())
for key in docs_models_family:
del all_models_family[key]
chosen_models = set()
for value in docs_models_family.values():
chosen_models.add(value[0])
for key, value in all_models_family.items():
chosen_models.add(value[0])
filename = "timm_models_list.txt"
if os.path.exists("benchmarks"):
filename = "benchmarks/" + filename
with open(filename, "w") as fw:
for model_name in sorted(chosen_models):
fw.write(model_name + "\n")
class TimmRunnner(BenchmarkRunner):
def __init__(self):
super(TimmRunnner, self).__init__()
self.suite_name = "timm_models"
def load_model(
self,
device,
model_name,
batch_size=None,
):
is_training = self.args.training
use_eval_mode = self.args.use_eval_mode
# _, model_dtype, data_dtype = self.resolve_precision()
channels_last = self._args.channels_last
retries = 1
success = False
while not success and retries < 4:
try:
model = create_model(
model_name,
in_chans=3,
scriptable=False,
num_classes=None,
drop_rate=0.0,
drop_path_rate=None,
drop_block_rate=None,
pretrained=True,
# global_pool=kwargs.pop('gp', 'fast'),
# num_classes=kwargs.pop('num_classes', None),
# drop_rate=kwargs.pop('drop', 0.),
# drop_path_rate=kwargs.pop('drop_path', None),
# drop_block_rate=kwargs.pop('drop_block', None),
)
success = True
except Exception:
wait = retries * 30
time.sleep(wait)
retries += 1
model.to(
device=device,
memory_format=torch.channels_last if channels_last else None,
)
self.num_classes = model.num_classes
data_config = resolve_data_config(
self._args, model=model, use_test_size=not is_training
)
input_size = data_config["input_size"]
recorded_batch_size = TIMM_MODELS[model_name]
recorded_batch_size = max(
int(recorded_batch_size / BATCH_SIZE_DIVISORS.get(model_name, 1)), 1
)
batch_size = batch_size or recorded_batch_size
# example_inputs = torch.randn(
# (batch_size,) + input_size, device=device, dtype=data_dtype
# )
torch.manual_seed(1337)
input_tensor = torch.randint(
256, size=(batch_size,) + input_size, device=device
).to(dtype=torch.float32)
mean = torch.mean(input_tensor)
std_dev = torch.std(input_tensor)
example_inputs = (input_tensor - mean) / std_dev
if channels_last:
example_inputs = example_inputs.contiguous(
memory_format=torch.channels_last
)
example_inputs = [
example_inputs,
]
self.target = self._gen_target(batch_size, device)
self.loss = torch.nn.CrossEntropyLoss().to(device)
if is_training and not use_eval_mode:
model.train()
else:
model.eval()
self.init_optimizer(device, model.parameters())
self.validate_model(model, example_inputs)
return device, model_name, model, example_inputs, batch_size
def iter_model_names(self, args):
# for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):
model_names = sorted(TIMM_MODELS.keys())
start, end = self.get_benchmark_indices(len(model_names))
for index, model_name in enumerate(model_names):
if index < start or index >= end:
continue
if (
not re.search("|".join(args.filter), model_name, re.I)
or re.search("|".join(args.exclude), model_name, re.I)
or model_name in self.skip_models
):
continue
yield model_name
def pick_grad(self, name, is_training):
if is_training:
return torch.enable_grad()
else:
return torch.no_grad()
def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
cosine = self.args.cosine
tolerance = 1e-3
if is_training:
if REQUIRE_HIGHER_TOLERANCE:
tolerance = 2 * 1e-2
else:
tolerance = 1e-2
return tolerance, cosine
def _gen_target(self, batch_size, device):
# return torch.ones((batch_size,) + (), device=device, dtype=torch.long)
return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
self.num_classes
)
def compute_loss(self, pred):
# High loss values make gradient checking harder, as small changes in
# accumulation order upsets accuracy checks.
return self.loss(pred, self.target) / 10.0
def forward_pass(self, mod, inputs, collect_outputs=True):
return mod(*inputs)
def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
cloned_inputs = clone_inputs(inputs)
self.optimizer_zero_grad()
with self.autocast():
pred = mod(*cloned_inputs)
if isinstance(pred, tuple):
pred = pred[0]
loss = self.compute_loss(pred)
self.grad_scaler.scale(loss).backward()
self.optimizer_step()
if collect_outputs:
return collect_results(mod, pred, loss, cloned_inputs)
return None
if __name__ == "__main__":
logging.basicConfig(level=logging.WARNING)
warnings.filterwarnings("ignore")
main(TimmRunnner())