pytorch/torch/csrc/jit/passes/freeze_module.cpp
cyy 70d7638b0d Fix clang-tidy suppression in torch/csrc/jit (#152271)
Remove some clang-tidy suppression in torch/csrc/jit by applying fixes or refactoring.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152271
Approved by: https://github.com/Skylion007, https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-27 21:18:39 +00:00

1048 lines
36 KiB
C++

#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/jit_log.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/passes/autocast.h>
#include <torch/csrc/jit/passes/clear_profiling.h>
#include <torch/csrc/jit/passes/eliminate_no_ops.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/runtime/graph_executor_impl.h>
#include <stack>
#include <utility>
namespace torch::jit {
namespace {
std::vector<std::string> splitName(const std::string& name) {
std::vector<std::string> result;
std::string sub_name;
std::istringstream name_stream(name);
while (std::getline(name_stream, sub_name, '.')) {
result.push_back(std::move(sub_name));
}
return result;
}
template <typename Iter>
std::string concatName(const Iter& begin, const Iter& end) {
std::string combined_name;
for (Iter it = begin; it != end; ++it) {
const std::string& sub_name = *it;
if (!combined_name.empty()) {
combined_name += ".";
}
combined_name += sub_name;
}
return combined_name;
}
class AttributePropagator {
public:
AttributePropagator(
Module& module,
std::vector<std::string>& preservedAttrs,
bool freezeInterfaces,
bool preserveParameters)
: module_(module),
freezeInterfaces_(freezeInterfaces),
preserveParameters_(preserveParameters) {
auto checkName = [this](std::string& name) {
const auto resolved_name = resolveName(name);
if (resolved_name) {
const auto& parent_module = resolved_name->first;
const auto& attr_name = resolved_name->second;
if (parent_module.hasattr(attr_name)) {
auto value = parent_module.attr(attr_name);
// Freezing client wants to preserve this submodule. When cleaning
// the frozen module, make sure it will be preserved entirely.
if (value.isModule()) {
preservedSubModule_.insert(value.toModule()._ivalue());
}
insertMutableAttr(attr_name, value, parent_module._ivalue());
} else {
auto fn = parent_module.get_method(attr_name);
preservedMethods_.insert(&fn.function());
}
return true;
}
return false;
};
// forward is preserved by default, but
// not all modules have a forward function defined
if (module_.find_method("forward")) {
auto method = module_.get_method("forward");
preservedMethods_.insert(&method.function());
}
for (auto name : preservedAttrs) {
TORCH_CHECK(checkName(name), "Unknown name: " + name);
}
}
void optimizeSubGraphs(
std::shared_ptr<Graph>& graph,
const std::function<void(std::shared_ptr<Graph>&)>& func) {
func(graph);
std::stack<Block*> blocks({graph->block()});
while (!blocks.empty()) {
Block* block = blocks.top();
blocks.pop();
for (auto n : block->nodes()) {
for (Block* sub_block : n->blocks()) {
blocks.push(sub_block);
}
if (n->kind() == prim::fork) {
auto subgraph = n->g(attr::Subgraph);
optimizeSubGraphs(subgraph, func);
}
}
}
}
void run() {
auto applyInline = [](std::shared_ptr<Graph>& subgraph) {
Inline(*subgraph);
ClearProfilingInformation(subgraph);
};
auto applyOptimizations = [](std::shared_ptr<Graph>& subgraph) {
#ifndef C10_MOBILE
Autocast(subgraph);
#endif
runOptimization(
subgraph,
/* unroll_non_constant_loops? */ false,
/* const_prop_user_classes? */ false);
EliminateNoOps(subgraph);
LowerSimpleTuples(subgraph);
};
std::unordered_map<std::string, std::unordered_set<std::string>>
interfacesToReassignType;
for (auto function : preservedMethods_) {
GRAPH_DEBUG("Analyzing function: " + function->name());
auto graph = toGraphFunction(*function).graph();
optimizeSubGraphs(graph, applyInline);
if (freezeInterfaces_) {
inlineInterfaceCalls(graph, interfacesToReassignType);
}
}
reassignInterfaceTypes(interfacesToReassignType);
for (auto function : preservedMethods_) {
GRAPH_DEBUG("Recording mutable attrs for function: " + function->name());
auto graph = toGraphFunction(*function).graph();
// Record Attributes that are explicitly set in the module.
// They cannot be folded.
recordMutableAttrs(graph);
}
for (auto function : preservedMethods_) {
GRAPH_DEBUG("Propagating function: " + function->name());
auto graph = toGraphFunction(*function).graph();
propagateAttributes(graph);
optimizeSubGraphs(graph, applyOptimizations);
}
GRAPH_DEBUG("Cleaning up module");
cleanupFrozenModule();
}
private:
using ResolvedName = std::pair<Module, std::string>;
// Try to resolve qualified names (submodule1.submodule2.foo). If
// the qualified name exists in the root module, return the unqualified
// attribute/function name and the parent module. Else, return nullopt.
// Examples:
// submodule1.submodule2.foo -> {submodule2, "foo"}
// submodule1.non_existent_module.foo -> nullopt
std::optional<ResolvedName> resolveName(const std::string& name) {
auto sub_names = splitName(name);
if (sub_names.empty()) {
return std::nullopt;
}
auto& attr_name = sub_names.back();
auto cur_module = module_;
std::vector<ResolvedName> attr_infos;
attr_infos.reserve(sub_names.size() - 1);
for (size_t i = 0; i < sub_names.size() - 1; ++i) {
bool found = false;
const auto& sub_name = sub_names[i];
for (const auto& child_module : cur_module.named_children()) {
if (child_module.name == sub_name) {
attr_infos.emplace_back(cur_module._ivalue(), child_module.name);
cur_module = child_module.value;
found = true;
break;
}
}
if (!found) {
return std::nullopt;
}
}
if (cur_module.hasattr(attr_name) || cur_module.find_method(attr_name)) {
// We don't want to mark these modules as mutable yet; that could
// interfere with the inlining procedure. Instead, we'll record
// the fact that the user wants to preserve them. They will be
// processed during clean-up preparation (recordReferenceAttrs)
for (auto& attr_info : attr_infos) {
const auto& parent_module = attr_info.first;
auto& sub_name = attr_info.second;
userPreservedAttrs_[parent_module._ivalue()].insert(
std::move(sub_name));
}
return std::make_pair(std::move(cur_module), std::move(attr_name));
}
return std::nullopt;
}
bool _loadModulePath(Value* input, std::shared_ptr<Graph>& graph) {
Node* node = input->node();
names_.clear();
while (!(node->outputs()[0]->type() == graph->inputs()[0]->type())) {
if (node->kind() == prim::GetAttr) {
names_.push_front(node->s(attr::name));
node = node->inputs()[0]->node();
} else {
return false;
}
}
return true;
}
std::optional<std::deque<std::string>> getModulePath(
Value* input,
std::shared_ptr<Graph>& graph) {
bool success = _loadModulePath(input, graph);
if (!success) {
return std::nullopt;
}
return names_;
}
template <typename Iter>
bool getModuleFromPath(
Module& attrModule,
const Iter& begin,
const Iter& end) {
for (Iter it = begin; it != end; ++it) {
const std::string& moduleName = *it;
if (preservedAttrs_.count(attrModule.attr(moduleName))) {
return false;
}
attrModule = attrModule.attr(moduleName).toModule();
}
return true;
}
// findConstantAttr function locates the sub Module where attributes are
// defined. The algorithm chases getAttr chains to locate the submodules.
// For example:
// module M {
// attributes {
// A = <SubModule at ...>
// }
// ...
// %A = prim::GetAttr[name="A"](%self)
// ...
// %B = prim::GetAttr[name="B"](%A)
// ...
// %weight = prim::GetAttr[name="scale"](%B)
// ...
// submodules {
// module SubModule {
// attributes {
// B = <SubModule2 at ...>
// }
// submodules {
// module SubModule2 {
// attributes {
// scale = 2
// }
// }
// }
// }
// }
//
// findConstantAttr(%B, "scale", M) returns true because there are no
// explicit SetAttr that modifies %B. attrModule points to the module where
// attribute lives (in this example it is <SubModule2 at ...>).
//
// Note inplace mutations to attributes are checked later using alias
// analysis.
//
// We can use a more efficient algorithm to hash each constant GetAttr to its
// corresponding value. Based on initial test on resnet50 and other torch
// vision tests. GetAttrs are not too frequent so it is ok to chase GetAttr
// chain to retrieve their values.
bool findConstantAttr(
Value* input,
std::string& name,
Module& attrModule,
std::shared_ptr<Graph>& graph) {
if (!input->type()->cast<InterfaceType>() &&
!input->type()->expectRef<ClassType>().is_module()) {
return false;
}
// loads the path into this->names_
if (!_loadModulePath(input, graph)) {
return false;
}
// reassigns attrModule to the module in names_
if (!getModuleFromPath(attrModule, names_.begin(), names_.end())) {
return false;
}
auto attr = attrModule.attr(name);
if (!AliasDb::isMutableType(attr.type())) {
auto it = preservedScalarAttrs_.find(attrModule._ivalue());
return it == preservedScalarAttrs_.end() || !it->second.count(name);
}
if (preservedAttrs_.count(attr)) {
return false;
}
if (!attr.type()->cast<ClassType>()) {
for (auto& ivalue : preservedAttrs_) {
if (!ivalue.isObject() && ivalue.overlaps(attr)) {
return false;
}
}
}
return true;
}
void insertMutableAttr(
const std::string& name,
const IValue& attr,
const ModulePtr& attrModule) {
if (AliasDb::isMutableType(attr.type())) {
preservedAttrs_.insert(attr);
} else {
preservedScalarAttrs_[attrModule].insert(name);
}
}
void recordMutableAttrs(std::shared_ptr<Graph>& graph) {
std::stack<Block*> blocks({graph->block()});
std::unique_ptr<AliasDb> aliasDb =
std::make_unique<AliasDb>(graph, /* isFrozen */ true);
while (!blocks.empty()) {
Block* block = blocks.top();
blocks.pop();
for (auto n : block->nodes()) {
for (Block* sub_block : n->blocks()) {
blocks.push(sub_block);
}
// Modules with prim::ModuleContainerIndex cannot be frozen because they
// return InterfaceTypes.
TORCH_CHECK(
n->kind() != prim::ModuleContainerIndex,
"Freezing modules containing prim::ModuleContainerIndex is not supported");
if (n->kind() == prim::SetAttr || n->kind() == prim::GetAttr) {
// By default if interface attributes are present then fail freezing.
// If freezingInterfaces is on then Interfaces are folded similarly
// to other attributes.
TORCH_CHECK(
freezeInterfaces_ ||
!(n->kind() == prim::GetAttr &&
n->output()->type()->cast<InterfaceType>()),
"attempted to freeze a module that uses interface attributes");
auto name = n->s(attr::name);
auto attrModule = module_;
if (!findConstantAttr(n->inputs()[0], name, attrModule, graph)) {
continue;
}
auto attr = attrModule.attr(name);
if (n->kind() == prim::GetAttr) {
auto type = n->output()->type();
// Do not record submodules. Their attributes are tracked
// individually.
if (attr.isObject() || !AliasDb::isMutableType(attr.type())) {
continue;
}
usedAttrs_.insert(attr);
}
if (n->kind() == prim::SetAttr || aliasDb->hasOutputWriters(n)) {
GRAPH_DEBUG(
n->kind() == prim::GetAttr ? "attribute: " + name + " in %" +
n->output()->debugName() + " has inplace writer"
: "attribute: " + name + " is set");
auto mptr = attrModule._ivalue();
insertMutableAttr(name, attr, mptr);
}
} else if (n->kind() == prim::fork) {
applyToForkSubgraph(
n,
graph,
// NOLINTNEXTLINE(modernize-avoid-bind)
std::bind(
&AttributePropagator::recordMutableAttrs,
*this,
std::placeholders::_1));
}
}
}
// FIXME: Current Alias analysis fails to track subvalues.
// This is not a common scenario, for freezing, detect and error out.
IValue::HashAliasedIValues seen;
for (auto& val : usedAttrs_) {
IValue::HashAliasedIValues subValues;
val.getSubValues(subValues);
TORCH_CHECK(
std::all_of(
subValues.begin(),
subValues.end(),
[&seen](const IValue& v) { return seen.count(v) == 0; }),
"module contains attributes values that overlaps ",
val);
seen.insert(subValues.begin(), subValues.end());
}
}
IValue overrideGradient(IValue attr) {
if (attr.isTensor()) {
auto& t = attr.toTensor();
if (t.requires_grad()) {
auto detached = t.detach();
detached.set_requires_grad(false);
attr = IValue(std::move(detached));
}
} else if (attr.isTuple()) {
auto tuple = std::move(attr).toTuple();
const auto& elems = tuple->elements();
for (const auto idx : c10::irange(elems.size())) {
tuple->unsafeSetElement(idx, overrideGradient(elems[idx]));
}
attr = std::move(tuple);
} else if (attr.isList()) {
c10::List<IValue> elems = std::move(attr).toList();
for (const auto i : c10::irange(elems.size())) {
elems.set(i, overrideGradient(elems.extract(i)));
}
attr = elems;
} else if (attr.isGenericDict()) {
auto dict = std::move(attr).toGenericDict();
for (const auto& pair : dict) {
auto val = pair.value();
val = overrideGradient(std::move(val));
}
attr = dict;
} else if (attr.isObject() && !attr.toObjectRef().type()->is_module()) {
auto obj_type = attr.type()->expect<ClassType>();
auto obj_value = std::move(attr).toObject();
auto sub_attributes = obj_type->getAttributes();
for (const auto& sub_attr : sub_attributes) {
auto sub_attr_val = obj_value->getAttr(sub_attr.getName());
sub_attr_val = overrideGradient(std::move(sub_attr_val));
}
return obj_value;
}
return attr;
}
// This method is invoked only when 'freezeInterfaces' parameter is on.
// The module associated with Interface is retrieved and the invoked method
// is inlined.
bool inlineInterfaceCall(Node* n, const IValue& attr) {
auto class_type = attr.type()->expect<ClassType>();
bool inlined = false;
for (auto use : n->output()->uses()) {
auto user_node = use.user;
if (user_node->kind() == prim::CallMethod) {
const std::string& methodName = user_node->s(attr::name);
Function& function = class_type->getMethod(methodName);
if (auto graphFunction = tryToGraphFunction(function)) {
GRAPH_UPDATE(
"Inlining interface method '",
function.name(),
"' to ",
*user_node);
GRAPH_UPDATE("Function body: ", graphFunction->optimized_graph());
inlineCallTo(user_node, graphFunction);
inlined = true;
}
}
}
return inlined;
}
// [Note: Inlining interfaces strategy]
// There's two structures that are relevant to freezing:
// - the graph describing the computation in a method
// - the module describing the data structure of the module instance.
//
// First, in inlineInterfaceCalls, we inline interfaces. This is done in a
// separate step from normal inlining because CallMethod on an interface type
// requires extra steps compared to inlining a normal CallMethod.
//
// Next we need to simplify the structure of the module data structure, which
// is done for the most part by the usual steps in cleanupFrozenModule.
//
// However, there's a complication that comes from the fact that within a
// method, you can change the value of an interface to another module that
// implements that interface.
//
// For example:
//
// impl: MyInterface
// ...
// def forward(self, x):
// if x > 0:
// self.impl = my_interface_impl
//
// This is disallowed in freezing, because in this case we can't flatten out
// the module structure, since the type of self.impl will change.
//
// To handle this, we do the following:
// 1. inlineInterfaceCalls:
// a. inline the graph, and in the process record all interfaces
// b. simultaneously, check (throw error) for disallowed SetAttr calls.
// 2. call reassignInterfaceTypes, which reassigns interface types to their
// concrete types. This is done in a separate step to avoid interfering
// with inlineInterfaceCalls (note: this may not need to be done as a
// separate step)
// 3. eventually cleanupFrozenModule will reorder the module data structure
// and it will expect that all interface types have been removed.
void inlineInterfaceCalls(
std::shared_ptr<Graph>& graph,
std::unordered_map<std::string, std::unordered_set<std::string>>&
interfacesToRetype) {
auto block = graph->block();
std::stack<Block*> blocks({block});
while (!blocks.empty()) {
Block* block = blocks.top();
blocks.pop();
for (auto n : block->nodes()) {
for (Block* sub_block : n->blocks()) {
blocks.push(sub_block);
}
if (n->kind() == prim::GetAttr) {
if (!n->output()->type()->cast<InterfaceType>()) {
continue;
}
auto name = n->s(attr::name);
auto attrModule = module_;
auto input = n->inputs()[0];
TORCH_CHECK(
findConstantAttr(input, name, attrModule, graph),
"failed to freeze interface attribute '" + name + "'");
TORCH_INTERNAL_ASSERT(attrModule.hasattr(name));
auto attr = attrModule.attr(name);
inlineInterfaceCall(n, attr);
// Reset the GetAttr to concrete module type.
n->output()->setType(attr.type());
// Record this so that we can reassign the type later
// in reassignInterfaceTypes()
// See [Note: Inlining interfaces strategy]
auto path = getModulePath(input, graph);
TORCH_INTERNAL_ASSERT(path.has_value());
auto path_str = concatName(path->begin(), path->end());
interfacesToRetype[path_str].insert(name);
} else if (n->kind() == prim::SetAttr) {
// Check to make sure we're not assigning the value of any parameters
// that are interface types.
// See [Note: Inlining interfaces strategy]
auto name = n->s(attr::name);
auto attrModule = module_;
auto input = n->inputs()[0];
if (!input->type()->cast<InterfaceType>() &&
!input->type()->expectRef<ClassType>().is_module()) {
// we only care if we're setattr["thing"](%mod) if %mod
continue;
}
// note: this will modify attrModule until it is the parent of the
// "name" attr. In other words, attrModule is now the module that
// matches "input".
// We can't use findConstantAttr in case the base item is an object,
// instead of a module/interface.
auto path = getModulePath(input, graph);
TORCH_INTERNAL_ASSERT(path.has_value());
getModuleFromPath(attrModule, path->begin(), path->end());
const auto& attrType = attrModule.type()->getAttribute(name);
TORCH_INTERNAL_ASSERT(
!attrType->cast<InterfaceType>(),
"Freezing does not support SetAttr on an interface type. ",
"SetAttr is attempted on '",
name,
"'");
} else if (n->kind() == prim::fork) {
applyToForkSubgraph(
n,
graph,
// NOLINTNEXTLINE(modernize-avoid-bind)
std::bind(
&AttributePropagator::inlineInterfaceCalls,
*this,
std::placeholders::_1,
interfacesToRetype));
}
}
}
}
// See [Note: Inlining interfaces strategy]
// This modifies the internal structure of module types to reassign the
// type from an interface type to its concrete type.
void reassignInterfaceTypes(
const std::unordered_map<std::string, std::unordered_set<std::string>>&
interfacesToRetype) {
for (const auto& it : interfacesToRetype) {
const std::string& modulePath = it.first;
const std::vector<std::string>& splitPath = splitName(modulePath);
Module attrModule = module_;
getModuleFromPath(attrModule, splitPath.begin(), splitPath.end());
for (const std::string& name : it.second) {
auto subvalue = attrModule.attr(name);
auto subvalueType = subvalue.type();
attrModule.type()->unsafeChangeAttributeType(name, subvalueType);
}
}
}
void propagateAttributes(std::shared_ptr<Graph>& graph) {
std::unordered_map<ModulePtr, std::unordered_map<std::string, Value*>>
attrValues;
auto isEval = !module_.hasattr("training") || !module_.is_training();
GRAPH_DEBUG("Freezing Module: ", module_.type()->name()->name());
auto block = graph->block();
std::stack<Block*> blocks({block});
Node* m = *block->nodes().begin();
WithInsertPoint guard(m);
while (!blocks.empty()) {
Block* block = blocks.top();
blocks.pop();
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
Node* n = *it;
it++; // advance iterator bc the current node may be destroyed
for (Block* sub_block : n->blocks()) {
blocks.push(sub_block);
}
if (n->kind() == prim::GetAttr) {
auto name = n->s(attr::name);
auto attrModule = module_;
auto input = n->inputs()[0];
if (!findConstantAttr(input, name, attrModule, graph)) {
GRAPH_DEBUG(
input->type()->cast<InterfaceType>() ||
input->type()->expectRef<ClassType>().is_module()
? "attribute: " + name + " is mutable."
: "");
continue;
}
TORCH_INTERNAL_ASSERT(attrModule.hasattr(name));
Value* paramConst = nullptr;
auto iter = attrValues.find(attrModule._ivalue());
if (iter != attrValues.end()) {
auto iter2 = iter->second.find(name);
if (iter2 != iter->second.end())
paramConst = iter2->second;
}
if (!paramConst) {
auto attr = attrModule.attr(name);
if (!isEval || preserveParameters_) {
auto type = attrModule.type();
auto slot = *type->findAttributeSlot(name);
if (type->is_parameter(slot) || type->is_buffer(slot) ||
(attr.isObject() &&
!attr.toObjectRef().type()->is_module())) {
continue;
} else {
attr = overrideGradient(attr);
}
if (!isEval && name == "training") {
continue;
}
} else {
attr = overrideGradient(attr);
}
if (attr.isObject()) {
if (object_memo_.count(attr.toObject())) {
attr = object_memo_[attr.toObject()];
} else {
auto weak_class_obj =
attr.toObject()->copy_to_weak_compilation_ref();
object_memo_[attr.toObject()] = weak_class_obj;
attr = weak_class_obj;
}
}
if (auto attrVal = tryInsertConstant(*graph, attr)) {
paramConst = *attrVal;
} else {
GRAPH_DEBUG(
attr.type()->cast<ClassType>() ? "" : "attribute: ",
name,
" is not materializable.");
continue;
}
std::string fullName("self.");
for (auto& name : names_) {
fullName += name + '.';
}
fullName += name;
paramConst->setDebugName(fullName);
attrValues[attrModule._ivalue()][name] = paramConst;
}
GRAPH_UPDATE(
"Folding GetAttr %",
n->outputs()[0]->debugName(),
" with ",
paramConst->debugName());
n->outputs().at(0)->replaceAllUsesWith(paramConst);
n->removeAllInputs();
} else if (n->kind() == prim::fork) {
applyToForkSubgraph(
n,
graph,
// NOLINTNEXTLINE(modernize-avoid-bind)
std::bind(
&AttributePropagator::propagateAttributes,
*this,
std::placeholders::_1));
}
}
}
}
void applyToForkSubgraph(
Node* n,
std::shared_ptr<Graph>& graph,
const std::function<void(std::shared_ptr<Graph>&)>& func) {
TORCH_CHECK(n->kind() == prim::fork);
auto attrModule = module_;
auto node = n->inputs()[0]->node();
// Check if first parameter of fork is a module. This module is used
// as the base module (similar to 'self' in forward) to resolve GetAttrs.
// Otherwise freezing is applied using module_
if (node->kind() == prim::GetAttr &&
node->output()->type()->cast<ClassType>()) {
auto name = node->s(attr::name);
auto input = node->inputs()[0];
if (!findConstantAttr(input, name, attrModule, graph)) {
// Module needs to be preserved.
return;
}
attrModule = attrModule.attr(name).toModule();
std::swap(module_, attrModule);
}
auto subgraph = n->g(attr::Subgraph);
func(subgraph);
module_ = attrModule;
}
bool moduleEscapes(Module& subModule, std::shared_ptr<Graph>& graph) {
for (auto& output : graph->outputs()) {
if (subModule.type()->isSubtypeOf(*output->type())) {
return true;
}
}
return preservedSubModule_.count(subModule._ivalue());
}
void removeExtraWaitCalls(Block* b) {
auto nodes = b->nodes();
for (auto it = nodes.begin(); it != nodes.end(); it++) {
auto node = *it;
if (node->kind() != aten::wait) {
continue;
}
TORCH_INTERNAL_ASSERT(node->inputs().size() == 1);
TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
// If input type is not a from aten::fork call then the
// aten::wait operator can be deleted.
if (node->input()->type()->kind() != TypeKind::FutureType) {
node->output()->replaceAllUsesWith(node->input());
it.destroyCurrent();
}
}
// For the remaining nodes, recurse.
for (auto it = nodes.begin(); it != nodes.end(); it++) {
auto node = *it;
for (auto sub_b : node->blocks()) {
removeExtraWaitCalls(sub_b);
}
}
}
// cleanupFrozenModule function cleans up the Frozen module. It performs the
// following:
// 1) Remove unused attributes.
// 2) Remove unreferenced submodules
// 3) Remove non public unreferenced methods.
void cleanupFrozenModule() {
for (auto function : preservedMethods_) {
auto graph = toGraphFunction(*function).graph();
recordReferencedAttrs(graph);
handleSharedClassType(module_, graph);
removeExtraWaitCalls(graph->block());
toGraphFunction(*function).clear_optimized_graphs();
}
removeUnusedAttrs();
}
// Preparing for clean up phase. At this point, record all subModules that
// contains mutable attributes.
void recordReferencedAttrs(std::shared_ptr<Graph>& graph) {
std::stack<Block*> blocks({graph->block()});
std::set<ModulePtr> modules({module_._ivalue()});
while (!blocks.empty()) {
Block* block = blocks.top();
blocks.pop();
for (auto n : block->nodes()) {
for (Block* subBlock : n->blocks()) {
blocks.push(subBlock);
}
if (n->kind() == prim::GetAttr) {
auto& name = n->s(attr::name);
// For now, use all module ivalues which are the same type
// and could be the module that this GetAttr resolves to
// TODO: we could attempt to follow the GetAttr chain and
// find the exact ivalue, we would have to be careful
// that the chain does not contain any attributes which
// get written to (setAttr calls)
for (auto& mptr : modules) {
auto module = Module(mptr);
if (module.type() == n->inputs()[0]->type()) {
TORCH_INTERNAL_ASSERT(module.hasattr(name));
auto module = Module(mptr);
auto attr = module.attr(name);
// TODO: this could be insertReferencedAttr to be more clear,
// these are attributes we could not inline, which include
// other reasons besides mutation (unsupported constant,
// getAttr resolving to non-getAttr node, etc)
insertMutableAttr(name, attr, mptr);
if (attr.isModule()) {
modules.insert(attr.toModule()._ivalue());
}
}
}
} else if (n->kind() == prim::fork) {
applyToForkSubgraph(
n,
graph,
// NOLINTNEXTLINE(modernize-avoid-bind)
std::bind(
&AttributePropagator::recordReferencedAttrs,
*this,
std::placeholders::_1));
}
}
}
// We have to process the attributes that the user wants to preserve
// separately since it's possible that the user-preserved module is
// never referenced in the graph.
for (const auto& attr_info : userPreservedAttrs_) {
const auto& parent_module = attr_info.first;
for (const auto& attr_name : attr_info.second) {
const auto value = parent_module->getAttr(attr_name);
insertMutableAttr(attr_name, value, parent_module);
}
}
}
// This function recursively iterates over submodules to identify
// for each class type the attribute slots that need to be preserved.
//
// Note 'attrsToKeep[type].insert(type->numAttributes())' means all
// attribute slots of 'type' and its methods are preserved. A submodule is
// preserved when it escapes (meaning it is returned).
void handleSharedClassType(Module& module, std::shared_ptr<Graph>& graph) {
auto type = module.type();
size_t N = type->numAttributes();
if (moduleEscapes(module, graph)) {
// Preserve all its attributes and methods.
attrsToKeep_[type].insert(N);
return;
}
auto it2 = preservedScalarAttrs_.find(module._ivalue());
SharedTypeSubModules_[type].insert(module._ivalue());
attrsToKeep_[type].insert({});
for (const auto i : c10::irange(N)) {
auto name = type->getAttributeName(i);
auto attr = module.attr(name);
auto attrTy = attr.type();
bool isMutable = false;
if (AliasDb::isMutableType(attrTy)) {
isMutable = preservedAttrs_.count(attr);
} else {
isMutable =
it2 != preservedScalarAttrs_.end() && it2->second.count(name);
}
if (isMutable) {
attrsToKeep_[type].insert(i);
if (attr.isModule()) {
// See [Note: Inlining interfaces strategy]
TORCH_CHECK(
!type->getAttribute(i)->cast<InterfaceType>(),
"Unexpected interface attribute '" + name + "' during freezing");
auto attrModule = attr.toModule();
handleSharedClassType(attrModule, graph);
}
}
}
}
// Remove unused attributes and methods for each sub module of the frozen
// module. This function iterates over the Classtypes of its submodule
// attributes including its own type.
void removeUnusedAttrs() {
std::vector<std::string> attrsToRemove;
std::vector<Function*> funcsToRemove;
for (auto& it : attrsToKeep_) {
auto& type = it.first;
size_t N = type->numAttributes();
if (it.second.count(N)) {
continue;
}
for (const auto i : c10::irange(N)) {
if (it.second.count(i) == 0) {
attrsToRemove.push_back(type->getAttributeName(i));
}
}
for (auto& fn : type->methods()) {
if (preservedMethods_.count(fn)) {
continue;
}
funcsToRemove.push_back(fn);
}
for (auto& name : attrsToRemove) {
for (auto& val : SharedTypeSubModules_[type]) {
auto mod = val.toModule();
mod._ivalue()->unsafeRemoveAttr(name);
}
type->unsafeRemoveAttribute(name);
}
for (auto fn : funcsToRemove) {
type->unsafeRemoveMethod(fn->name());
auto mod = SharedTypeSubModules_[type].begin()->toModule();
mod._ivalue()->compilation_unit()->unsafeRemoveMethod(fn->qualname());
}
attrsToRemove.clear();
funcsToRemove.clear();
}
}
// Contains attributes that can't be folded or user directs to keep them.
IValue::HashAliasedIValues preservedAttrs_;
// Tracked immutable types (Scalars) by their attribute names not
// IValues.
std::unordered_map<ModulePtr, std::unordered_set<std::string>>
preservedScalarAttrs_;
// Contains user specified methods to be preserved in frozen module.
std::unordered_set<Function*> preservedMethods_;
// Contains user specified sub module to be preserve in frozen module.
std::unordered_set<ModulePtr> preservedSubModule_;
// Track all used attributes ivalues that can be aliased.
IValue::HashAliasedIValues usedAttrs_;
// Contains the attribute slots that need to be preserved for each ClassType.
std::unordered_map<ClassTypePtr, std::unordered_set<size_t>> attrsToKeep_;
// Contains the sub modules that share the same ClassType.
std::unordered_map<ClassTypePtr, IValue::HashAliasedIValues>
SharedTypeSubModules_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
Module& module_;
// Allow to freeze modules containing interfaces.
bool freezeInterfaces_;
// Preserve module parameters
bool preserveParameters_;
// Contains the attributes names (e.g. {"self", "subModule", "a"}
std::deque<std::string> names_;
// see [Constant Object Weak CompilationUnit Reference]
std::unordered_map<
c10::intrusive_ptr<at::ivalue::Object>,
c10::intrusive_ptr<at::ivalue::Object>>
object_memo_;
// Contains names of attributes that the user wants to preserve with
// their owning modules.
std::unordered_map<ModulePtr, std::unordered_set<std::string>>
userPreservedAttrs_;
}; // class AttributePropagator
void checkModuleDoesNotReturnSelf(const Module& module) {
if (module.find_method("forward")) {
Method method = module.get_method("forward");
// Check that module does not return itself.
for (auto& output : method.graph()->outputs()) {
TORCH_CHECK(
output->type() != module.type(),
"attempted to freeze a module that return itself");
}
}
}
} // namespace
Module freeze_module(
const Module& module,
std::vector<std::string> preservedAttrs,
bool freezeInterfaces,
bool preserveParameters) {
checkModuleDoesNotReturnSelf(module);
auto moduleClone = module.clone(true);
AttributePropagator attrPropagator(
moduleClone, preservedAttrs, freezeInterfaces, preserveParameters);
attrPropagator.run();
return moduleClone;
}
void freeze_module_inplace(
Module* module,
std::vector<std::string> preservedAttrs,
bool freezeInterfaces,
bool preserveParameters) {
TORCH_CHECK(module != nullptr, "module cannot be nullptr");
checkModuleDoesNotReturnSelf(*module);
AttributePropagator attrPropagator(
*module, preservedAttrs, freezeInterfaces, preserveParameters);
attrPropagator.run();
}
} // namespace torch::jit