mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[GR v0] AOTI Enablement - Fix GR model AOTI inplace update by skipping empty named (#165970) (#166037)
Summary: Add a gflag to allow us skip empty constant named parameter during dense loading. In [vm_parameters.py](https://fburl.com/code/7xr9ihwy), there is a constant _empty_tensor parameter used for the model. This constant parameter is skipped in XL weights during model publish because it is empty. This will break model inplace update later because it will be reported by the AOTI container but cannot be found from the model merge weights. This diff will allow us to solve the problem. Test Plan: Verified inplace update in job https://www.internalfb.com/vanguard/serving_test_cases/1165842932095688 Reviewed By: muchulee8, joannec3634 Differential Revision: D85082330 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166037 Approved by: https://github.com/muchulee8, https://github.com/jcwchen
This commit is contained in:
parent
add37bacda
commit
17bdb232e1
|
|
@ -348,8 +348,19 @@ class AOTInductorModelContainer {
|
|||
return constant_type == ConstantType::Buffer;
|
||||
}
|
||||
|
||||
bool _is_tensor_constant_or_buffer_type(const size_t idx) const {
|
||||
return _is_tensor_constant_type(idx) || _is_buffer_type(idx);
|
||||
bool _is_empty_parameter_type(const size_t idx) const {
|
||||
auto constant_type = models_[0]->constant_type(static_cast<int64_t>(idx));
|
||||
auto constant_data_size =
|
||||
models_[0]->constant_data_size(static_cast<int64_t>(idx));
|
||||
// Empty parameters are skipped and not provided by the upstream services,
|
||||
// it is OK to skip.
|
||||
return constant_type == ConstantType::Parameter && constant_data_size == 0;
|
||||
}
|
||||
|
||||
bool _is_tensor_constant_or_buffer_type_or_empty_parameter(
|
||||
const size_t idx) const {
|
||||
return _is_tensor_constant_type(idx) || _is_buffer_type(idx) ||
|
||||
_is_empty_parameter_type(idx);
|
||||
}
|
||||
|
||||
void assert_all_constants(
|
||||
|
|
@ -364,11 +375,11 @@ class AOTInductorModelContainer {
|
|||
std::string(models_[0]->constant_name(static_cast<int64_t>(idx)));
|
||||
auto it = constants_map.find(constant_name);
|
||||
if (it == constants_map.end()) {
|
||||
if (_is_tensor_constant_or_buffer_type(idx)) {
|
||||
if (_is_tensor_constant_or_buffer_type_or_empty_parameter(idx)) {
|
||||
// tracing sometimes creates tensors that are non-existent in
|
||||
// original graph. We could skip those and do a direct copy.
|
||||
std::cerr << "[WARNING] Found constant or module state buffer "
|
||||
<< constant_name
|
||||
std::cerr << "[WARNING] Found constant or module state buffer or "
|
||||
<< "empty module state parameter " << constant_name
|
||||
<< " in model, but not provided by user!\n";
|
||||
continue;
|
||||
}
|
||||
|
|
@ -453,7 +464,8 @@ class AOTInductorModelContainer {
|
|||
std::string(models_[0]->constant_name(static_cast<int64_t>(idx)));
|
||||
auto it = constants_map.find(constant_name);
|
||||
if (it == constants_map.end() &&
|
||||
!(use_inactive && _is_tensor_constant_or_buffer_type(idx))) {
|
||||
!(use_inactive &&
|
||||
_is_tensor_constant_or_buffer_type_or_empty_parameter(idx))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user