pytorch/benchmarks/dynamo/torchao_backend.py
HDCharles 374747818d Run performance test non-alternately (#131935)
Summary:
By default, performance tests (speedup experiments) will run the baseline and test backend alternately.

However, this does not work for the torchao backend, which will change the model in-place, therefore the baseline run will also run with torchao backend since the model has already been quantized.

Add a new experiment "latency_experiment" to run performance tests non-alternately (first run baseline for a few iterations, then run the test backend).

other changes:

need to add torch.compiler.cudagraph_mark_step_begin() to avoid the
slowdown from             # Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards

also updated the torchao APIs to the current versions

X-link: https://github.com/pytorch/benchmark/pull/2394

Test Plan:
python run_benchmark.py torchao --only AlbertForMaskedLM --quantization noquant --performance --inference --bfloat16 --inductor-compile-mode max-autotune python run_benchmark.py torchao --only BartForCausalLM --quantization noquant --performance --inference --bfloat16 --inductor-compile-mode max-autotune python run_benchmark.py torchao --only timm_efficientnet --quantization noquant --performance --inference --bfloat16 --inductor-compile-mode max-autotune

(should all be ~1.0
0.997x
1.006x
0.994x

Reviewed By: xuzhao9

Differential Revision: D60252821

Pulled By: HDCharles

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131935
Approved by: https://github.com/xuzhao9
2024-08-08 00:23:20 +00:00

58 lines
2.2 KiB
Python

from typing import Any, Callable
import torch
def setup_baseline():
from torchao.quantization.utils import recommended_inductor_config_setter
recommended_inductor_config_setter()
torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.cache_size_limit = 10000
def torchao_optimize_ctx(quantization: str):
from torchao.quantization.quant_api import (
autoquant,
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
from torchao.utils import unwrap_tensor_subclass
def inner(model_iter_fn: Callable):
def _torchao_apply(module: torch.nn.Module, example_inputs: Any):
if getattr(module, "_quantized", None) is None:
if quantization == "int8dynamic":
quantize_(
module,
int8_dynamic_activation_int8_weight(),
set_inductor_config=False,
)
elif quantization == "int8weightonly":
quantize_(module, int8_weight_only(), set_inductor_config=False)
elif quantization == "int4weightonly":
quantize_(module, int4_weight_only(), set_inductor_config=False)
if quantization == "autoquant":
autoquant(module, error_on_unseen=False, set_inductor_config=False)
if isinstance(example_inputs, dict):
module(**example_inputs)
else:
module(*example_inputs)
from torchao.quantization.autoquant import AUTOQUANT_CACHE
if len(AUTOQUANT_CACHE) == 0:
raise Exception( # noqa: TRY002`
"NotAutoquantizable"
f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run"
)
else:
unwrap_tensor_subclass(module)
setattr(module, "_quantized", True) # noqa: B010
model_iter_fn(module, example_inputs)
return _torchao_apply
return inner