mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[AOTI][dashboard] Fix mis-calculated memory compression ratio (#150695)
Summary: https://github.com/pytorch/pytorch/pull/149817 introduced an extra warmup run to compute AOTI memory compression ratio, but since weights are only loaded once in the AOTI run, the peak memory seen in the extra warmup won't include the weight, which causes an aritifically high memory compression ratio. This PR removes that extra warmup run, and calls reset_peak_memory_stats in the proper place instead. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150695 Approved by: https://github.com/yushangdi
This commit is contained in:
parent
6c38b9be73
commit
6a8ab902a2
|
|
@ -1395,6 +1395,8 @@ class AOTInductorModelCache:
|
|||
with torch.no_grad():
|
||||
# copy.deepcopy is required to prevent any surprising side-effect,
|
||||
# see https://github.com/pytorch/pytorch/issues/113029
|
||||
# This will cause memory stats to be overshadowed by this eager run.
|
||||
# To fix that, memory stats will be reset later.
|
||||
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
|
||||
|
||||
if pytree.is_namedtuple_instance(example_outputs):
|
||||
|
|
@ -1411,6 +1413,14 @@ class AOTInductorModelCache:
|
|||
_produce_dynamic_shapes_for_export, combined_args
|
||||
)
|
||||
|
||||
# delete example_outputs and reset memory stats here
|
||||
del example_outputs
|
||||
if current_device == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
empty_gpu_cache(current_device)
|
||||
elif current_device == "hpu":
|
||||
torch.hpu.reset_peak_memory_stats()
|
||||
|
||||
ep = torch.export.export(
|
||||
model,
|
||||
example_args,
|
||||
|
|
@ -3735,10 +3745,6 @@ def run(runner, args, original_dir=None):
|
|||
# AOTInductor doesn't support control flow yet
|
||||
runner.skip_models.update(runner.skip_models_due_to_control_flow)
|
||||
runner.skip_models.update(runner.skip_models_due_to_export_not_supported)
|
||||
|
||||
# For AOTI, we only measure the memory compression ratio at the run time
|
||||
# instead of the compile time, so use a warmup run to trigger AOTI compilation.
|
||||
args.use_warm_peak_memory = True
|
||||
elif args.backend == "torchao":
|
||||
assert "cuda" in args.devices, "Quantization requires CUDA device."
|
||||
assert args.bfloat16, "Quantization requires dtype bfloat16."
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user