add loop mm benchmark (#149932)

results:
compile time instruction count for iteration 4 is 67947323682

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149932
Approved by: https://github.com/bobrenjc93, https://github.com/eellison
This commit is contained in:
Laith Sakka 2025-03-24 23:22:33 -07:00 committed by PyTorch MergeBot
parent 79e8a69257
commit 7379c66344
2 changed files with 82 additions and 20 deletions

View File

@ -3,6 +3,8 @@ import gc
import json
import os
from abc import ABC, abstractmethod
from typing import Optional
from typing_extensions import Self
import torch._C._instruction_counter as i_counter
import torch._dynamo.config as config
@ -76,7 +78,7 @@ class BenchmarkBase(ABC):
backend: str = "",
mode: str = "",
dynamic=None,
):
) -> None:
# These individual attributes are used to support different filters on the
# dashboard later
self._category = category
@ -85,51 +87,51 @@ class BenchmarkBase(ABC):
self._mode = mode # Training or inference
self._dynamic = dynamic
def with_iterations(self, value):
def with_iterations(self, value: int) -> Self:
self._num_iterations = value
return self
def enable_instruction_count(self):
def enable_instruction_count(self) -> Self:
self._enable_instruction_count = True
return self
def enable_compile_time_instruction_count(self):
def enable_compile_time_instruction_count(self) -> Self:
self._enable_compile_time_instruction_count = True
return self
def name(self):
def name(self) -> str:
return ""
def backend(self):
def backend(self) -> str:
return self._backend
def mode(self):
def mode(self) -> str:
return self._mode
def category(self):
def category(self) -> str:
return self._category
def device(self):
def device(self) -> str:
return self._device
def is_dynamic(self):
def is_dynamic(self) -> Optional[bool]:
return self._dynamic
def description(self):
def description(self) -> str:
return ""
@abstractmethod
def _prepare(self):
def _prepare(self) -> None:
pass
@abstractmethod
def _work(self):
def _work(self) -> None:
pass
def _prepare_once(self): # noqa: B027
def _prepare_once(self) -> None: # noqa: B027
pass
def _count_instructions(self):
def _count_instructions(self) -> int:
print(f"collecting instruction count for {self.name()}")
results = []
for i in range(self._num_iterations):
@ -141,7 +143,7 @@ class BenchmarkBase(ABC):
results.append(count)
return min(results)
def _count_compile_time_instructions(self):
def _count_compile_time_instructions(self) -> int:
gc.disable()
try:
@ -169,7 +171,7 @@ class BenchmarkBase(ABC):
finally:
gc.enable()
def _write_to_json(self, output_dir: str):
def _write_to_json(self, output_dir: str) -> None:
"""
Write the result into JSON format, so that it can be uploaded to the benchmark database
to be displayed on OSS dashboard. The JSON format is defined at
@ -209,7 +211,7 @@ class BenchmarkBase(ABC):
with open(os.path.join(output_dir, f"{self.name()}.json"), "w") as f:
json.dump(records, f)
def append_results(self, path):
def append_results(self, path: str) -> None:
with open(path, "a", newline="") as csvfile:
# Create a writer object
writer = csv.writer(csvfile)
@ -221,11 +223,11 @@ class BenchmarkBase(ABC):
# as the CSV writer for now
self._write_to_json(os.path.dirname(os.path.abspath(path)))
def print(self):
def print(self) -> None:
for entry in self.results:
print(f"{entry[0]},{entry[1]},{entry[2]}")
def collect_all(self):
def collect_all(self) -> Self:
self._prepare_once()
self.results = []
if (

View File

@ -0,0 +1,60 @@
import sys
from benchmark_base import BenchmarkBase
import torch
from torch._inductor.utils import fresh_inductor_cache
class Benchmark(BenchmarkBase):
def __init__(self) -> None:
super().__init__(
category="mm_loop",
backend="inductor",
device="cuda",
dynamic=False,
)
def name(self) -> str:
prefix = f"{self.category()}_{self.backend()}"
if self.is_dynamic():
prefix += "_dynamic"
if self.device() == "cuda":
prefix += "_gpu"
return prefix
def description(self) -> str:
return "a mm 100 times in a loop with max auto tune on"
def _prepare_once(self) -> None:
self.a = torch.ones(10, 10, device=self.device())
self.b = torch.torch.ones(10, 10, device=self.device())
def _prepare(self) -> None:
torch._dynamo.reset()
def _work(self) -> None:
@torch.compile(
backend="inductor",
fullgraph=True,
dynamic=False,
)
def f(a, b):
z = torch.mm(a, b)
for i in range(200):
z = torch.mm(z, b)
return z
with fresh_inductor_cache(), torch._inductor.config.patch(max_autotune=True):
f(self.a, self.b)
def main():
result_path = sys.argv[1]
Benchmark().enable_compile_time_instruction_count().collect_all().append_results(
result_path
)
if __name__ == "__main__":
main()