mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTInductor] Add interface for user managed buffer in package api. (#151325)
Summary: https://github.com/pytorch/pytorch/pull/151141 We add interface for user managed buffer in the package api. Test Plan: Included in commit.] Reviewed By: henrylhtsang Differential Revision: D72985440 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151325 Approved by: https://github.com/angelayi
This commit is contained in:
parent
82200e33b5
commit
107121dfad
|
|
@ -467,6 +467,63 @@ class TestAOTInductorPackage(TestCase):
|
|||
output = compiled(test_inputs)
|
||||
self.assertEqual(expected, output)
|
||||
|
||||
@skipif(
|
||||
lambda device, package_cpp_only: device == "cpu" or package_cpp_only,
|
||||
"No support for cpp only and cpu",
|
||||
)
|
||||
def test_package_user_managed_weight(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, n, k, device):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(k, n, device=device)
|
||||
|
||||
def forward(self, a):
|
||||
return self.linear(a)
|
||||
|
||||
M, N, K = 128, 4096, 4096
|
||||
model = Model(N, K, self.device)
|
||||
example_inputs = (torch.randn(M, K, device=self.device),)
|
||||
|
||||
inductor_configs = {
|
||||
"always_keep_tensor_constants": True,
|
||||
"aot_inductor.package_constants_in_so": False,
|
||||
}
|
||||
compiled = compile(model, example_inputs, inductor_configs=inductor_configs)
|
||||
|
||||
self.assertEqual(
|
||||
set(compiled.get_constant_fqns()), set(model.state_dict().keys())
|
||||
)
|
||||
|
||||
compiled.load_constants(
|
||||
model.state_dict(), check_full_update=True, user_managed=False
|
||||
)
|
||||
|
||||
test_inputs = torch.randn(M, K, device=self.device)
|
||||
expected = model(test_inputs)
|
||||
output = compiled(test_inputs)
|
||||
self.assertEqual(expected, output)
|
||||
|
||||
# Let's try to modify the weight in-place, result shouldn't change.
|
||||
model.linear.weight.data *= 3.7
|
||||
new_output = compiled(test_inputs)
|
||||
self.assertEqual(new_output, output)
|
||||
|
||||
# Recreate a new model that we will test against user_managed=True
|
||||
new_compiled = compile(model, example_inputs, inductor_configs=inductor_configs)
|
||||
new_compiled.load_constants(
|
||||
model.state_dict(), check_full_update=True, user_managed=True
|
||||
)
|
||||
|
||||
expected = model(test_inputs)
|
||||
new_output = new_compiled(test_inputs)
|
||||
self.assertEqual(expected, new_output)
|
||||
|
||||
# Try to modify the weight in-place, result should change.
|
||||
model.linear.weight.data *= 3.7
|
||||
expected = model(test_inputs)
|
||||
new_output = new_compiled(test_inputs)
|
||||
self.assertEqual(new_output, expected)
|
||||
|
||||
def test_deepcopy_compiled_model(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
|
|
|
|||
|
|
@ -259,6 +259,7 @@ class AOTICompiledModel:
|
|||
constants_map: dict[str, torch.Tensor],
|
||||
*,
|
||||
check_full_update: bool,
|
||||
user_managed: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Given a mapping of constant fqns to tensors, load the constants into the model.
|
||||
|
|
@ -270,7 +271,9 @@ class AOTICompiledModel:
|
|||
check_full_update: Whether to add check to see if all the constants
|
||||
are updated and have values.
|
||||
"""
|
||||
self.loader.load_constants(constants_map, False, check_full_update) # type: ignore[attr-defined]
|
||||
self.loader.load_constants( # type: ignore[attr-defined]
|
||||
constants_map, False, check_full_update, user_managed
|
||||
)
|
||||
|
||||
def get_constant_fqns(self) -> list[str]:
|
||||
return self.loader.get_constant_fqns() # type: ignore[attr-defined]
|
||||
|
|
|
|||
|
|
@ -523,7 +523,8 @@ std::vector<std::string> AOTIModelPackageLoader::get_call_spec() {
|
|||
void AOTIModelPackageLoader::load_constants(
|
||||
std::unordered_map<std::string, at::Tensor>& constants_map,
|
||||
bool use_inactive,
|
||||
bool check_full_update) {
|
||||
bool check_full_update,
|
||||
bool user_managed) {
|
||||
std::unordered_map<std::string, std::string> constant_name_to_fqn =
|
||||
runner_->getConstantNamesToOriginalFQNs();
|
||||
std::unordered_map<std::string, at::string> fqn_to_constant_name;
|
||||
|
|
@ -541,7 +542,7 @@ void AOTIModelPackageLoader::load_constants(
|
|||
}
|
||||
|
||||
return runner_->update_constant_buffer(
|
||||
updated_constants_map, use_inactive, check_full_update);
|
||||
updated_constants_map, use_inactive, check_full_update, user_managed);
|
||||
}
|
||||
|
||||
std::vector<std::string> AOTIModelPackageLoader::get_constant_fqns() {
|
||||
|
|
@ -558,9 +559,10 @@ std::vector<std::string> AOTIModelPackageLoader::get_constant_fqns() {
|
|||
void AOTIModelPackageLoader::update_constant_buffer(
|
||||
std::unordered_map<std::string, at::Tensor>& tensor_map,
|
||||
bool use_inactive,
|
||||
bool validate_full_updates) {
|
||||
bool validate_full_updates,
|
||||
bool user_managed) {
|
||||
runner_->update_constant_buffer(
|
||||
tensor_map, use_inactive, validate_full_updates);
|
||||
tensor_map, use_inactive, validate_full_updates, user_managed);
|
||||
}
|
||||
} // namespace torch::inductor
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -30,13 +30,15 @@ class TORCH_API AOTIModelPackageLoader {
|
|||
void load_constants(
|
||||
std::unordered_map<std::string, at::Tensor>& constants_map,
|
||||
bool use_inactive,
|
||||
bool check_full_update);
|
||||
bool check_full_update,
|
||||
bool user_managed = false);
|
||||
std::vector<std::string> get_constant_fqns();
|
||||
|
||||
void update_constant_buffer(
|
||||
std::unordered_map<std::string, at::Tensor>& tensor_map,
|
||||
bool use_inactive,
|
||||
bool validate_full_updates);
|
||||
bool validate_full_updates,
|
||||
bool user_managed = false);
|
||||
|
||||
private:
|
||||
std::string temp_dir_;
|
||||
|
|
|
|||
|
|
@ -69,9 +69,19 @@ void initAOTIPackageBindings(PyObject* module) {
|
|||
.def("get_call_spec", &AOTIModelPackageLoaderPybind::get_call_spec)
|
||||
.def(
|
||||
"get_constant_fqns", &AOTIModelPackageLoaderPybind::get_constant_fqns)
|
||||
.def("load_constants", &AOTIModelPackageLoaderPybind::load_constants)
|
||||
.def(
|
||||
"load_constants",
|
||||
&AOTIModelPackageLoaderPybind::load_constants,
|
||||
py::arg("constants_map"),
|
||||
py::arg("use_inactive"),
|
||||
py::arg("check_full_update"),
|
||||
py::arg("user_managed") = false)
|
||||
.def(
|
||||
"update_constant_buffer",
|
||||
&AOTIModelPackageLoaderPybind::update_constant_buffer);
|
||||
&AOTIModelPackageLoaderPybind::update_constant_buffer,
|
||||
py::arg("tensor_map"),
|
||||
py::arg("use_inactive"),
|
||||
py::arg("validate_full_updates"),
|
||||
py::arg("user_managed") = false);
|
||||
}
|
||||
} // namespace torch::inductor
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user