pytorch/benchmarks/dynamo/torchbench.py
Will Constable 6eab5e88c8 Graph-break on allowed modules if they have hooks (#97184)
Allowed modules are stuck into dynamo's fx graph as call_module
nodes, without dynamo doing any tracing of the module.  This means
during AOT trace time, hooks will fire during tracing when the
call_module is executed, but the hooks themselves will disappear
after that and not be present in the compiled program.
  (worse, if they performed any tensor operations, those would get
   traced so you could end up with part of the hook's functionality).

To circumvent this, there are two options for 'allowed modules' with hooks.
1) don't treat them as 'allowed' - trace into them
2) graph-break, so the module is no longer part of the dynamo trace at all

(1) will fail for users that opted into allowed modules becuase they know
    their module has problems being traced by dynamo.
(2) causes graph breaks on common modules such as nn.Linear, just because they
    are marked as 'allowed'.

It would help matters if we could differentiate between types of allowed modules
  (A) allowed to avoid overheads - used for common ops like nn.Linear
  (B) allowed to avoid dynamo graphbreaks caused by unsupported code

Ideally, we'd use method (1) for group (A) and (2) for (B).

For now, graph-break on all cases of allowed modules.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97184
Approved by: https://github.com/jansel
2023-04-15 01:46:15 +00:00

403 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
import gc
import importlib
import logging
import os
import re
import sys
import warnings
from os.path import abspath, exists
import torch
try:
from .common import BenchmarkRunner, main
except ImportError:
from common import BenchmarkRunner, main
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
from torch._dynamo.utils import clone_inputs
# We are primarily interested in tf32 datatype
torch.backends.cuda.matmul.allow_tf32 = True
def setup_torchbench_cwd():
original_dir = abspath(os.getcwd())
os.environ["KALDI_ROOT"] = "/tmp" # avoids some spam
for torchbench_dir in (
"./torchbenchmark",
"../torchbenchmark",
"../torchbench",
"../benchmark",
"../../torchbenchmark",
"../../torchbench",
"../../benchmark",
):
if exists(torchbench_dir):
break
if exists(torchbench_dir):
torchbench_dir = abspath(torchbench_dir)
os.chdir(torchbench_dir)
sys.path.append(torchbench_dir)
return original_dir
# Some models have large dataset that doesn't fit in memory. Lower the batch
# size to test the accuracy.
USE_SMALL_BATCH_SIZE = {
"demucs": 4,
"dlrm": 1024,
"densenet121": 4,
"hf_Reformer": 4,
"timm_efficientdet": 1,
}
DETECTRON2_MODELS = {
"detectron2_fasterrcnn_r_101_c4",
"detectron2_fasterrcnn_r_101_dc5",
"detectron2_fasterrcnn_r_101_fpn",
"detectron2_fasterrcnn_r_50_c4",
"detectron2_fasterrcnn_r_50_dc5",
"detectron2_fasterrcnn_r_50_fpn",
"detectron2_maskrcnn_r_101_c4",
"detectron2_maskrcnn_r_101_fpn",
"detectron2_maskrcnn_r_50_fpn",
}
SKIP = {
# https://github.com/pytorch/torchdynamo/issues/101
"detectron2_maskrcnn",
# https://github.com/pytorch/torchdynamo/issues/145
"fambench_xlmr",
# https://github.com/pytorch/pytorch/issues/99201
"opacus_cifar10",
}
SKIP_FOR_CUDA = {
"gat", # only works on CPU
"gcn", # only works on CPU
"sage", # only works on CPU
}
# Additional models that are skipped in training
SKIP_TRAIN = {
# not designed for training
"pyhpc_equation_of_state",
"pyhpc_isoneutral_mixing",
"pyhpc_turbulent_kinetic_energy",
# Unusual training setup
"opacus_cifar10",
"maml",
# segfault: Internal Triton PTX codegen error
"timm_efficientdet",
}
SKIP_TRAIN.update(DETECTRON2_MODELS)
# These models support only train mode. So accuracy checking can't be done in
# eval mode.
ONLY_TRAINING_MODE = {
"tts_angular",
"tacotron2",
"demucs",
"hf_Reformer",
"pytorch_struct",
"yolov3",
}
ONLY_TRAINING_MODE.update(DETECTRON2_MODELS)
# Need lower tolerance on GPU. GPU kernels have non deterministic kernels for these models.
REQUIRE_HIGHER_TOLERANCE = {
"alexnet",
"attention_is_all_you_need_pytorch",
"densenet121",
"hf_Albert",
"vgg16",
"mobilenet_v3_large",
"nvidia_deeprecommender",
"timm_efficientdet",
"vision_maskrcnn",
}
# These models need >1e-3 tolerance
REQUIRE_EVEN_HIGHER_TOLERANCE = {
"soft_actor_critic",
"tacotron2",
}
REQUIRE_HIGHER_FP16_TOLERANCE = {
"drq",
}
REQUIRE_COSINE_TOLERACE = {
# Just keeping it here even though its empty, if we need this in future.
}
# non-deterministic output / cant check correctness
NONDETERMINISTIC = {
# https://github.com/pytorch/pytorch/issues/98355
"mobilenet_v3_large",
"vision_maskrcnn", # eager variant
}
# These benchmarks took >600s on an i9-11900K CPU
VERY_SLOW_BENCHMARKS = {
"hf_BigBird", # 3339s
"hf_Longformer", # 3062s
"hf_T5", # 930s
}
# These benchmarks took >60s on an i9-11900K CPU
SLOW_BENCHMARKS = {
*VERY_SLOW_BENCHMARKS,
"BERT_pytorch", # 137s
"demucs", # 116s
"fastNLP_Bert", # 242s
"hf_Albert", # 221s
"hf_Bart", # 400s
"hf_Bert", # 334s
"hf_DistilBert", # 187s
"hf_GPT2", # 470s
"hf_Reformer", # 141s
"speech_transformer", # 317s
"vision_maskrcnn", # 99s
}
TRT_NOT_YET_WORKING = {
"alexnet",
"resnet18",
"resnet50",
"mobilenet_v2",
"mnasnet1_0",
"squeezenet1_1",
"shufflenetv2_x1_0",
"vgg16",
"resnext50_32x4d",
}
DYNAMIC_SHAPES_NOT_YET_WORKING = {
"demucs",
"timm_nfnet",
}
DONT_CHANGE_BATCH_SIZE = {
"demucs",
"pytorch_struct",
"pyhpc_turbulent_kinetic_energy",
}
SKIP_ACCURACY_CHECK_MODELS = {
# Models too large to have eager, dynamo and fp64_numbers simultaneosuly
# even for 40 GB machine. We have tested accuracy for smaller version of
# these models
"hf_GPT2_large",
"hf_T5_large",
"timm_vision_transformer_large",
"maml", # accuracy https://github.com/pytorch/pytorch/issues/93847
}
MAX_BATCH_SIZE_FOR_ACCURACY_CHECK = {
"hf_GPT2": 2,
"pytorch_unet": 2,
}
class TorchBenchmarkRunner(BenchmarkRunner):
def __init__(self):
super().__init__()
self.suite_name = "torchbench"
self.optimizer = None
@property
def skip_models(self):
return SKIP
@property
def skip_models_for_cuda(self):
return SKIP_FOR_CUDA
@property
def slow_models(self):
return SLOW_BENCHMARKS
@property
def very_slow_models(self):
return VERY_SLOW_BENCHMARKS
@property
def non_deterministic_models(self):
return NONDETERMINISTIC
@property
def skip_not_suitable_for_training_models(self):
return SKIP_TRAIN
@property
def failing_fx2trt_models(self):
return TRT_NOT_YET_WORKING
@property
def failing_dynamic_shape_models(self):
return DYNAMIC_SHAPES_NOT_YET_WORKING
@property
def skip_accuracy_checks_large_models_dashboard(self):
if self.args.dashboard or self.args.accuracy:
return SKIP_ACCURACY_CHECK_MODELS
return set()
def load_model(
self,
device,
model_name,
batch_size=None,
part=None,
):
is_training = self.args.training
use_eval_mode = self.args.use_eval_mode
dynamic_shapes = self.args.dynamic_shapes
try:
module = importlib.import_module(f"torchbenchmark.models.{model_name}")
except ModuleNotFoundError:
module = importlib.import_module(f"torchbenchmark.models.fb.{model_name}")
benchmark_cls = getattr(module, "Model", None)
if not hasattr(benchmark_cls, "name"):
benchmark_cls.name = model_name
cant_change_batch_size = (
not getattr(benchmark_cls, "ALLOW_CUSTOMIZE_BSIZE", True)
or model_name in DONT_CHANGE_BATCH_SIZE
)
if cant_change_batch_size:
batch_size = None
if batch_size is None and is_training and model_name in USE_SMALL_BATCH_SIZE:
batch_size = USE_SMALL_BATCH_SIZE[model_name]
# Control the memory footprint for few models
if self.args.accuracy and model_name in MAX_BATCH_SIZE_FOR_ACCURACY_CHECK:
batch_size = min(batch_size, MAX_BATCH_SIZE_FOR_ACCURACY_CHECK[model_name])
# workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
torch.backends.__allow_nonbracketed_mutation_flag = True
extra_args = []
if part:
extra_args = ["--part", part]
if is_training:
benchmark = benchmark_cls(
test="train",
device=device,
jit=False,
batch_size=batch_size,
extra_args=extra_args,
)
else:
benchmark = benchmark_cls(
test="eval",
device=device,
jit=False,
batch_size=batch_size,
extra_args=extra_args,
)
model, example_inputs = benchmark.get_module()
# Models that must be in train mode while training
if is_training and (not use_eval_mode or model_name in ONLY_TRAINING_MODE):
model.train()
else:
model.eval()
gc.collect()
batch_size = benchmark.batch_size
# Torchbench has quite different setup for yolov3, so directly passing
# the right example_inputs
if model_name == "yolov3":
example_inputs = (torch.rand(batch_size, 3, 384, 512).to(device),)
# global current_name, current_device
# current_device = device
# current_name = benchmark.name
if self.args.trace_on_xla:
# work around for: https://github.com/pytorch/xla/issues/4174
import torch_xla # noqa: F401
self.validate_model(model, example_inputs)
return device, benchmark.name, model, example_inputs, batch_size
def iter_model_names(self, args):
from torchbenchmark import _list_model_paths
models = _list_model_paths()
start, end = self.get_benchmark_indices(len(models))
for index, model_path in enumerate(models):
if index < start or index >= end:
continue
model_name = os.path.basename(model_path)
if (
not re.search("|".join(args.filter), model_name, re.I)
or re.search("|".join(args.exclude), model_name, re.I)
or model_name in args.exclude_exact
or model_name in self.skip_models
):
continue
yield model_name
def pick_grad(self, name, is_training):
if is_training or name in ("maml",):
return torch.enable_grad()
else:
return torch.no_grad()
def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
tolerance = 1e-4
cosine = self.args.cosine
# Increase the tolerance for torch allclose
if self.args.float16 or self.args.amp:
if name in REQUIRE_HIGHER_FP16_TOLERANCE:
return 1e-2, cosine
return 1e-3, cosine
if is_training and current_device == "cuda":
tolerance = 1e-3
if name in REQUIRE_COSINE_TOLERACE:
cosine = True
elif name in REQUIRE_HIGHER_TOLERANCE:
tolerance = 1e-3
elif name in REQUIRE_EVEN_HIGHER_TOLERANCE:
tolerance = 8 * 1e-2
return tolerance, cosine
def compute_loss(self, pred):
return reduce_to_scalar_loss(pred)
def forward_pass(self, mod, inputs, collect_outputs=True):
with self.autocast():
return mod(*inputs)
def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
cloned_inputs = clone_inputs(inputs)
self.optimizer_zero_grad(mod)
with self.autocast():
pred = mod(*cloned_inputs)
loss = self.compute_loss(pred)
self.grad_scaler.scale(loss).backward()
self.optimizer_step()
if collect_outputs:
return collect_results(mod, pred, loss, cloned_inputs)
return None
def torchbench_main():
original_dir = setup_torchbench_cwd()
logging.basicConfig(level=logging.WARNING)
warnings.filterwarnings("ignore")
main(TorchBenchmarkRunner(), original_dir)
if __name__ == "__main__":
torchbench_main()