[AOTInductor] Add interface for user managed buffer in package api. (#151325)

Summary:
https://github.com/pytorch/pytorch/pull/151141
We add interface for user managed buffer in the package api.

Test Plan:
Included in commit.]

Reviewed By: henrylhtsang

Differential Revision: D72985440

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151325
Approved by: https://github.com/angelayi
This commit is contained in:
Mu-Chu Lee 2025-04-16 04:25:37 +00:00 committed by PyTorch MergeBot
parent 82200e33b5
commit 107121dfad
5 changed files with 83 additions and 9 deletions

View File

@ -467,6 +467,63 @@ class TestAOTInductorPackage(TestCase):
output = compiled(test_inputs)
self.assertEqual(expected, output)
@skipif(
lambda device, package_cpp_only: device == "cpu" or package_cpp_only,
"No support for cpp only and cpu",
)
def test_package_user_managed_weight(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.linear = torch.nn.Linear(k, n, device=device)
def forward(self, a):
return self.linear(a)
M, N, K = 128, 4096, 4096
model = Model(N, K, self.device)
example_inputs = (torch.randn(M, K, device=self.device),)
inductor_configs = {
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": False,
}
compiled = compile(model, example_inputs, inductor_configs=inductor_configs)
self.assertEqual(
set(compiled.get_constant_fqns()), set(model.state_dict().keys())
)
compiled.load_constants(
model.state_dict(), check_full_update=True, user_managed=False
)
test_inputs = torch.randn(M, K, device=self.device)
expected = model(test_inputs)
output = compiled(test_inputs)
self.assertEqual(expected, output)
# Let's try to modify the weight in-place, result shouldn't change.
model.linear.weight.data *= 3.7
new_output = compiled(test_inputs)
self.assertEqual(new_output, output)
# Recreate a new model that we will test against user_managed=True
new_compiled = compile(model, example_inputs, inductor_configs=inductor_configs)
new_compiled.load_constants(
model.state_dict(), check_full_update=True, user_managed=True
)
expected = model(test_inputs)
new_output = new_compiled(test_inputs)
self.assertEqual(expected, new_output)
# Try to modify the weight in-place, result should change.
model.linear.weight.data *= 3.7
expected = model(test_inputs)
new_output = new_compiled(test_inputs)
self.assertEqual(new_output, expected)
def test_deepcopy_compiled_model(self):
class Model(torch.nn.Module):
def forward(self, x, y):

View File

@ -259,6 +259,7 @@ class AOTICompiledModel:
constants_map: dict[str, torch.Tensor],
*,
check_full_update: bool,
user_managed: bool = False,
) -> None:
"""
Given a mapping of constant fqns to tensors, load the constants into the model.
@ -270,7 +271,9 @@ class AOTICompiledModel:
check_full_update: Whether to add check to see if all the constants
are updated and have values.
"""
self.loader.load_constants(constants_map, False, check_full_update) # type: ignore[attr-defined]
self.loader.load_constants( # type: ignore[attr-defined]
constants_map, False, check_full_update, user_managed
)
def get_constant_fqns(self) -> list[str]:
return self.loader.get_constant_fqns() # type: ignore[attr-defined]

View File

@ -523,7 +523,8 @@ std::vector<std::string> AOTIModelPackageLoader::get_call_spec() {
void AOTIModelPackageLoader::load_constants(
std::unordered_map<std::string, at::Tensor>& constants_map,
bool use_inactive,
bool check_full_update) {
bool check_full_update,
bool user_managed) {
std::unordered_map<std::string, std::string> constant_name_to_fqn =
runner_->getConstantNamesToOriginalFQNs();
std::unordered_map<std::string, at::string> fqn_to_constant_name;
@ -541,7 +542,7 @@ void AOTIModelPackageLoader::load_constants(
}
return runner_->update_constant_buffer(
updated_constants_map, use_inactive, check_full_update);
updated_constants_map, use_inactive, check_full_update, user_managed);
}
std::vector<std::string> AOTIModelPackageLoader::get_constant_fqns() {
@ -558,9 +559,10 @@ std::vector<std::string> AOTIModelPackageLoader::get_constant_fqns() {
void AOTIModelPackageLoader::update_constant_buffer(
std::unordered_map<std::string, at::Tensor>& tensor_map,
bool use_inactive,
bool validate_full_updates) {
bool validate_full_updates,
bool user_managed) {
runner_->update_constant_buffer(
tensor_map, use_inactive, validate_full_updates);
tensor_map, use_inactive, validate_full_updates, user_managed);
}
} // namespace torch::inductor
#endif

View File

@ -30,13 +30,15 @@ class TORCH_API AOTIModelPackageLoader {
void load_constants(
std::unordered_map<std::string, at::Tensor>& constants_map,
bool use_inactive,
bool check_full_update);
bool check_full_update,
bool user_managed = false);
std::vector<std::string> get_constant_fqns();
void update_constant_buffer(
std::unordered_map<std::string, at::Tensor>& tensor_map,
bool use_inactive,
bool validate_full_updates);
bool validate_full_updates,
bool user_managed = false);
private:
std::string temp_dir_;

View File

@ -69,9 +69,19 @@ void initAOTIPackageBindings(PyObject* module) {
.def("get_call_spec", &AOTIModelPackageLoaderPybind::get_call_spec)
.def(
"get_constant_fqns", &AOTIModelPackageLoaderPybind::get_constant_fqns)
.def("load_constants", &AOTIModelPackageLoaderPybind::load_constants)
.def(
"load_constants",
&AOTIModelPackageLoaderPybind::load_constants,
py::arg("constants_map"),
py::arg("use_inactive"),
py::arg("check_full_update"),
py::arg("user_managed") = false)
.def(
"update_constant_buffer",
&AOTIModelPackageLoaderPybind::update_constant_buffer);
&AOTIModelPackageLoaderPybind::update_constant_buffer,
py::arg("tensor_map"),
py::arg("use_inactive"),
py::arg("validate_full_updates"),
py::arg("user_managed") = false);
}
} // namespace torch::inductor