Fix operator benchmark issue#162708 (#162744)

This PR skips memory metric calculation for ops which don't take tensor input, fixing the operator_benchmark bug

Fixes https://github.com/pytorch/pytorch/issues/162708

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162744
Approved by: https://github.com/huydhn
This commit is contained in:
jainapurva 2025-09-12 06:51:12 +00:00 committed by PyTorch MergeBot
parent 00e9ba75cd
commit 5f66902ecf

View File

@ -373,9 +373,14 @@ class BenchmarkRunner:
curr_test_total_time = 0
time_trace = []
peak_memory = 0
sample_input = next(iter(test_case.op_bench.inputs.values()))
device = sample_input.device
device_module = torch.get_device_module(device.type)
input_values = test_case.op_bench.inputs.values()
device, device_module = None, None
if input_values and isinstance(next(iter(input_values)), torch.Tensor):
# The device and device module information are crucial for memory metric calculation,
# In case of ops where inputs are integers (not tensor), memory metrics need not be calculated.
sample_input = next(iter(input_values))
device = sample_input.device
device_module = torch.get_device_module(device.type)
# TODO: add support for cpu memory measurement
while True:
if hasattr(device_module, "reset_peak_memory_stats"):