mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTInductor] Add interface for user managed buffer in package api. (#151325)
Summary: https://github.com/pytorch/pytorch/pull/151141 We add interface for user managed buffer in the package api. Test Plan: Included in commit.] Reviewed By: henrylhtsang Differential Revision: D72985440 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151325 Approved by: https://github.com/angelayi
This commit is contained in:
parent
82200e33b5
commit
107121dfad
|
|
@ -467,6 +467,63 @@ class TestAOTInductorPackage(TestCase):
|
||||||
output = compiled(test_inputs)
|
output = compiled(test_inputs)
|
||||||
self.assertEqual(expected, output)
|
self.assertEqual(expected, output)
|
||||||
|
|
||||||
|
@skipif(
|
||||||
|
lambda device, package_cpp_only: device == "cpu" or package_cpp_only,
|
||||||
|
"No support for cpp only and cpu",
|
||||||
|
)
|
||||||
|
def test_package_user_managed_weight(self):
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self, n, k, device):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.Linear(k, n, device=device)
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return self.linear(a)
|
||||||
|
|
||||||
|
M, N, K = 128, 4096, 4096
|
||||||
|
model = Model(N, K, self.device)
|
||||||
|
example_inputs = (torch.randn(M, K, device=self.device),)
|
||||||
|
|
||||||
|
inductor_configs = {
|
||||||
|
"always_keep_tensor_constants": True,
|
||||||
|
"aot_inductor.package_constants_in_so": False,
|
||||||
|
}
|
||||||
|
compiled = compile(model, example_inputs, inductor_configs=inductor_configs)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
set(compiled.get_constant_fqns()), set(model.state_dict().keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
compiled.load_constants(
|
||||||
|
model.state_dict(), check_full_update=True, user_managed=False
|
||||||
|
)
|
||||||
|
|
||||||
|
test_inputs = torch.randn(M, K, device=self.device)
|
||||||
|
expected = model(test_inputs)
|
||||||
|
output = compiled(test_inputs)
|
||||||
|
self.assertEqual(expected, output)
|
||||||
|
|
||||||
|
# Let's try to modify the weight in-place, result shouldn't change.
|
||||||
|
model.linear.weight.data *= 3.7
|
||||||
|
new_output = compiled(test_inputs)
|
||||||
|
self.assertEqual(new_output, output)
|
||||||
|
|
||||||
|
# Recreate a new model that we will test against user_managed=True
|
||||||
|
new_compiled = compile(model, example_inputs, inductor_configs=inductor_configs)
|
||||||
|
new_compiled.load_constants(
|
||||||
|
model.state_dict(), check_full_update=True, user_managed=True
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = model(test_inputs)
|
||||||
|
new_output = new_compiled(test_inputs)
|
||||||
|
self.assertEqual(expected, new_output)
|
||||||
|
|
||||||
|
# Try to modify the weight in-place, result should change.
|
||||||
|
model.linear.weight.data *= 3.7
|
||||||
|
expected = model(test_inputs)
|
||||||
|
new_output = new_compiled(test_inputs)
|
||||||
|
self.assertEqual(new_output, expected)
|
||||||
|
|
||||||
def test_deepcopy_compiled_model(self):
|
def test_deepcopy_compiled_model(self):
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
|
|
|
||||||
|
|
@ -259,6 +259,7 @@ class AOTICompiledModel:
|
||||||
constants_map: dict[str, torch.Tensor],
|
constants_map: dict[str, torch.Tensor],
|
||||||
*,
|
*,
|
||||||
check_full_update: bool,
|
check_full_update: bool,
|
||||||
|
user_managed: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Given a mapping of constant fqns to tensors, load the constants into the model.
|
Given a mapping of constant fqns to tensors, load the constants into the model.
|
||||||
|
|
@ -270,7 +271,9 @@ class AOTICompiledModel:
|
||||||
check_full_update: Whether to add check to see if all the constants
|
check_full_update: Whether to add check to see if all the constants
|
||||||
are updated and have values.
|
are updated and have values.
|
||||||
"""
|
"""
|
||||||
self.loader.load_constants(constants_map, False, check_full_update) # type: ignore[attr-defined]
|
self.loader.load_constants( # type: ignore[attr-defined]
|
||||||
|
constants_map, False, check_full_update, user_managed
|
||||||
|
)
|
||||||
|
|
||||||
def get_constant_fqns(self) -> list[str]:
|
def get_constant_fqns(self) -> list[str]:
|
||||||
return self.loader.get_constant_fqns() # type: ignore[attr-defined]
|
return self.loader.get_constant_fqns() # type: ignore[attr-defined]
|
||||||
|
|
|
||||||
|
|
@ -523,7 +523,8 @@ std::vector<std::string> AOTIModelPackageLoader::get_call_spec() {
|
||||||
void AOTIModelPackageLoader::load_constants(
|
void AOTIModelPackageLoader::load_constants(
|
||||||
std::unordered_map<std::string, at::Tensor>& constants_map,
|
std::unordered_map<std::string, at::Tensor>& constants_map,
|
||||||
bool use_inactive,
|
bool use_inactive,
|
||||||
bool check_full_update) {
|
bool check_full_update,
|
||||||
|
bool user_managed) {
|
||||||
std::unordered_map<std::string, std::string> constant_name_to_fqn =
|
std::unordered_map<std::string, std::string> constant_name_to_fqn =
|
||||||
runner_->getConstantNamesToOriginalFQNs();
|
runner_->getConstantNamesToOriginalFQNs();
|
||||||
std::unordered_map<std::string, at::string> fqn_to_constant_name;
|
std::unordered_map<std::string, at::string> fqn_to_constant_name;
|
||||||
|
|
@ -541,7 +542,7 @@ void AOTIModelPackageLoader::load_constants(
|
||||||
}
|
}
|
||||||
|
|
||||||
return runner_->update_constant_buffer(
|
return runner_->update_constant_buffer(
|
||||||
updated_constants_map, use_inactive, check_full_update);
|
updated_constants_map, use_inactive, check_full_update, user_managed);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> AOTIModelPackageLoader::get_constant_fqns() {
|
std::vector<std::string> AOTIModelPackageLoader::get_constant_fqns() {
|
||||||
|
|
@ -558,9 +559,10 @@ std::vector<std::string> AOTIModelPackageLoader::get_constant_fqns() {
|
||||||
void AOTIModelPackageLoader::update_constant_buffer(
|
void AOTIModelPackageLoader::update_constant_buffer(
|
||||||
std::unordered_map<std::string, at::Tensor>& tensor_map,
|
std::unordered_map<std::string, at::Tensor>& tensor_map,
|
||||||
bool use_inactive,
|
bool use_inactive,
|
||||||
bool validate_full_updates) {
|
bool validate_full_updates,
|
||||||
|
bool user_managed) {
|
||||||
runner_->update_constant_buffer(
|
runner_->update_constant_buffer(
|
||||||
tensor_map, use_inactive, validate_full_updates);
|
tensor_map, use_inactive, validate_full_updates, user_managed);
|
||||||
}
|
}
|
||||||
} // namespace torch::inductor
|
} // namespace torch::inductor
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -30,13 +30,15 @@ class TORCH_API AOTIModelPackageLoader {
|
||||||
void load_constants(
|
void load_constants(
|
||||||
std::unordered_map<std::string, at::Tensor>& constants_map,
|
std::unordered_map<std::string, at::Tensor>& constants_map,
|
||||||
bool use_inactive,
|
bool use_inactive,
|
||||||
bool check_full_update);
|
bool check_full_update,
|
||||||
|
bool user_managed = false);
|
||||||
std::vector<std::string> get_constant_fqns();
|
std::vector<std::string> get_constant_fqns();
|
||||||
|
|
||||||
void update_constant_buffer(
|
void update_constant_buffer(
|
||||||
std::unordered_map<std::string, at::Tensor>& tensor_map,
|
std::unordered_map<std::string, at::Tensor>& tensor_map,
|
||||||
bool use_inactive,
|
bool use_inactive,
|
||||||
bool validate_full_updates);
|
bool validate_full_updates,
|
||||||
|
bool user_managed = false);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string temp_dir_;
|
std::string temp_dir_;
|
||||||
|
|
|
||||||
|
|
@ -69,9 +69,19 @@ void initAOTIPackageBindings(PyObject* module) {
|
||||||
.def("get_call_spec", &AOTIModelPackageLoaderPybind::get_call_spec)
|
.def("get_call_spec", &AOTIModelPackageLoaderPybind::get_call_spec)
|
||||||
.def(
|
.def(
|
||||||
"get_constant_fqns", &AOTIModelPackageLoaderPybind::get_constant_fqns)
|
"get_constant_fqns", &AOTIModelPackageLoaderPybind::get_constant_fqns)
|
||||||
.def("load_constants", &AOTIModelPackageLoaderPybind::load_constants)
|
.def(
|
||||||
|
"load_constants",
|
||||||
|
&AOTIModelPackageLoaderPybind::load_constants,
|
||||||
|
py::arg("constants_map"),
|
||||||
|
py::arg("use_inactive"),
|
||||||
|
py::arg("check_full_update"),
|
||||||
|
py::arg("user_managed") = false)
|
||||||
.def(
|
.def(
|
||||||
"update_constant_buffer",
|
"update_constant_buffer",
|
||||||
&AOTIModelPackageLoaderPybind::update_constant_buffer);
|
&AOTIModelPackageLoaderPybind::update_constant_buffer,
|
||||||
|
py::arg("tensor_map"),
|
||||||
|
py::arg("use_inactive"),
|
||||||
|
py::arg("validate_full_updates"),
|
||||||
|
py::arg("user_managed") = false);
|
||||||
}
|
}
|
||||||
} // namespace torch::inductor
|
} // namespace torch::inductor
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user