Fix torch.load (torch.utils.benchmark) after #137602 (#139810)

After #137602, the default `weights_only` has been set to True.  This test is failing in trunk slow jobs atm

benchmark_utils/test_benchmark_utils.py::TestBenchmarkUtils::test_collect_callgrind [GH job link](https://github.com/pytorch/pytorch/actions/runs/11672436111/job/32502454946) [HUD commit link](1aa71be56c)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139810
Approved by: https://github.com/kit1980
This commit is contained in:
Huy Do 2024-11-06 03:08:29 +00:00 committed by PyTorch MergeBot
parent 63b01f328e
commit c19c384690

View File

@ -457,7 +457,10 @@ class GlobalsBridge:
elif wrapped_value.serialization == Serialization.TORCH:
path = os.path.join(self._data_dir, f"{name}.pt")
load_lines.append(f"{name} = torch.load({repr(path)})")
# TODO: Figure out if we can use torch.serialization.add_safe_globals here
# Using weights_only=False after the change in
# https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573
load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)")
torch.save(wrapped_value.value, path)
elif wrapped_value.serialization == Serialization.TORCH_JIT: