mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: In AOTInductor generated CPU model code, there can be direct references to some aten/c10 utility functions and data structures, e.g. at::vec and c10::Half. These are performance critical and thus it doesn't make sense to create C shim for them. Instead, we make sure they are implemented in a header-only way, and use this set of tests to guard future changes. There are more header files to be updated, but we will do it in other followup PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123848 Approved by: https://github.com/jansel ghstack dependencies: #123847
112 lines
3.3 KiB
Python
112 lines
3.3 KiB
Python
import torch
|
|
from torch._export import aot_compile
|
|
from torch.export import Dim
|
|
|
|
torch.manual_seed(1337)
|
|
|
|
|
|
class Net(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.w_pre = torch.randn(4, 4, device=device)
|
|
self.w_add = torch.randn(4, 4, device=device)
|
|
|
|
def forward(self, x):
|
|
w_transpose = torch.transpose(self.w_pre, 0, 1)
|
|
w_relu = torch.nn.functional.relu(w_transpose)
|
|
w = w_relu + self.w_add
|
|
return torch.matmul(x, w)
|
|
|
|
|
|
class NetWithTensorConstants(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.w = torch.randn(30, 1, device="cuda")
|
|
|
|
def forward(self, x, y):
|
|
z = self.w * x * y
|
|
return z[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17]]
|
|
|
|
|
|
data = {}
|
|
data_with_tensor_constants = {}
|
|
|
|
|
|
# Basice AOTI model test generation.
|
|
def generate_basic_tests():
|
|
for device in ["cpu", "cuda"]:
|
|
for use_runtime_constant_folding in [True, False]:
|
|
if device == "cpu" and use_runtime_constant_folding:
|
|
# We do not test runtime const folding for cpu mode.
|
|
continue
|
|
model = Net(device).to(device=device)
|
|
x = torch.randn((4, 4), device=device)
|
|
with torch.no_grad():
|
|
ref_output = model(x)
|
|
|
|
torch._dynamo.reset()
|
|
with torch.no_grad():
|
|
dim0_x = Dim("dim0_x", min=1, max=1024)
|
|
dynamic_shapes = {"x": {0: dim0_x}}
|
|
model_so_path = aot_compile(
|
|
model,
|
|
(x,),
|
|
dynamic_shapes=dynamic_shapes,
|
|
options={
|
|
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
|
|
},
|
|
)
|
|
|
|
suffix = f"{device}"
|
|
if use_runtime_constant_folding:
|
|
suffix += "_use_runtime_constant_folding"
|
|
data.update(
|
|
{
|
|
f"model_so_path_{suffix}": model_so_path,
|
|
f"inputs_{suffix}": [x],
|
|
f"outputs_{suffix}": [ref_output],
|
|
f"w_pre_{suffix}": model.w_pre,
|
|
f"w_add_{suffix}": model.w_add,
|
|
}
|
|
)
|
|
|
|
|
|
# AOTI model which will create additional tensors during autograd.
|
|
def generate_test_with_additional_tensors():
|
|
model = NetWithTensorConstants()
|
|
x = torch.randn((30, 1), device="cuda")
|
|
y = torch.randn((30, 1), device="cuda")
|
|
with torch.no_grad():
|
|
ref_output = model(x, y)
|
|
|
|
torch._dynamo.reset()
|
|
with torch.no_grad():
|
|
model_so_path = aot_compile(model, (x, y))
|
|
|
|
data_with_tensor_constants.update(
|
|
{
|
|
"model_so_path": model_so_path,
|
|
"inputs": [x, y],
|
|
"outputs": [ref_output],
|
|
"w": model.w,
|
|
}
|
|
)
|
|
|
|
|
|
generate_basic_tests()
|
|
generate_test_with_additional_tensors()
|
|
|
|
|
|
# Use this to communicate tensors to the cpp code
|
|
class Serializer(torch.nn.Module):
|
|
def __init__(self, data):
|
|
super().__init__()
|
|
for key in data:
|
|
setattr(self, key, data[key])
|
|
|
|
|
|
torch.jit.script(Serializer(data)).save("data.pt")
|
|
torch.jit.script(Serializer(data_with_tensor_constants)).save(
|
|
"data_with_tensor_constants.pt"
|
|
)
|