mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
OSS-only copy of https://github.com/pytorch/pytorch/pull/85710 Pull Request resolved: https://github.com/pytorch/pytorch/pull/85894 Approved by: https://github.com/drisspg
46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
import argparse
|
|
|
|
import torch
|
|
|
|
|
|
def bench(nt_a, nt_b, niter):
|
|
# Warmup
|
|
nt_c = nt_a.bmm(nt_b)
|
|
|
|
torch.cuda.synchronize()
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
start_event.record()
|
|
for iter in range(niter):
|
|
nt_c = nt_a.bmm(nt_b)
|
|
end_event.record()
|
|
torch.cuda.synchronize()
|
|
runtime = (start_event.elapsed_time(end_event) * 1.0e-3) / niter
|
|
return runtime
|
|
|
|
|
|
def sweep_n(ntensor, niter, dtype):
|
|
print("n, dtype, ntensor, gflop, runtime, tflop/s")
|
|
for n in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
|
nt_a = torch.nested_tensor(
|
|
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
|
|
)
|
|
nt_b = torch.nested_tensor(
|
|
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
|
|
)
|
|
runtime = bench(nt_a, nt_b, niter)
|
|
tflop = n * n * n * ntensor * 2 / 1e12
|
|
print(n, dtype, ntensor, tflop, runtime, tflop / runtime)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark")
|
|
parser.add_argument("--niter", default="10", type=int)
|
|
parser.add_argument("--ntensor", default="20", type=int)
|
|
|
|
args = parser.parse_args()
|
|
niter = args.niter
|
|
ntensor = args.ntensor
|
|
|
|
sweep_n(ntensor, niter, torch.float32)
|
|
sweep_n(ntensor, niter, torch.float16)
|