mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR suggests adding some models to `cpu_skip_list` which are currently being run in TIMM and Torchbench. The suggested models takes a long time which leads to the benchmark runs being `timeout`. [benchmark runs for aarch64](https://github.com/pytorch/pytorch/actions/workflows/inductor-perf-test-nightly-aarch64.yml) • The issue stems from unoptimized groupwise convolution (BF16 /F16 dtype) kernels for aarch64 platforms , which significantly slow down execution leading to the timeout. **Action:** • An optimized BF16 groupwise convolution kernel is currently being developed in oneDNN, targeted for release in Q4 2025. To maintain dashboard consistency and signal clarity, I’ve skipped the affected tests in: * timm benchmarks * torchbench benchmarks As suggested, skip is applied at the CPU - arch level, explicitly branching for aarch64 and adding models which needs to be skipped. This keeps the logic clean, but: • An alternative considered was increasing shard counts for aarch64 runners, but given the known performance bottleneck, skipping avoids wasted compute cycles. Suggestions around this will be appreciated. Benchmark does not timeout after the suggested change: https://github.com/pytorch/pytorch/actions/runs/16447200138 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158842 Approved by: https://github.com/malfet
454 lines
13 KiB
Python
Executable File
454 lines
13 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import importlib
|
|
import logging
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import warnings
|
|
|
|
|
|
try:
|
|
from .common import BenchmarkRunner, download_retry_decorator, load_yaml_file, main
|
|
except ImportError:
|
|
from common import BenchmarkRunner, download_retry_decorator, load_yaml_file, main
|
|
|
|
import torch
|
|
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
|
|
from torch._dynamo.utils import clone_inputs
|
|
|
|
|
|
# Enable FX graph caching
|
|
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
|
|
torch._inductor.config.fx_graph_cache = True
|
|
|
|
|
|
def pip_install(package):
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
|
|
|
|
|
|
try:
|
|
importlib.import_module("timm")
|
|
except ModuleNotFoundError:
|
|
print("Installing PyTorch Image Models...")
|
|
pip_install("git+https://github.com/rwightman/pytorch-image-models")
|
|
finally:
|
|
from timm import __version__ as timmversion
|
|
from timm.data import resolve_data_config
|
|
from timm.models import create_model
|
|
|
|
TIMM_MODELS = {}
|
|
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
|
|
|
|
with open(filename) as fh:
|
|
lines = fh.readlines()
|
|
lines = [line.rstrip() for line in lines]
|
|
for line in lines:
|
|
model_name, batch_size = line.split(" ")
|
|
TIMM_MODELS[model_name] = int(batch_size)
|
|
|
|
|
|
# TODO - Figure out the reason of cold start memory spike
|
|
|
|
BATCH_SIZE_DIVISORS = {
|
|
"beit_base_patch16_224": 2,
|
|
"convit_base": 2,
|
|
"convmixer_768_32": 2,
|
|
"convnext_base": 2,
|
|
"cspdarknet53": 2,
|
|
"deit_base_distilled_patch16_224": 2,
|
|
"gluon_xception65": 2,
|
|
"mobilevit_s": 2,
|
|
"pnasnet5large": 2,
|
|
"poolformer_m36": 2,
|
|
"resnest101e": 2,
|
|
"swin_base_patch4_window7_224": 2,
|
|
"swsl_resnext101_32x16d": 2,
|
|
"vit_base_patch16_224": 2,
|
|
"volo_d1_224": 2,
|
|
"jx_nest_base": 4,
|
|
}
|
|
|
|
REQUIRE_HIGHER_TOLERANCE = {
|
|
"crossvit_9_240",
|
|
"fbnetv3_b",
|
|
"gmixer_24_224",
|
|
"hrnet_w18",
|
|
"inception_v3",
|
|
"mixer_b16_224",
|
|
"mobilenetv3_large_100",
|
|
"sebotnet33ts_256",
|
|
"selecsls42b",
|
|
"convnext_base",
|
|
"cait_m36_384",
|
|
}
|
|
|
|
REQUIRE_HIGHER_TOLERANCE_AMP = {
|
|
"poolformer_m36",
|
|
}
|
|
|
|
REQUIRE_EVEN_HIGHER_TOLERANCE = {
|
|
"levit_128",
|
|
"sebotnet33ts_256",
|
|
"beit_base_patch16_224",
|
|
"cspdarknet53",
|
|
}
|
|
|
|
# These models need higher tolerance in MaxAutotune mode
|
|
REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE = {
|
|
"gluon_inception_v3",
|
|
}
|
|
|
|
REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = {
|
|
"adv_inception_v3",
|
|
"botnet26t_256",
|
|
"gluon_inception_v3",
|
|
"selecsls42b",
|
|
"swsl_resnext101_32x16d",
|
|
}
|
|
|
|
SCALED_COMPUTE_LOSS = {
|
|
"ese_vovnet19b_dw",
|
|
"fbnetc_100",
|
|
"mnasnet_100",
|
|
"mobilevit_s",
|
|
"sebotnet33ts_256",
|
|
}
|
|
|
|
FORCE_AMP_FOR_FP16_BF16_MODELS = {
|
|
"convit_base",
|
|
"xcit_large_24_p8_224",
|
|
}
|
|
|
|
SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = {
|
|
"xcit_large_24_p8_224",
|
|
}
|
|
|
|
REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR = {
|
|
"inception_v3",
|
|
"mobilenetv3_large_100",
|
|
"cspdarknet53",
|
|
"gluon_inception_v3",
|
|
"cait_m36_384",
|
|
}
|
|
|
|
|
|
def refresh_model_names():
|
|
import glob
|
|
|
|
from timm.models import list_models
|
|
|
|
def read_models_from_docs():
|
|
models = set()
|
|
# TODO - set the path to pytorch-image-models repo
|
|
for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):
|
|
with open(fn) as f:
|
|
while True:
|
|
line = f.readline()
|
|
if not line:
|
|
break
|
|
if not line.startswith("model = timm.create_model("):
|
|
continue
|
|
|
|
model = line.split("'")[1]
|
|
# print(model)
|
|
models.add(model)
|
|
return models
|
|
|
|
def get_family_name(name):
|
|
known_families = [
|
|
"darknet",
|
|
"densenet",
|
|
"dla",
|
|
"dpn",
|
|
"ecaresnet",
|
|
"halo",
|
|
"regnet",
|
|
"efficientnet",
|
|
"deit",
|
|
"mobilevit",
|
|
"mnasnet",
|
|
"convnext",
|
|
"resnet",
|
|
"resnest",
|
|
"resnext",
|
|
"selecsls",
|
|
"vgg",
|
|
"xception",
|
|
]
|
|
|
|
for known_family in known_families:
|
|
if known_family in name:
|
|
return known_family
|
|
|
|
if name.startswith("gluon_"):
|
|
return "gluon_" + name.split("_")[1]
|
|
return name.split("_")[0]
|
|
|
|
def populate_family(models):
|
|
family = {}
|
|
for model_name in models:
|
|
family_name = get_family_name(model_name)
|
|
if family_name not in family:
|
|
family[family_name] = []
|
|
family[family_name].append(model_name)
|
|
return family
|
|
|
|
docs_models = read_models_from_docs()
|
|
all_models = list_models(pretrained=True, exclude_filters=["*in21k"])
|
|
|
|
all_models_family = populate_family(all_models)
|
|
docs_models_family = populate_family(docs_models)
|
|
|
|
for key in docs_models_family:
|
|
del all_models_family[key]
|
|
|
|
chosen_models = set()
|
|
chosen_models.update(value[0] for value in docs_models_family.values())
|
|
|
|
chosen_models.update(value[0] for key, value in all_models_family.items())
|
|
|
|
filename = "timm_models_list.txt"
|
|
if os.path.exists("benchmarks"):
|
|
filename = "benchmarks/" + filename
|
|
with open(filename, "w") as fw:
|
|
for model_name in sorted(chosen_models):
|
|
fw.write(model_name + "\n")
|
|
|
|
|
|
class TimmRunner(BenchmarkRunner):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.suite_name = "timm_models"
|
|
|
|
@property
|
|
def _config(self):
|
|
return load_yaml_file("timm_models.yaml")
|
|
|
|
@property
|
|
def _skip(self):
|
|
return self._config["skip"]
|
|
|
|
@property
|
|
def skip_models_for_cpu(self):
|
|
return self._skip["device"]["cpu"]
|
|
|
|
@property
|
|
def skip_models_for_cpu_aarch64(self):
|
|
return self._skip["device"]["cpu_aarch64"]
|
|
|
|
@property
|
|
def skip_models(self):
|
|
return self._skip["all"]
|
|
|
|
@property
|
|
def force_amp_for_fp16_bf16_models(self):
|
|
return FORCE_AMP_FOR_FP16_BF16_MODELS
|
|
|
|
@property
|
|
def force_fp16_for_bf16_models(self):
|
|
return set()
|
|
|
|
@property
|
|
def get_output_amp_train_process_func(self):
|
|
return {}
|
|
|
|
@property
|
|
def skip_accuracy_check_as_eager_non_deterministic(self):
|
|
if self.args.accuracy and self.args.training:
|
|
return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS
|
|
return set()
|
|
|
|
@property
|
|
def guard_on_nn_module_models(self):
|
|
return {
|
|
"convit_base",
|
|
}
|
|
|
|
@property
|
|
def inline_inbuilt_nn_modules_models(self):
|
|
return {
|
|
"lcnet_050",
|
|
}
|
|
|
|
@download_retry_decorator
|
|
def _download_model(self, model_name):
|
|
model = create_model(
|
|
model_name,
|
|
in_chans=3,
|
|
scriptable=False,
|
|
num_classes=None,
|
|
drop_rate=0.0,
|
|
drop_path_rate=None,
|
|
drop_block_rate=None,
|
|
pretrained=True,
|
|
)
|
|
return model
|
|
|
|
def load_model(
|
|
self,
|
|
device,
|
|
model_name,
|
|
batch_size=None,
|
|
extra_args=None,
|
|
):
|
|
if self.args.enable_activation_checkpointing:
|
|
raise NotImplementedError(
|
|
"Activation checkpointing not implemented for Timm models"
|
|
)
|
|
|
|
is_training = self.args.training
|
|
use_eval_mode = self.args.use_eval_mode
|
|
|
|
channels_last = self._args.channels_last
|
|
model = self._download_model(model_name)
|
|
|
|
if model is None:
|
|
raise RuntimeError(f"Failed to load model '{model_name}'")
|
|
model.to(
|
|
device=device,
|
|
memory_format=torch.channels_last if channels_last else None,
|
|
)
|
|
|
|
self.num_classes = model.num_classes
|
|
|
|
data_config = resolve_data_config(
|
|
vars(self._args) if timmversion >= "0.8.0" else self._args,
|
|
model=model,
|
|
use_test_size=not is_training,
|
|
)
|
|
input_size = data_config["input_size"]
|
|
recorded_batch_size = TIMM_MODELS[model_name]
|
|
|
|
if model_name in BATCH_SIZE_DIVISORS:
|
|
recorded_batch_size = max(
|
|
int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1
|
|
)
|
|
batch_size = batch_size or recorded_batch_size
|
|
|
|
torch.manual_seed(1337)
|
|
input_tensor = torch.randint(
|
|
256, size=(batch_size,) + input_size, device=device
|
|
).to(dtype=torch.float32)
|
|
mean = torch.mean(input_tensor)
|
|
std_dev = torch.std(input_tensor)
|
|
example_inputs = (input_tensor - mean) / std_dev
|
|
|
|
if channels_last:
|
|
example_inputs = example_inputs.contiguous(
|
|
memory_format=torch.channels_last
|
|
)
|
|
example_inputs = [
|
|
example_inputs,
|
|
]
|
|
self.target = self._gen_target(batch_size, device)
|
|
|
|
self.loss = torch.nn.CrossEntropyLoss().to(device)
|
|
|
|
if model_name in SCALED_COMPUTE_LOSS:
|
|
self.compute_loss = self.scaled_compute_loss
|
|
|
|
if is_training and not use_eval_mode:
|
|
model.train()
|
|
else:
|
|
model.eval()
|
|
|
|
self.validate_model(model, example_inputs)
|
|
|
|
return device, model_name, model, example_inputs, batch_size
|
|
|
|
def iter_model_names(self, args):
|
|
# for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):
|
|
model_names = sorted(TIMM_MODELS.keys())
|
|
start, end = self.get_benchmark_indices(len(model_names))
|
|
for index, model_name in enumerate(model_names):
|
|
if index < start or index >= end:
|
|
continue
|
|
if (
|
|
not re.search("|".join(args.filter), model_name, re.IGNORECASE)
|
|
or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
|
|
or model_name in args.exclude_exact
|
|
or model_name in self.skip_models
|
|
):
|
|
continue
|
|
|
|
yield model_name
|
|
|
|
def pick_grad(self, name, is_training):
|
|
if is_training:
|
|
return torch.enable_grad()
|
|
else:
|
|
return torch.no_grad()
|
|
|
|
def use_larger_multiplier_for_smaller_tensor(self, name):
|
|
return name in REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR
|
|
|
|
def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
|
|
cosine = self.args.cosine
|
|
tolerance = 1e-3
|
|
|
|
if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING:
|
|
# the conv-batchnorm fusion used under freezing may cause relatively
|
|
# large numerical difference. We need are larger tolerance.
|
|
# Check https://github.com/pytorch/pytorch/issues/120545 for context
|
|
tolerance = 8 * 1e-2
|
|
|
|
if is_training:
|
|
from torch._inductor import config as inductor_config
|
|
|
|
if name in REQUIRE_EVEN_HIGHER_TOLERANCE or (
|
|
inductor_config.max_autotune
|
|
and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE
|
|
):
|
|
tolerance = 8 * 1e-2
|
|
elif name in REQUIRE_HIGHER_TOLERANCE or (
|
|
self.args.amp and name in REQUIRE_HIGHER_TOLERANCE_AMP
|
|
):
|
|
tolerance = 4 * 1e-2
|
|
else:
|
|
tolerance = 1e-2
|
|
return tolerance, cosine
|
|
|
|
def _gen_target(self, batch_size, device):
|
|
return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
|
|
self.num_classes
|
|
)
|
|
|
|
def compute_loss(self, pred):
|
|
# High loss values make gradient checking harder, as small changes in
|
|
# accumulation order upsets accuracy checks.
|
|
return reduce_to_scalar_loss(pred)
|
|
|
|
def scaled_compute_loss(self, pred):
|
|
# Loss values need zoom out further.
|
|
return reduce_to_scalar_loss(pred) / 1000.0
|
|
|
|
def forward_pass(self, mod, inputs, collect_outputs=True):
|
|
with self.autocast(**self.autocast_arg):
|
|
return mod(*inputs)
|
|
|
|
def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
|
|
cloned_inputs = clone_inputs(inputs)
|
|
self.optimizer_zero_grad(mod)
|
|
with self.autocast(**self.autocast_arg):
|
|
pred = mod(*cloned_inputs)
|
|
if isinstance(pred, tuple):
|
|
pred = pred[0]
|
|
loss = self.compute_loss(pred)
|
|
self.grad_scaler.scale(loss).backward()
|
|
self.optimizer_step()
|
|
if collect_outputs:
|
|
return collect_results(mod, None, loss, cloned_inputs)
|
|
return None
|
|
|
|
|
|
def timm_main():
|
|
logging.basicConfig(level=logging.WARNING)
|
|
warnings.filterwarnings("ignore")
|
|
main(TimmRunner())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
timm_main()
|