mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Add test for enabling CUDACachingAllocator for AOTInductor's Weight. Implementation TBD Test Plan: N/A, commit is adding a test. Rollback Plan: Differential Revision: D79107507 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159279 Approved by: https://github.com/desertfire, https://github.com/jingsh
220 lines
6.9 KiB
Python
220 lines
6.9 KiB
Python
import torch
|
|
import torch._inductor.config
|
|
from torch._export import aot_compile
|
|
from torch.export import Dim
|
|
|
|
|
|
torch.manual_seed(1337)
|
|
|
|
|
|
class Net(torch.nn.Module):
|
|
def __init__(self, device, size=4):
|
|
super().__init__()
|
|
self.w_pre = torch.randn(size, size, device=device)
|
|
self.w_add = torch.randn(size, size, device=device)
|
|
|
|
def forward(self, x):
|
|
w_transpose = torch.transpose(self.w_pre, 0, 1)
|
|
w_relu = torch.nn.functional.relu(w_transpose)
|
|
w = w_relu + self.w_add
|
|
return torch.matmul(x, w)
|
|
|
|
|
|
class NetWithTensorConstants(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.randn(30, 1, device="cuda")
|
|
|
|
def forward(self, x, y):
|
|
z = self.w * x * y
|
|
return z[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17]]
|
|
|
|
|
|
data = {}
|
|
large_data = {}
|
|
cuda_alloc_data = {}
|
|
data_with_tensor_constants = {}
|
|
|
|
|
|
# Basice AOTI model test generation.
|
|
def generate_basic_tests():
|
|
for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
|
|
for use_runtime_constant_folding in [True, False]:
|
|
if device == "cpu" and use_runtime_constant_folding:
|
|
# We do not test runtime const folding for cpu mode.
|
|
continue
|
|
model = Net(device).to(device=device)
|
|
x = torch.randn((4, 4), device=device)
|
|
with torch.no_grad():
|
|
ref_output = model(x)
|
|
|
|
torch._dynamo.reset()
|
|
with torch.no_grad():
|
|
dim0_x = Dim("dim0_x", min=1, max=1024)
|
|
dynamic_shapes = {"x": {0: dim0_x}}
|
|
model_so_path = aot_compile(
|
|
model,
|
|
(x,),
|
|
dynamic_shapes=dynamic_shapes,
|
|
options={
|
|
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
|
|
},
|
|
)
|
|
# Also store a .pt2 file using the aoti_compile_and_package API
|
|
pt2_package_path = torch._inductor.aoti_compile_and_package(
|
|
torch.export.export(
|
|
model,
|
|
(x,),
|
|
dynamic_shapes=dynamic_shapes,
|
|
),
|
|
inductor_configs={
|
|
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
|
|
},
|
|
)
|
|
|
|
suffix = f"{device}"
|
|
if use_runtime_constant_folding:
|
|
suffix += "_use_runtime_constant_folding"
|
|
data.update(
|
|
{
|
|
f"model_so_path_{suffix}": model_so_path,
|
|
f"pt2_package_path_{suffix}": pt2_package_path,
|
|
f"inputs_{suffix}": [x],
|
|
f"outputs_{suffix}": [ref_output],
|
|
f"w_pre_{suffix}": model.w_pre,
|
|
f"w_add_{suffix}": model.w_add,
|
|
}
|
|
)
|
|
|
|
|
|
def generate_basic_tests_consts_cpp():
|
|
backup_consts_asm_cfg: bool = (
|
|
torch._inductor.config.aot_inductor.use_consts_asm_build
|
|
)
|
|
torch._inductor.config.aot_inductor.use_consts_asm_build = False
|
|
|
|
# Test consts cpp build again.
|
|
generate_basic_tests()
|
|
|
|
torch._inductor.config.aot_inductor.use_consts_asm_build = backup_consts_asm_cfg
|
|
|
|
|
|
def generate_large_tests():
|
|
device = "cuda"
|
|
model = Net(device, size=4096).to(device=device)
|
|
x = torch.randn((4096, 4096), device=device)
|
|
with torch.no_grad():
|
|
ref_output = model(x)
|
|
|
|
torch._dynamo.reset()
|
|
for use_runtime_constant_folding in [True, False]:
|
|
with torch.no_grad():
|
|
model_so_path = aot_compile(
|
|
model,
|
|
(x,),
|
|
options={
|
|
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
|
|
},
|
|
)
|
|
# Also store a .pt2 file using the aoti_compile_and_package API
|
|
pt2_package_path = torch._inductor.aoti_compile_and_package(
|
|
torch.export.export(
|
|
model,
|
|
(x,),
|
|
),
|
|
inductor_configs={
|
|
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
|
|
},
|
|
)
|
|
|
|
suffix = "_use_runtime_constant_folding" if use_runtime_constant_folding else ""
|
|
large_data.update(
|
|
{ # noqa: F541
|
|
f"model_so_path{suffix}": model_so_path,
|
|
f"pt2_package_path{suffix}": pt2_package_path,
|
|
"inputs": [x],
|
|
"outputs": [ref_output],
|
|
"w_pre": model.w_pre,
|
|
"w_add": model.w_add,
|
|
}
|
|
)
|
|
|
|
|
|
def generate_cuda_alloc_test():
|
|
device = "cuda"
|
|
model = Net(device, size=4096).to(device=device)
|
|
x = torch.randn((4096, 4096), device=device)
|
|
with torch.no_grad():
|
|
ref_output = model(x)
|
|
|
|
torch._dynamo.reset()
|
|
with torch.no_grad():
|
|
model_so_path = aot_compile(
|
|
model,
|
|
(x,),
|
|
options={"aot_inductor.weight_use_caching_allocator": True},
|
|
)
|
|
|
|
cuda_alloc_data.update(
|
|
{ # noqa: F541
|
|
"model_so_path": model_so_path,
|
|
"inputs": [x],
|
|
"outputs": [ref_output],
|
|
"w_pre": model.w_pre,
|
|
"w_add": model.w_add,
|
|
}
|
|
)
|
|
|
|
|
|
# AOTI model which will create additional tensors during autograd.
|
|
def generate_test_with_additional_tensors():
|
|
if not torch.cuda.is_available():
|
|
return
|
|
|
|
model = NetWithTensorConstants()
|
|
x = torch.randn((30, 1), device="cuda")
|
|
y = torch.randn((30, 1), device="cuda")
|
|
with torch.no_grad():
|
|
ref_output = model(x, y)
|
|
|
|
torch._dynamo.reset()
|
|
with torch.no_grad():
|
|
model_so_path = aot_compile(model, (x, y))
|
|
# Also store a .pt2 file using the aoti_compile_and_package API
|
|
pt2_package_path = torch._inductor.aoti_compile_and_package(
|
|
torch.export.export(model, (x, y))
|
|
)
|
|
|
|
data_with_tensor_constants.update(
|
|
{
|
|
"model_so_path": model_so_path,
|
|
"pt2_package_path": pt2_package_path,
|
|
"inputs": [x, y],
|
|
"outputs": [ref_output],
|
|
"w": model.w,
|
|
}
|
|
)
|
|
|
|
|
|
generate_basic_tests()
|
|
generate_basic_tests_consts_cpp()
|
|
generate_large_tests()
|
|
generate_test_with_additional_tensors()
|
|
generate_cuda_alloc_test()
|
|
|
|
|
|
# Use this to communicate tensors to the cpp code
|
|
class Serializer(torch.nn.Module):
|
|
def __init__(self, data):
|
|
super().__init__()
|
|
for key in data:
|
|
setattr(self, key, data[key])
|
|
|
|
|
|
torch.jit.script(Serializer(data)).save("data.pt")
|
|
torch.jit.script(Serializer(large_data)).save("large_data.pt")
|
|
torch.jit.script(Serializer(data_with_tensor_constants)).save(
|
|
"data_with_tensor_constants.pt"
|
|
)
|
|
torch.jit.script(Serializer(cuda_alloc_data)).save("cuda_alloc_data.pt")
|