mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add loop mm benchmark (#149932)
results: compile time instruction count for iteration 4 is 67947323682 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149932 Approved by: https://github.com/bobrenjc93, https://github.com/eellison
This commit is contained in:
parent
79e8a69257
commit
7379c66344
|
|
@ -3,6 +3,8 @@ import gc
|
|||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch._C._instruction_counter as i_counter
|
||||
import torch._dynamo.config as config
|
||||
|
|
@ -76,7 +78,7 @@ class BenchmarkBase(ABC):
|
|||
backend: str = "",
|
||||
mode: str = "",
|
||||
dynamic=None,
|
||||
):
|
||||
) -> None:
|
||||
# These individual attributes are used to support different filters on the
|
||||
# dashboard later
|
||||
self._category = category
|
||||
|
|
@ -85,51 +87,51 @@ class BenchmarkBase(ABC):
|
|||
self._mode = mode # Training or inference
|
||||
self._dynamic = dynamic
|
||||
|
||||
def with_iterations(self, value):
|
||||
def with_iterations(self, value: int) -> Self:
|
||||
self._num_iterations = value
|
||||
return self
|
||||
|
||||
def enable_instruction_count(self):
|
||||
def enable_instruction_count(self) -> Self:
|
||||
self._enable_instruction_count = True
|
||||
return self
|
||||
|
||||
def enable_compile_time_instruction_count(self):
|
||||
def enable_compile_time_instruction_count(self) -> Self:
|
||||
self._enable_compile_time_instruction_count = True
|
||||
return self
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return ""
|
||||
|
||||
def backend(self):
|
||||
def backend(self) -> str:
|
||||
return self._backend
|
||||
|
||||
def mode(self):
|
||||
def mode(self) -> str:
|
||||
return self._mode
|
||||
|
||||
def category(self):
|
||||
def category(self) -> str:
|
||||
return self._category
|
||||
|
||||
def device(self):
|
||||
def device(self) -> str:
|
||||
return self._device
|
||||
|
||||
def is_dynamic(self):
|
||||
def is_dynamic(self) -> Optional[bool]:
|
||||
return self._dynamic
|
||||
|
||||
def description(self):
|
||||
def description(self) -> str:
|
||||
return ""
|
||||
|
||||
@abstractmethod
|
||||
def _prepare(self):
|
||||
def _prepare(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _work(self):
|
||||
def _work(self) -> None:
|
||||
pass
|
||||
|
||||
def _prepare_once(self): # noqa: B027
|
||||
def _prepare_once(self) -> None: # noqa: B027
|
||||
pass
|
||||
|
||||
def _count_instructions(self):
|
||||
def _count_instructions(self) -> int:
|
||||
print(f"collecting instruction count for {self.name()}")
|
||||
results = []
|
||||
for i in range(self._num_iterations):
|
||||
|
|
@ -141,7 +143,7 @@ class BenchmarkBase(ABC):
|
|||
results.append(count)
|
||||
return min(results)
|
||||
|
||||
def _count_compile_time_instructions(self):
|
||||
def _count_compile_time_instructions(self) -> int:
|
||||
gc.disable()
|
||||
|
||||
try:
|
||||
|
|
@ -169,7 +171,7 @@ class BenchmarkBase(ABC):
|
|||
finally:
|
||||
gc.enable()
|
||||
|
||||
def _write_to_json(self, output_dir: str):
|
||||
def _write_to_json(self, output_dir: str) -> None:
|
||||
"""
|
||||
Write the result into JSON format, so that it can be uploaded to the benchmark database
|
||||
to be displayed on OSS dashboard. The JSON format is defined at
|
||||
|
|
@ -209,7 +211,7 @@ class BenchmarkBase(ABC):
|
|||
with open(os.path.join(output_dir, f"{self.name()}.json"), "w") as f:
|
||||
json.dump(records, f)
|
||||
|
||||
def append_results(self, path):
|
||||
def append_results(self, path: str) -> None:
|
||||
with open(path, "a", newline="") as csvfile:
|
||||
# Create a writer object
|
||||
writer = csv.writer(csvfile)
|
||||
|
|
@ -221,11 +223,11 @@ class BenchmarkBase(ABC):
|
|||
# as the CSV writer for now
|
||||
self._write_to_json(os.path.dirname(os.path.abspath(path)))
|
||||
|
||||
def print(self):
|
||||
def print(self) -> None:
|
||||
for entry in self.results:
|
||||
print(f"{entry[0]},{entry[1]},{entry[2]}")
|
||||
|
||||
def collect_all(self):
|
||||
def collect_all(self) -> Self:
|
||||
self._prepare_once()
|
||||
self.results = []
|
||||
if (
|
||||
|
|
|
|||
60
benchmarks/dynamo/pr_time_benchmarks/benchmarks/mm_loop.py
Normal file
60
benchmarks/dynamo/pr_time_benchmarks/benchmarks/mm_loop.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import sys
|
||||
|
||||
from benchmark_base import BenchmarkBase
|
||||
|
||||
import torch
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
|
||||
|
||||
class Benchmark(BenchmarkBase):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
category="mm_loop",
|
||||
backend="inductor",
|
||||
device="cuda",
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
def name(self) -> str:
|
||||
prefix = f"{self.category()}_{self.backend()}"
|
||||
if self.is_dynamic():
|
||||
prefix += "_dynamic"
|
||||
if self.device() == "cuda":
|
||||
prefix += "_gpu"
|
||||
return prefix
|
||||
|
||||
def description(self) -> str:
|
||||
return "a mm 100 times in a loop with max auto tune on"
|
||||
|
||||
def _prepare_once(self) -> None:
|
||||
self.a = torch.ones(10, 10, device=self.device())
|
||||
self.b = torch.torch.ones(10, 10, device=self.device())
|
||||
|
||||
def _prepare(self) -> None:
|
||||
torch._dynamo.reset()
|
||||
|
||||
def _work(self) -> None:
|
||||
@torch.compile(
|
||||
backend="inductor",
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
)
|
||||
def f(a, b):
|
||||
z = torch.mm(a, b)
|
||||
for i in range(200):
|
||||
z = torch.mm(z, b)
|
||||
return z
|
||||
|
||||
with fresh_inductor_cache(), torch._inductor.config.patch(max_autotune=True):
|
||||
f(self.a, self.b)
|
||||
|
||||
|
||||
def main():
|
||||
result_path = sys.argv[1]
|
||||
Benchmark().enable_compile_time_instruction_count().collect_all().append_results(
|
||||
result_path
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user