mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[JIT] Add optimize_for_inference API (#58193)
Summary: Freezing exists as a pass which partially evaluates your model and applies generic optimizations which should speed it up. Optimize for inference is a counterpart to these optimizations which runs build & server specific optimizations. The interaction with existing `optimize_frozen_module` is not great, I guess we could just deprecate the API entirely? it was never officially released but just existed to document the `optimize_numerics` keyword. Eventually, I would like to add a way of adding example inputs but I didnt add that here because they are not being used at all yet. I also have not yet included a way to blacklist individual optimizations, and would like to wait until we move this to Beta and have a little more clarity on how everything will fit together. I also think blacklisting will be an uncommon use case for the current optimizations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/58193 Reviewed By: bertmaher, navahgar Differential Revision: D28443714 Pulled By: eellison fbshipit-source-id: b032355bb2585720a6d2f00c89d0d9a7ef60e649
This commit is contained in:
parent
fad2ce439e
commit
211bac53ef
|
|
@ -365,7 +365,10 @@ TEST(ModuleAPITest, Freezing) {
|
||||||
auto frozen_mod = torch::jit::freeze(m);
|
auto frozen_mod = torch::jit::freeze(m);
|
||||||
auto forward_g = frozen_mod.get_method("forward").graph();
|
auto forward_g = frozen_mod.get_method("forward").graph();
|
||||||
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
|
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
|
||||||
;
|
|
||||||
|
auto frozen_mod2 = torch::jit::optimize_for_inference(m);
|
||||||
|
forward_g = frozen_mod.get_method("forward").graph();
|
||||||
|
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
|
|
|
||||||
|
|
@ -1590,14 +1590,14 @@ class TestFrozenOptimizations(JitTestCase):
|
||||||
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
|
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
|
||||||
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
|
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
|
||||||
mod = torch.nn.Sequential(conv, bn)
|
mod = torch.nn.Sequential(conv, bn)
|
||||||
# set optimize to False here, by default freezing runs optimize_frozen_module
|
# set optimize to False here, by default freezing runs run_frozen_optimizations
|
||||||
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize_numerics=False)
|
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize_numerics=False)
|
||||||
# inspect frozen mod
|
# inspect frozen mod
|
||||||
FileCheck().check("batch_norm").run(frozen_mod.graph)
|
FileCheck().check("batch_norm").run(frozen_mod.graph)
|
||||||
torch.jit.optimize_frozen_module(frozen_mod)
|
torch.jit.run_frozen_optimizations(frozen_mod)
|
||||||
FileCheck().check_not("batch_norm").run(frozen_mod.graph)
|
FileCheck().check_not("batch_norm").run(frozen_mod.graph)
|
||||||
|
|
||||||
# optimize_frozen_module should be run
|
# run_frozen_optimizations should be run
|
||||||
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
|
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
|
||||||
FileCheck().check_not("batch_norm").run(frozen_mod.graph)
|
FileCheck().check_not("batch_norm").run(frozen_mod.graph)
|
||||||
|
|
||||||
|
|
@ -1849,7 +1849,7 @@ class TestFrozenOptimizations(JitTestCase):
|
||||||
else:
|
else:
|
||||||
scripted_mod = torch.jit.script(mod_eager)
|
scripted_mod = torch.jit.script(mod_eager)
|
||||||
|
|
||||||
frozen_mod = torch.jit.freeze(scripted_mod)
|
frozen_mod = torch.jit.optimize_for_inference(scripted_mod)
|
||||||
if add_z:
|
if add_z:
|
||||||
FileCheck().check("aten::cudnn_convolution_add_relu").run(frozen_mod.graph)
|
FileCheck().check("aten::cudnn_convolution_add_relu").run(frozen_mod.graph)
|
||||||
else:
|
else:
|
||||||
|
|
@ -1993,6 +1993,19 @@ class TestFrozenOptimizations(JitTestCase):
|
||||||
# and we aren't testing aten impls anyways
|
# and we aren't testing aten impls anyways
|
||||||
self.assertTrue(torch.allclose(aten_op(x, inplace=False), m(x).to_dense()))
|
self.assertTrue(torch.allclose(aten_op(x, inplace=False), m(x).to_dense()))
|
||||||
|
|
||||||
|
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
|
||||||
|
def test_optimize_for_inference(self):
|
||||||
|
with set_default_dtype(torch.float):
|
||||||
|
mod = nn.Linear(20, 30).eval()
|
||||||
|
scripted_mod = torch.jit.script(mod)
|
||||||
|
|
||||||
|
optimized = torch.jit.optimize_for_inference(scripted_mod)
|
||||||
|
FileCheck().check("to_mkldnn").run(optimized.graph)
|
||||||
|
|
||||||
|
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
|
||||||
|
optimized = torch.jit.optimize_for_inference(scripted_mod)
|
||||||
|
FileCheck().check("to_mkldnn").run(optimized.graph)
|
||||||
|
|
||||||
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
|
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
|
||||||
class TestMKLDNNReinplacing(JitTestCase):
|
class TestMKLDNNReinplacing(JitTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
|
||||||
|
|
@ -182,6 +182,7 @@ def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
|
||||||
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
|
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
|
||||||
def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
|
def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
|
||||||
def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ...
|
def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ...
|
||||||
|
def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ...
|
||||||
def _jit_pass_remove_dropout(module: 'torch.jit.ScriptModule'): ...
|
def _jit_pass_remove_dropout(module: 'torch.jit.ScriptModule'): ...
|
||||||
|
|
||||||
def _is_tracing() -> _bool: ...
|
def _is_tracing() -> _bool: ...
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,9 @@
|
||||||
#include <torch/csrc/jit/jit_log.h>
|
#include <torch/csrc/jit/jit_log.h>
|
||||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||||
|
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
|
||||||
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
||||||
|
#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
|
||||||
#include <torch/csrc/jit/passes/inliner.h>
|
#include <torch/csrc/jit/passes/inliner.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
|
|
||||||
|
|
@ -484,6 +486,18 @@ Module freeze(
|
||||||
return out_mod;
|
return out_mod;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Module optimize_for_inference(Module& module) {
|
||||||
|
// not frozen yet
|
||||||
|
if (module._ivalue()->type()->hasAttribute("training")) {
|
||||||
|
auto mod = freeze(module, {}, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto graph = module.get_method("forward").graph();
|
||||||
|
FuseFrozenConvAddRelu(graph);
|
||||||
|
ConvertFrozenOpsToMKLDNN(graph);
|
||||||
|
return module;
|
||||||
|
}
|
||||||
|
|
||||||
buffer_list Module::buffers(bool recurse) const {
|
buffer_list Module::buffers(bool recurse) const {
|
||||||
return buffer_list(*this, recurse, /*return_module=*/false);
|
return buffer_list(*this, recurse, /*return_module=*/false);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -295,6 +295,10 @@ TORCH_API Module freeze(
|
||||||
c10::optional<std::vector<std::string>> preserved_attrs = c10::nullopt,
|
c10::optional<std::vector<std::string>> preserved_attrs = c10::nullopt,
|
||||||
bool optimize_numerics = true);
|
bool optimize_numerics = true);
|
||||||
|
|
||||||
|
// C++ equivalent api of `torch.jit.optimize_for_inference`. See documentation
|
||||||
|
// there for details.
|
||||||
|
TORCH_API Module optimize_for_inference(Module& module);
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
struct TORCH_API SlotCursor {
|
struct TORCH_API SlotCursor {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||||
#include <torch/csrc/jit/ir/ir_views.h>
|
#include <torch/csrc/jit/ir/ir_views.h>
|
||||||
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
|
|
||||||
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
|
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
|
||||||
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
||||||
#include <torch/csrc/jit/passes/remove_dropout.h>
|
#include <torch/csrc/jit/passes/remove_dropout.h>
|
||||||
|
|
@ -22,7 +21,6 @@ void OptimizeFrozenGraph(
|
||||||
FoldFrozenConvMulOrDiv(graph);
|
FoldFrozenConvMulOrDiv(graph);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
FuseFrozenConvAddRelu(graph);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ from torch.jit._async import fork, wait
|
||||||
from torch.jit._serialization import save, load
|
from torch.jit._serialization import save, load
|
||||||
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
|
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
|
||||||
|
|
||||||
from torch.jit._freeze import freeze, optimize_frozen_module
|
from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations
|
||||||
|
|
||||||
# For backwards compatibility
|
# For backwards compatibility
|
||||||
_fork = fork
|
_fork = fork
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,10 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics:
|
||||||
|
|
||||||
Freezing currently only accepts ScriptModules that are in eval mode.
|
Freezing currently only accepts ScriptModules that are in eval mode.
|
||||||
|
|
||||||
|
Freezing applies generic optimization that will speed up your model regardless of machine.
|
||||||
|
To further optimize using server-specific settings, run `optimize_for_inference` after
|
||||||
|
freezing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mod (:class:`ScriptModule`): a module to be frozen
|
mod (:class:`ScriptModule`): a module to be frozen
|
||||||
|
|
||||||
|
|
@ -27,7 +31,7 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics:
|
||||||
Attributes modified in preserved methods will also be preserved.
|
Attributes modified in preserved methods will also be preserved.
|
||||||
|
|
||||||
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
|
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
|
||||||
preserve numerics. Full details of optimization can be found at `torch.jit.optimize_frozen_module`.
|
preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Frozen :class:`ScriptModule`.
|
Frozen :class:`ScriptModule`.
|
||||||
|
|
@ -83,6 +87,12 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics:
|
||||||
If you're not sure why an attribute is not being inlined as a constant, you can run
|
If you're not sure why an attribute is not being inlined as a constant, you can run
|
||||||
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
|
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
|
||||||
attribute is being modified.
|
attribute is being modified.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Because freezing makes weights constants and removes module hierarchy, `to` and other
|
||||||
|
nn.Module methods to manipulate device or dtype no longer work. As a workaround,
|
||||||
|
You can remap devices by specifying `map_location` in `torch.jit.load`, however
|
||||||
|
device-specific logic may have been baked into the model.
|
||||||
"""
|
"""
|
||||||
if not isinstance(mod, ScriptModule):
|
if not isinstance(mod, ScriptModule):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
@ -100,12 +110,11 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics:
|
||||||
|
|
||||||
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
|
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
|
||||||
RecursiveScriptModule._finalize_scriptmodule(out)
|
RecursiveScriptModule._finalize_scriptmodule(out)
|
||||||
optimize_frozen_module(out, optimize_numerics)
|
run_frozen_optimizations(out, optimize_numerics)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def run_frozen_optimizations(mod, optimize_numerics: bool = True):
|
||||||
def optimize_frozen_module(mod, optimize_numerics: bool = True):
|
|
||||||
r"""
|
r"""
|
||||||
Runs a series of optimizations looking for patterns that occur in frozen graphs.
|
Runs a series of optimizations looking for patterns that occur in frozen graphs.
|
||||||
The current set of optimizations is:
|
The current set of optimizations is:
|
||||||
|
|
@ -136,11 +145,11 @@ def optimize_frozen_module(mod, optimize_numerics: bool = True):
|
||||||
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
|
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
|
||||||
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
|
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
|
||||||
mod = torch.nn.Sequential(conv, bn)
|
mod = torch.nn.Sequential(conv, bn)
|
||||||
# set optimize to False here, by default freezing runs optimize_frozen_module
|
# set optimize to False here, by default freezing runs run_frozen_optimizations
|
||||||
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
|
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
|
||||||
# inspect frozen mod
|
# inspect frozen mod
|
||||||
assert "batch_norm" in str(frozen_mod.graph)
|
assert "batch_norm" in str(frozen_mod.graph)
|
||||||
torch.jit.optimize_frozen_module(frozen_mod)
|
torch.jit.run_frozen_optimizations(frozen_mod)
|
||||||
assert "batch_norm" not in str(frozen_mod.graph)
|
assert "batch_norm" not in str(frozen_mod.graph)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
@ -153,4 +162,32 @@ def optimize_frozen_module(mod, optimize_numerics: bool = True):
|
||||||
torch._C._jit_pass_fold_frozen_conv_bn(mod.graph)
|
torch._C._jit_pass_fold_frozen_conv_bn(mod.graph)
|
||||||
torch._C._jit_pass_fold_frozen_conv_add_or_sub(mod.graph)
|
torch._C._jit_pass_fold_frozen_conv_add_or_sub(mod.graph)
|
||||||
torch._C._jit_pass_fold_frozen_conv_mul_or_div(mod.graph)
|
torch._C._jit_pass_fold_frozen_conv_mul_or_div(mod.graph)
|
||||||
|
|
||||||
|
def optimize_for_inference(mod: ScriptModule) -> ScriptModule:
|
||||||
|
"""
|
||||||
|
Performs a set of optimization passes to optimize a model for the
|
||||||
|
purposes of inference. If the model is not already frozen, optimize_for_inference
|
||||||
|
will invoke `torch.jit.freeze` automatically.
|
||||||
|
|
||||||
|
In addition to generic optimizations that should speed up your model regardless
|
||||||
|
of environment, prepare for inference will also bake in build specific settings
|
||||||
|
such as the presence of CUDNN or MKLDNN, and may in the future make transformations
|
||||||
|
which speed things up on one machine but slow things down on another. Accordingly,
|
||||||
|
serialization is not implemented following invoking `optimize_for_inference` and
|
||||||
|
is not guaranteed.
|
||||||
|
|
||||||
|
This is still in prototype, and may have the potential to slow down your model.
|
||||||
|
Primary use cases that have been targeted so far have been vision models on cpu
|
||||||
|
and gpu to a lesser extent.
|
||||||
|
"""
|
||||||
|
if not isinstance(mod, ScriptModule):
|
||||||
|
raise RuntimeError(
|
||||||
|
"optimize_for_inference expects a ScriptModule as input. "
|
||||||
|
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'.")
|
||||||
|
|
||||||
|
if hasattr(mod, "training"):
|
||||||
|
mod = freeze(mod.eval())
|
||||||
|
|
||||||
|
torch._C._jit_pass_convert_frozen_ops_to_mkldnn(mod.graph)
|
||||||
torch._C._jit_pass_fuse_frozen_conv_add_relu(mod.graph)
|
torch._C._jit_pass_fuse_frozen_conv_add_relu(mod.graph)
|
||||||
|
return mod
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user