mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[JIT] Add optimize_for_inference API (#58193)
Summary: Freezing exists as a pass which partially evaluates your model and applies generic optimizations which should speed it up. Optimize for inference is a counterpart to these optimizations which runs build & server specific optimizations. The interaction with existing `optimize_frozen_module` is not great, I guess we could just deprecate the API entirely? it was never officially released but just existed to document the `optimize_numerics` keyword. Eventually, I would like to add a way of adding example inputs but I didnt add that here because they are not being used at all yet. I also have not yet included a way to blacklist individual optimizations, and would like to wait until we move this to Beta and have a little more clarity on how everything will fit together. I also think blacklisting will be an uncommon use case for the current optimizations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/58193 Reviewed By: bertmaher, navahgar Differential Revision: D28443714 Pulled By: eellison fbshipit-source-id: b032355bb2585720a6d2f00c89d0d9a7ef60e649
This commit is contained in:
parent
fad2ce439e
commit
211bac53ef
|
|
@ -365,7 +365,10 @@ TEST(ModuleAPITest, Freezing) {
|
|||
auto frozen_mod = torch::jit::freeze(m);
|
||||
auto forward_g = frozen_mod.get_method("forward").graph();
|
||||
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
|
||||
;
|
||||
|
||||
auto frozen_mod2 = torch::jit::optimize_for_inference(m);
|
||||
forward_g = frozen_mod.get_method("forward").graph();
|
||||
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
|
|
|||
|
|
@ -1590,14 +1590,14 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
|
||||
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
|
||||
mod = torch.nn.Sequential(conv, bn)
|
||||
# set optimize to False here, by default freezing runs optimize_frozen_module
|
||||
# set optimize to False here, by default freezing runs run_frozen_optimizations
|
||||
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize_numerics=False)
|
||||
# inspect frozen mod
|
||||
FileCheck().check("batch_norm").run(frozen_mod.graph)
|
||||
torch.jit.optimize_frozen_module(frozen_mod)
|
||||
torch.jit.run_frozen_optimizations(frozen_mod)
|
||||
FileCheck().check_not("batch_norm").run(frozen_mod.graph)
|
||||
|
||||
# optimize_frozen_module should be run
|
||||
# run_frozen_optimizations should be run
|
||||
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
|
||||
FileCheck().check_not("batch_norm").run(frozen_mod.graph)
|
||||
|
||||
|
|
@ -1849,7 +1849,7 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
else:
|
||||
scripted_mod = torch.jit.script(mod_eager)
|
||||
|
||||
frozen_mod = torch.jit.freeze(scripted_mod)
|
||||
frozen_mod = torch.jit.optimize_for_inference(scripted_mod)
|
||||
if add_z:
|
||||
FileCheck().check("aten::cudnn_convolution_add_relu").run(frozen_mod.graph)
|
||||
else:
|
||||
|
|
@ -1993,6 +1993,19 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
# and we aren't testing aten impls anyways
|
||||
self.assertTrue(torch.allclose(aten_op(x, inplace=False), m(x).to_dense()))
|
||||
|
||||
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
|
||||
def test_optimize_for_inference(self):
|
||||
with set_default_dtype(torch.float):
|
||||
mod = nn.Linear(20, 30).eval()
|
||||
scripted_mod = torch.jit.script(mod)
|
||||
|
||||
optimized = torch.jit.optimize_for_inference(scripted_mod)
|
||||
FileCheck().check("to_mkldnn").run(optimized.graph)
|
||||
|
||||
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
|
||||
optimized = torch.jit.optimize_for_inference(scripted_mod)
|
||||
FileCheck().check("to_mkldnn").run(optimized.graph)
|
||||
|
||||
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
|
||||
class TestMKLDNNReinplacing(JitTestCase):
|
||||
def setUp(self):
|
||||
|
|
|
|||
|
|
@ -182,6 +182,7 @@ def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
|
|||
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
|
||||
def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
|
||||
def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ...
|
||||
def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ...
|
||||
def _jit_pass_remove_dropout(module: 'torch.jit.ScriptModule'): ...
|
||||
|
||||
def _is_tracing() -> _bool: ...
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@
|
|||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
|
||||
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
||||
#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
|
||||
|
|
@ -484,6 +486,18 @@ Module freeze(
|
|||
return out_mod;
|
||||
}
|
||||
|
||||
Module optimize_for_inference(Module& module) {
|
||||
// not frozen yet
|
||||
if (module._ivalue()->type()->hasAttribute("training")) {
|
||||
auto mod = freeze(module, {}, true);
|
||||
}
|
||||
|
||||
auto graph = module.get_method("forward").graph();
|
||||
FuseFrozenConvAddRelu(graph);
|
||||
ConvertFrozenOpsToMKLDNN(graph);
|
||||
return module;
|
||||
}
|
||||
|
||||
buffer_list Module::buffers(bool recurse) const {
|
||||
return buffer_list(*this, recurse, /*return_module=*/false);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -295,6 +295,10 @@ TORCH_API Module freeze(
|
|||
c10::optional<std::vector<std::string>> preserved_attrs = c10::nullopt,
|
||||
bool optimize_numerics = true);
|
||||
|
||||
// C++ equivalent api of `torch.jit.optimize_for_inference`. See documentation
|
||||
// there for details.
|
||||
TORCH_API Module optimize_for_inference(Module& module);
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct TORCH_API SlotCursor {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/ir/ir_views.h>
|
||||
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
|
||||
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
|
||||
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
||||
#include <torch/csrc/jit/passes/remove_dropout.h>
|
||||
|
|
@ -22,7 +21,6 @@ void OptimizeFrozenGraph(
|
|||
FoldFrozenConvMulOrDiv(graph);
|
||||
}
|
||||
}
|
||||
FuseFrozenConvAddRelu(graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ from torch.jit._async import fork, wait
|
|||
from torch.jit._serialization import save, load
|
||||
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
|
||||
|
||||
from torch.jit._freeze import freeze, optimize_frozen_module
|
||||
from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations
|
||||
|
||||
# For backwards compatibility
|
||||
_fork = fork
|
||||
|
|
|
|||
|
|
@ -20,6 +20,10 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics:
|
|||
|
||||
Freezing currently only accepts ScriptModules that are in eval mode.
|
||||
|
||||
Freezing applies generic optimization that will speed up your model regardless of machine.
|
||||
To further optimize using server-specific settings, run `optimize_for_inference` after
|
||||
freezing.
|
||||
|
||||
Args:
|
||||
mod (:class:`ScriptModule`): a module to be frozen
|
||||
|
||||
|
|
@ -27,7 +31,7 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics:
|
|||
Attributes modified in preserved methods will also be preserved.
|
||||
|
||||
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
|
||||
preserve numerics. Full details of optimization can be found at `torch.jit.optimize_frozen_module`.
|
||||
preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`.
|
||||
|
||||
Returns:
|
||||
Frozen :class:`ScriptModule`.
|
||||
|
|
@ -83,6 +87,12 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics:
|
|||
If you're not sure why an attribute is not being inlined as a constant, you can run
|
||||
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
|
||||
attribute is being modified.
|
||||
|
||||
Note:
|
||||
Because freezing makes weights constants and removes module hierarchy, `to` and other
|
||||
nn.Module methods to manipulate device or dtype no longer work. As a workaround,
|
||||
You can remap devices by specifying `map_location` in `torch.jit.load`, however
|
||||
device-specific logic may have been baked into the model.
|
||||
"""
|
||||
if not isinstance(mod, ScriptModule):
|
||||
raise RuntimeError(
|
||||
|
|
@ -100,12 +110,11 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics:
|
|||
|
||||
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
|
||||
RecursiveScriptModule._finalize_scriptmodule(out)
|
||||
optimize_frozen_module(out, optimize_numerics)
|
||||
run_frozen_optimizations(out, optimize_numerics)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def optimize_frozen_module(mod, optimize_numerics: bool = True):
|
||||
def run_frozen_optimizations(mod, optimize_numerics: bool = True):
|
||||
r"""
|
||||
Runs a series of optimizations looking for patterns that occur in frozen graphs.
|
||||
The current set of optimizations is:
|
||||
|
|
@ -136,11 +145,11 @@ def optimize_frozen_module(mod, optimize_numerics: bool = True):
|
|||
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
|
||||
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
|
||||
mod = torch.nn.Sequential(conv, bn)
|
||||
# set optimize to False here, by default freezing runs optimize_frozen_module
|
||||
# set optimize to False here, by default freezing runs run_frozen_optimizations
|
||||
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
|
||||
# inspect frozen mod
|
||||
assert "batch_norm" in str(frozen_mod.graph)
|
||||
torch.jit.optimize_frozen_module(frozen_mod)
|
||||
torch.jit.run_frozen_optimizations(frozen_mod)
|
||||
assert "batch_norm" not in str(frozen_mod.graph)
|
||||
|
||||
"""
|
||||
|
|
@ -153,4 +162,32 @@ def optimize_frozen_module(mod, optimize_numerics: bool = True):
|
|||
torch._C._jit_pass_fold_frozen_conv_bn(mod.graph)
|
||||
torch._C._jit_pass_fold_frozen_conv_add_or_sub(mod.graph)
|
||||
torch._C._jit_pass_fold_frozen_conv_mul_or_div(mod.graph)
|
||||
|
||||
def optimize_for_inference(mod: ScriptModule) -> ScriptModule:
|
||||
"""
|
||||
Performs a set of optimization passes to optimize a model for the
|
||||
purposes of inference. If the model is not already frozen, optimize_for_inference
|
||||
will invoke `torch.jit.freeze` automatically.
|
||||
|
||||
In addition to generic optimizations that should speed up your model regardless
|
||||
of environment, prepare for inference will also bake in build specific settings
|
||||
such as the presence of CUDNN or MKLDNN, and may in the future make transformations
|
||||
which speed things up on one machine but slow things down on another. Accordingly,
|
||||
serialization is not implemented following invoking `optimize_for_inference` and
|
||||
is not guaranteed.
|
||||
|
||||
This is still in prototype, and may have the potential to slow down your model.
|
||||
Primary use cases that have been targeted so far have been vision models on cpu
|
||||
and gpu to a lesser extent.
|
||||
"""
|
||||
if not isinstance(mod, ScriptModule):
|
||||
raise RuntimeError(
|
||||
"optimize_for_inference expects a ScriptModule as input. "
|
||||
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'.")
|
||||
|
||||
if hasattr(mod, "training"):
|
||||
mod = freeze(mod.eval())
|
||||
|
||||
torch._C._jit_pass_convert_frozen_ops_to_mkldnn(mod.graph)
|
||||
torch._C._jit_pass_fuse_frozen_conv_add_relu(mod.graph)
|
||||
return mod
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user