mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add option to run AOT Precompile in benchmark (#164906)
Use the existing benchmark infra to get some signals for AOT precompile pass rate on OSS models. Here we also measure and log the loading time. ``` python ./benchmarks/dynamo/huggingface.py --accuracy --inference --aot-precompile python ./benchmarks/dynamo/timm_models.py --accuracy --inference --aot-precompile python ./benchmarks/dynamo/torchbench.py --accuracy --inference --aot-precompile ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164906 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
382d04a51e
commit
102b7885ff
|
|
@ -1060,6 +1060,8 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
|
|||
frozen_model_iter_fn = export_nativert(model, example_inputs)
|
||||
elif args.torchscript_jit_trace:
|
||||
frozen_model_iter_fn = torchscript_jit_trace(model, example_inputs)
|
||||
elif args.aot_precompile:
|
||||
frozen_model_iter_fn = aot_precompile(model, example_inputs)
|
||||
else:
|
||||
if kwargs["hf_llm"]:
|
||||
# If it's an llm, we want to optimize model.forward, and use
|
||||
|
|
@ -1495,6 +1497,37 @@ def export(model, example_inputs):
|
|||
return opt_export
|
||||
|
||||
|
||||
def aot_precompile(model, example_inputs):
|
||||
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
||||
save_path = f.name
|
||||
|
||||
with fresh_cache(), torch._dynamo.config.patch("enable_aot_compile", True):
|
||||
compiled_fn = torch.compile(
|
||||
model,
|
||||
fullgraph=True,
|
||||
options={"guard_filter_fn": lambda guards: [False for _ in guards]},
|
||||
).forward.aot_compile((example_args, example_kwargs))
|
||||
|
||||
compiled_fn.save_compiled_function(save_path)
|
||||
|
||||
torch._dynamo.reset()
|
||||
with open(save_path, "rb") as f:
|
||||
load_start_time = time.perf_counter()
|
||||
loaded_fn = torch.compiler.load_compiled_function(f)
|
||||
load_end_time = time.perf_counter()
|
||||
print(
|
||||
f"AOT Precompile loading time: {load_end_time - load_start_time} seconds"
|
||||
)
|
||||
|
||||
def opt_aot_precompile(_, example_inputs, collect_outputs=False):
|
||||
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
|
||||
return loaded_fn(model, *example_args, **example_kwargs)
|
||||
|
||||
return opt_aot_precompile
|
||||
|
||||
|
||||
def export_nativert(model, example_inputs):
|
||||
optimized = NativeRTCache.load(model, example_inputs)
|
||||
|
||||
|
|
@ -2274,6 +2307,7 @@ class BenchmarkRunner:
|
|||
or self.args.export_aot_inductor
|
||||
or self.args.export_nativert
|
||||
or self.args.torchscript_jit_trace
|
||||
or self.args.aot_precompile
|
||||
):
|
||||
# apply export on module directly
|
||||
# no need for n iterations
|
||||
|
|
@ -2729,6 +2763,7 @@ class BenchmarkRunner:
|
|||
self.args.export_aot_inductor
|
||||
or self.args.export_nativert
|
||||
or self.args.torchscript_jit_trace
|
||||
or self.args.aot_precompile
|
||||
):
|
||||
optimized_model_iter_fn = optimize_ctx
|
||||
else:
|
||||
|
|
@ -3505,6 +3540,11 @@ def parse_args(args=None):
|
|||
action="store_true",
|
||||
help="Measure pass rate with Export+AOTInductor",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aot-precompile",
|
||||
action="store_true",
|
||||
help="Measure pass rate with AOT Precompile",
|
||||
)
|
||||
group.add_argument(
|
||||
"--export-nativert",
|
||||
action="store_true",
|
||||
|
|
@ -3935,6 +3975,10 @@ def run(runner, args, original_dir=None):
|
|||
optimize_ctx = export
|
||||
experiment = speedup_experiment
|
||||
output_filename = "export.csv"
|
||||
elif args.aot_precompile:
|
||||
optimize_ctx = aot_precompile
|
||||
experiment = speedup_experiment
|
||||
output_filename = "aot_precompile.csv"
|
||||
elif args.export_nativert:
|
||||
optimize_ctx = export_nativert
|
||||
experiment = speedup_experiment
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user