mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix operator benchmark issue#162708 (#162744)
This PR skips memory metric calculation for ops which don't take tensor input, fixing the operator_benchmark bug Fixes https://github.com/pytorch/pytorch/issues/162708 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162744 Approved by: https://github.com/huydhn
This commit is contained in:
parent
00e9ba75cd
commit
5f66902ecf
|
|
@ -373,9 +373,14 @@ class BenchmarkRunner:
|
|||
curr_test_total_time = 0
|
||||
time_trace = []
|
||||
peak_memory = 0
|
||||
sample_input = next(iter(test_case.op_bench.inputs.values()))
|
||||
device = sample_input.device
|
||||
device_module = torch.get_device_module(device.type)
|
||||
input_values = test_case.op_bench.inputs.values()
|
||||
device, device_module = None, None
|
||||
if input_values and isinstance(next(iter(input_values)), torch.Tensor):
|
||||
# The device and device module information are crucial for memory metric calculation,
|
||||
# In case of ops where inputs are integers (not tensor), memory metrics need not be calculated.
|
||||
sample_input = next(iter(input_values))
|
||||
device = sample_input.device
|
||||
device_module = torch.get_device_module(device.type)
|
||||
# TODO: add support for cpu memory measurement
|
||||
while True:
|
||||
if hasattr(device_module, "reset_peak_memory_stats"):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user