mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[nnc] Added micro-benchmark to show perf improvement with cat subgraph optimization (#59581)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59581 Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D28955317 Pulled By: navahgar fbshipit-source-id: 53bb3dbfafbd3b146063f305523c2e6ec96cf6b8
This commit is contained in:
parent
d0c4ace00f
commit
47bbc01e0b
|
|
@ -1,5 +1,6 @@
|
|||
from . import benchmark
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class Concat2D2InputBench(benchmark.Benchmark):
|
||||
def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
|
||||
|
|
@ -50,7 +51,67 @@ class Concat2D2InputBench(benchmark.Benchmark):
|
|||
[1, 580, 1, 174, 1],
|
||||
[20, 160, 20, 14, 1],
|
||||
[20, 580, 20, 174, 1],
|
||||
[8, 512, 8, 512, 1]
|
||||
[8, 512, 8, 512, 1],
|
||||
[1 << 13, 1060, 1 << 13, 1040, 1],
|
||||
[1 << 13, 2000, 1 << 13, 1074, 1],
|
||||
[1 << 15, 1060, 1 << 15, 2670, 1],
|
||||
[1 << 15, 5120, 1 << 15, 2512, 1]
|
||||
]
|
||||
|
||||
benchmark.register_benchmark_class(Concat2D2InputBench)
|
||||
|
||||
class ConcatGraphOptBench(benchmark.Benchmark):
|
||||
def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
|
||||
super().__init__(mode, device, dtype)
|
||||
self.I1_D1 = I1_D1
|
||||
self.I1_D2 = I1_D2
|
||||
self.I2_D1 = I2_D1
|
||||
self.I2_D2 = I2_D2
|
||||
self.concat_dim = concat_dim
|
||||
self.input1 = self.randn([I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad)
|
||||
self.input2 = self.randn([I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad)
|
||||
self.inputs = [self.input1, self.input2]
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
torch._C._jit_cat_wo_conditionals(True)
|
||||
|
||||
def forward(self, input1, input2):
|
||||
x1 = self.add(input1, 0.00001)
|
||||
x2 = self.add(input2, 0.00001)
|
||||
y = self.cat((x1, x2), dim=self.concat_dim)
|
||||
z = self.relu(y)
|
||||
return z
|
||||
|
||||
def reference(self):
|
||||
return np.concatenate((self.numpy(self.input1), self.numpy(self.input2)), axis=concat_dim)
|
||||
|
||||
def config(self):
|
||||
return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim]
|
||||
|
||||
@staticmethod
|
||||
def module():
|
||||
return "concatGraphOpt"
|
||||
|
||||
def memory_workload(self):
|
||||
if self.mode == "fwd":
|
||||
sol_count = 1 + 1
|
||||
algorithmic_count = 3 + 1
|
||||
else:
|
||||
sol_count = (1 + 1) + (1 + 1)
|
||||
algorithmic_count = (3 + 1) + (3 + 1)
|
||||
|
||||
buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2
|
||||
return {
|
||||
"sol": buffer_size * sol_count,
|
||||
"algorithmic": buffer_size * algorithmic_count,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def default_configs():
|
||||
return [
|
||||
[1 << 13, 1060, 1 << 13, 1040, 1],
|
||||
[1 << 13, 2000, 1 << 13, 1074, 1],
|
||||
[1 << 15, 1060, 1 << 15, 2670, 1],
|
||||
[1 << 15, 5120, 1 << 15, 2512, 1]
|
||||
]
|
||||
|
||||
benchmark.register_benchmark_class(ConcatGraphOptBench)
|
||||
|
|
|
|||
|
|
@ -50,6 +50,15 @@ class TorchTensorEngine(object):
|
|||
def cat(self, inputs, dim=0):
|
||||
return torch.cat(inputs, dim=dim)
|
||||
|
||||
def clamp(self, data, min, max):
|
||||
return torch.clamp(data, min=min, max=max)
|
||||
|
||||
def relu(self, data):
|
||||
return torch.nn.functional.relu(data)
|
||||
|
||||
def tanh(self, data):
|
||||
return torch.tanh(data)
|
||||
|
||||
def max_pool2d(self, data, kernel_size, stride=1):
|
||||
return torch.nn.functional.max_pool2d(data, kernel_size, stride=stride)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user