[JIT] Add optimize_for_inference API (#58193)

Summary:
Freezing exists as a pass which partially evaluates your model and applies generic optimizations which should speed it up. Optimize for inference is a counterpart to these optimizations which runs build & server specific optimizations.  The interaction with existing `optimize_frozen_module` is not great, I guess we could just deprecate the API entirely? it was never officially released but just existed to document the `optimize_numerics` keyword.

Eventually, I would like to add a way of adding example inputs but I didnt add that here because they are not being used at all yet. I also have not yet included a way to blacklist individual optimizations, and would like to wait until we move this to Beta and have a little more clarity on how everything will fit together. I also think blacklisting will be an uncommon use case for the current optimizations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58193

Reviewed By: bertmaher, navahgar

Differential Revision: D28443714

Pulled By: eellison

fbshipit-source-id: b032355bb2585720a6d2f00c89d0d9a7ef60e649
This commit is contained in:
Elias Ellison 2021-05-15 15:48:42 -07:00 committed by Facebook GitHub Bot
parent fad2ce439e
commit 211bac53ef
8 changed files with 84 additions and 14 deletions

View File

@ -365,7 +365,10 @@ TEST(ModuleAPITest, Freezing) {
auto frozen_mod = torch::jit::freeze(m);
auto forward_g = frozen_mod.get_method("forward").graph();
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
;
auto frozen_mod2 = torch::jit::optimize_for_inference(m);
forward_g = frozen_mod.get_method("forward").graph();
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -1590,14 +1590,14 @@ class TestFrozenOptimizations(JitTestCase):
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
mod = torch.nn.Sequential(conv, bn)
# set optimize to False here, by default freezing runs optimize_frozen_module
# set optimize to False here, by default freezing runs run_frozen_optimizations
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize_numerics=False)
# inspect frozen mod
FileCheck().check("batch_norm").run(frozen_mod.graph)
torch.jit.optimize_frozen_module(frozen_mod)
torch.jit.run_frozen_optimizations(frozen_mod)
FileCheck().check_not("batch_norm").run(frozen_mod.graph)
# optimize_frozen_module should be run
# run_frozen_optimizations should be run
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
FileCheck().check_not("batch_norm").run(frozen_mod.graph)
@ -1849,7 +1849,7 @@ class TestFrozenOptimizations(JitTestCase):
else:
scripted_mod = torch.jit.script(mod_eager)
frozen_mod = torch.jit.freeze(scripted_mod)
frozen_mod = torch.jit.optimize_for_inference(scripted_mod)
if add_z:
FileCheck().check("aten::cudnn_convolution_add_relu").run(frozen_mod.graph)
else:
@ -1993,6 +1993,19 @@ class TestFrozenOptimizations(JitTestCase):
# and we aren't testing aten impls anyways
self.assertTrue(torch.allclose(aten_op(x, inplace=False), m(x).to_dense()))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_optimize_for_inference(self):
with set_default_dtype(torch.float):
mod = nn.Linear(20, 30).eval()
scripted_mod = torch.jit.script(mod)
optimized = torch.jit.optimize_for_inference(scripted_mod)
FileCheck().check("to_mkldnn").run(optimized.graph)
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
optimized = torch.jit.optimize_for_inference(scripted_mod)
FileCheck().check("to_mkldnn").run(optimized.graph)
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
class TestMKLDNNReinplacing(JitTestCase):
def setUp(self):

View File

@ -182,6 +182,7 @@ def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ...
def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ...
def _jit_pass_remove_dropout(module: 'torch.jit.ScriptModule'): ...
def _is_tracing() -> _bool: ...

View File

@ -10,7 +10,9 @@
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/operator.h>
@ -484,6 +486,18 @@ Module freeze(
return out_mod;
}
Module optimize_for_inference(Module& module) {
// not frozen yet
if (module._ivalue()->type()->hasAttribute("training")) {
auto mod = freeze(module, {}, true);
}
auto graph = module.get_method("forward").graph();
FuseFrozenConvAddRelu(graph);
ConvertFrozenOpsToMKLDNN(graph);
return module;
}
buffer_list Module::buffers(bool recurse) const {
return buffer_list(*this, recurse, /*return_module=*/false);
}

View File

@ -295,6 +295,10 @@ TORCH_API Module freeze(
c10::optional<std::vector<std::string>> preserved_attrs = c10::nullopt,
bool optimize_numerics = true);
// C++ equivalent api of `torch.jit.optimize_for_inference`. See documentation
// there for details.
TORCH_API Module optimize_for_inference(Module& module);
namespace detail {
struct TORCH_API SlotCursor {

View File

@ -1,6 +1,5 @@
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/passes/remove_dropout.h>
@ -22,7 +21,6 @@ void OptimizeFrozenGraph(
FoldFrozenConvMulOrDiv(graph);
}
}
FuseFrozenConvAddRelu(graph);
}
} // namespace jit

View File

@ -49,7 +49,7 @@ from torch.jit._async import fork, wait
from torch.jit._serialization import save, load
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
from torch.jit._freeze import freeze, optimize_frozen_module
from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations
# For backwards compatibility
_fork = fork

View File

@ -20,6 +20,10 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics:
Freezing currently only accepts ScriptModules that are in eval mode.
Freezing applies generic optimization that will speed up your model regardless of machine.
To further optimize using server-specific settings, run `optimize_for_inference` after
freezing.
Args:
mod (:class:`ScriptModule`): a module to be frozen
@ -27,7 +31,7 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics:
Attributes modified in preserved methods will also be preserved.
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
preserve numerics. Full details of optimization can be found at `torch.jit.optimize_frozen_module`.
preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`.
Returns:
Frozen :class:`ScriptModule`.
@ -83,6 +87,12 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics:
If you're not sure why an attribute is not being inlined as a constant, you can run
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
attribute is being modified.
Note:
Because freezing makes weights constants and removes module hierarchy, `to` and other
nn.Module methods to manipulate device or dtype no longer work. As a workaround,
You can remap devices by specifying `map_location` in `torch.jit.load`, however
device-specific logic may have been baked into the model.
"""
if not isinstance(mod, ScriptModule):
raise RuntimeError(
@ -100,12 +110,11 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics:
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
RecursiveScriptModule._finalize_scriptmodule(out)
optimize_frozen_module(out, optimize_numerics)
run_frozen_optimizations(out, optimize_numerics)
return out
def optimize_frozen_module(mod, optimize_numerics: bool = True):
def run_frozen_optimizations(mod, optimize_numerics: bool = True):
r"""
Runs a series of optimizations looking for patterns that occur in frozen graphs.
The current set of optimizations is:
@ -136,11 +145,11 @@ def optimize_frozen_module(mod, optimize_numerics: bool = True):
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
mod = torch.nn.Sequential(conv, bn)
# set optimize to False here, by default freezing runs optimize_frozen_module
# set optimize to False here, by default freezing runs run_frozen_optimizations
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
# inspect frozen mod
assert "batch_norm" in str(frozen_mod.graph)
torch.jit.optimize_frozen_module(frozen_mod)
torch.jit.run_frozen_optimizations(frozen_mod)
assert "batch_norm" not in str(frozen_mod.graph)
"""
@ -153,4 +162,32 @@ def optimize_frozen_module(mod, optimize_numerics: bool = True):
torch._C._jit_pass_fold_frozen_conv_bn(mod.graph)
torch._C._jit_pass_fold_frozen_conv_add_or_sub(mod.graph)
torch._C._jit_pass_fold_frozen_conv_mul_or_div(mod.graph)
def optimize_for_inference(mod: ScriptModule) -> ScriptModule:
"""
Performs a set of optimization passes to optimize a model for the
purposes of inference. If the model is not already frozen, optimize_for_inference
will invoke `torch.jit.freeze` automatically.
In addition to generic optimizations that should speed up your model regardless
of environment, prepare for inference will also bake in build specific settings
such as the presence of CUDNN or MKLDNN, and may in the future make transformations
which speed things up on one machine but slow things down on another. Accordingly,
serialization is not implemented following invoking `optimize_for_inference` and
is not guaranteed.
This is still in prototype, and may have the potential to slow down your model.
Primary use cases that have been targeted so far have been vision models on cpu
and gpu to a lesser extent.
"""
if not isinstance(mod, ScriptModule):
raise RuntimeError(
"optimize_for_inference expects a ScriptModule as input. "
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'.")
if hasattr(mod, "training"):
mod = freeze(mod.eval())
torch._C._jit_pass_convert_frozen_ops_to_mkldnn(mod.graph)
torch._C._jit_pass_fuse_frozen_conv_add_relu(mod.graph)
return mod