move misc implementation out of jit/__init__.py (#41154)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/41154

Test Plan: Imported from OSS

Reviewed By: ailzhang

Differential Revision: D22445213

Pulled By: suo

fbshipit-source-id: 200545715c5ef13beb1437f49e01efb21498ddb7
This commit is contained in:
Michael Suo 2020-07-13 16:57:41 -07:00 committed by Facebook GitHub Bot
parent 6392713584
commit ca1b8ebbcb
21 changed files with 403 additions and 357 deletions

View File

@ -487,7 +487,7 @@ class TestModels(JitTestCase):
return self.seq.forward(input)
# disabled due to a jitter issues that will be fixed by using load/store in the compiler
with torch.jit._disable_emit_hooks():
with torch._jit_internal._disable_emit_hooks():
# TODO: toggle export_import once above issues are fixed
self.checkTrace(Traced(), (torch.rand(3, 4),),
export_import=False)

View File

@ -1817,7 +1817,7 @@ class TestDeprecatedJitQuantized(JitTestCase):
def weight(self, w):
self._packed_weight = torch.ops.quantized.linear_prepack(w)
with torch.jit._disable_emit_hooks():
with torch._jit_internal._disable_emit_hooks():
x = torch.jit.script(Linear(10, 10))
torch._C._jit_pass_erase_shape_information(x.graph)

View File

@ -2159,7 +2159,7 @@ graph(%Ra, %Rb):
self.assertExpected(cu.foo.code)
def test_import_method(self):
with torch.jit._disable_emit_hooks():
with torch._jit_internal._disable_emit_hooks():
class Foo(torch.jit.ScriptModule):
def __init__(self):
super(Foo, self).__init__()
@ -3596,7 +3596,7 @@ def foo(x):
mod.ninf = float("-inf")
mod.nan = float("nan")
with torch.jit._disable_emit_hooks():
with torch._jit_internal._disable_emit_hooks():
class Foo(torch.jit.ScriptModule):
def __init__(self):
super(Foo, self).__init__()
@ -9122,7 +9122,7 @@ a")
x[seq_lens[b]:, b, :] = 0
eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens)
with torch.jit._disable_emit_hooks():
with torch._jit_internal._disable_emit_hooks():
scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script)
script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens)
self.assertEqual(eager_seq, script_seq)
@ -9145,7 +9145,7 @@ a")
lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2)
with torch.jit._disable_emit_hooks():
with torch._jit_internal._disable_emit_hooks():
self.checkModule(lstm, [torch.ones(2, 2)])
def test_script_pad_sequence_pack_sequence(self):
@ -9165,7 +9165,7 @@ a")
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5])
tensor3 = torch.tensor([6])
with torch.jit._disable_emit_hooks():
with torch._jit_internal._disable_emit_hooks():
self.checkScript(pad_sequence_func,
([ones3, ones4, ones5],))
self.checkScript(pad_sequence_func,
@ -9361,7 +9361,7 @@ a")
def test_tuples(self):
# TODO: jitter issue.
with torch.jit._disable_emit_hooks(): # TODO: Python print broadcasting list
with torch._jit_internal._disable_emit_hooks(): # TODO: Python print broadcasting list
def foo(i):
a = (i + 4, i * 2)
c = a
@ -12613,7 +12613,7 @@ a")
self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
def test_bool_dispatch(self):
with torch.jit._disable_emit_hooks(): # TODO: Python print broadcasting list
with torch._jit_internal._disable_emit_hooks(): # TODO: Python print broadcasting list
def kwarg_false(x):
# type: (Tensor) -> Tensor
return F.max_pool1d(x, 1, 1, return_indices=False)
@ -14237,7 +14237,7 @@ a")
# type: (str) -> Tensor
return self.table[key] + self.x
with torch.jit._disable_emit_hooks():
with torch._jit_internal._disable_emit_hooks():
# TODO: re-enable module hook when Python printing of attributes is
# supported
m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
@ -15393,7 +15393,7 @@ def add_nn_functional_test(name, self_size, args, variant_name='', check_ad=(),
self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
if test_name in EXCLUDE_PYTHON_PRINT:
with torch.jit._disable_emit_hooks():
with torch._jit_internal._disable_emit_hooks():
run_test()
else:
run_test()

View File

@ -474,7 +474,7 @@ class TestFuser(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
@torch.jit._disable_emit_hooks_decorator
@torch._jit_internal._disable_emit_hooks_decorator
@_inline_everything
def test_fuse_decompose_normalization(self):
class ResLike(torch.jit.ScriptModule):

View File

@ -507,7 +507,7 @@ class TestFuser(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
@torch.jit._disable_emit_hooks_decorator
@torch._jit_internal._disable_emit_hooks_decorator
@_inline_everything
def test_fuse_decompose_normalization(self):
class ResLike(torch.jit.ScriptModule):

View File

@ -4,6 +4,8 @@ can be used in other places in torch/ (namely torch.nn) without running into
circular dependency problems
"""
import contextlib
import collections
import inspect
import weakref
import warnings
@ -767,3 +769,45 @@ class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
def fake_range():
return SourceContext('', None, 0, 0).make_raw_range(0, 1)
def _try_get_dispatched_fn(fn):
if not callable(fn):
return None
return boolean_dispatched.get(fn)
def _get_named_tuple_properties(obj):
assert issubclass(obj, tuple) and hasattr(obj, '_fields')
fields = list(obj._fields)
annotations = []
has_annotations = hasattr(obj, '__annotations__')
for field in fields:
if has_annotations and field in obj.__annotations__:
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range())
annotations.append(the_type)
else:
annotations.append(torch._C.TensorType.get())
return type(obj).__name__, fields, annotations
def _create_named_tuple(t, unqual_name, field_names):
TupleType = collections.namedtuple(unqual_name, field_names)
return TupleType(*t)
@contextlib.contextmanager
def _disable_emit_hooks():
hooks = torch._C._jit_get_emit_hooks()
torch._C._jit_set_emit_hooks(None, None)
yield
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811
def __enter__(self):
self.hooks = torch._C._jit_get_emit_hooks()
torch._C._jit_set_emit_hooks(None, None)
def __exit__(self, *args):
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])

View File

@ -61,7 +61,7 @@ class _OpNamespace(types.ModuleType):
op = torch._C._jit_get_operation(qualified_op_name)
# let the script frontend know that op is identical to the builtin op
# with qualified_op_name
torch.jit._register_builtin(op, qualified_op_name)
torch.jit._builtins._register_builtin(op, qualified_op_name)
setattr(self, op_name, op)
op.__module__ = self.__module__ + "." + self.name
return op

View File

@ -839,7 +839,7 @@ inline py::object toPyObject(IValue ivalue) {
auto fieldNames = fmap(
tuple->type()->schema()->arguments(),
[](const Argument& arg) { return arg.name(); });
return py::module::import("torch.jit")
return py::module::import("torch._jit_internal")
.attr("_create_named_tuple")(t, unqualName, fieldNames);
} else {
return std::move(t);

View File

@ -686,8 +686,8 @@ TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc) {
}
}
py::object props =
py::module::import("torch.jit").attr("_get_named_tuple_properties")(obj);
py::object props = py::module::import("torch._jit_internal")
.attr("_get_named_tuple_properties")(obj);
std::string unqualName;
std::vector<std::string> fields;
std::vector<TypePtr> annotations;
@ -788,7 +788,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
}
py::object builtin_name =
py::module::import("torch.jit").attr("_find_builtin")(obj);
py::module::import("torch.jit._builtins").attr("_find_builtin")(obj);
if (!builtin_name.is_none()) {
return std::make_shared<BuiltinFunction>(
Symbol::fromQualString(py::str(builtin_name)), c10::nullopt);
@ -801,8 +801,8 @@ std::shared_ptr<SugaredValue> toSugaredValue(
}
}
py::object dispatched_fn =
py::module::import("torch.jit").attr("_try_get_dispatched_fn")(obj);
py::object dispatched_fn = py::module::import("torch._jit_internal")
.attr("_try_get_dispatched_fn")(obj);
if (!dispatched_fn.is_none()) {
return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
}

View File

@ -160,7 +160,7 @@ def _wait_all_workers():
is_leader_worker = leader_worker_name == self_worker_name
# Set a long enough timeout for all shutdown messages to be processed.
timeout = 5 # seconds
timeout = 5 # second
# Phase 1: Followers send intents.
# All followers report intents to the leader.
@ -522,7 +522,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
"""
qualified_name = torch.jit._find_builtin(func)
qualified_name = torch.jit._builtins._find_builtin(func)
dst_worker_info = _to_worker_info(to)
should_profile = torch.autograd._profiler_enabled()
@ -594,7 +594,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
if not callable(func):
raise TypeError("function should be callable.")
qualified_name = torch.jit._find_builtin(func)
qualified_name = torch.jit._builtins._find_builtin(func)
dst_worker_info = _to_worker_info(to)
# TODO: profiling logic does not really belong in invoke_rpc, it should be

View File

@ -1,93 +1,52 @@
import torch._C
import torch._jit_internal as _jit_internal
from torch.jit._builtins import _find_builtin, _get_builtin_table, _register_builtin # noqa
from torch._jit_internal import Future
from torch.nn import Module
from torch.utils import set_module
from torch.autograd.grad_mode import _DecoratorContextManager
from typing import Optional, List
import collections
import contextlib
import functools
import os
import pathlib
# These are imported so users can access them from the `torch.jit` module
from torch._jit_internal import Final, _overload, _overload_method
from torch._jit_internal import ignore, export, unused
from torch.jit._script import script, Attribute, ScriptModule, is_scripting, script_method, \
RecursiveScriptModule, ScriptWarning, interface
from torch.jit._trace import trace, trace_module, TracedModule, TracerWarning, TracingCheckError, \
is_tracing, ONNXTracedModule, _unique_state_dict, _flatten, TopLevelTracedModule
from torch._jit_internal import (
Final,
Future,
_overload,
_overload_method,
ignore,
export,
unused,
)
from torch.jit._script import (
script,
Attribute,
ScriptModule,
is_scripting,
script_method,
RecursiveScriptModule,
ScriptWarning,
interface,
CompilationUnit,
ScriptFunction,
_unwrap_optional,
)
from torch.jit._trace import (
trace,
trace_module,
TracedModule,
TracerWarning,
TracingCheckError,
is_tracing,
ONNXTracedModule,
TopLevelTracedModule,
_unique_state_dict,
_flatten,
_script_if_tracing,
_get_trace_graph,
)
from torch.jit._async import fork, wait
from torch.jit._serialization import save, load
set_module(Future, "torch.jit")
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
# For backwards compatibility
_fork = fork
_wait = wait
@contextlib.contextmanager
def optimized_execution(should_optimize):
"""
A context manager that controls whether the JIT's executor will run
optimizations before executing a function.
"""
stored_flag = torch._C._get_graph_executor_optimize()
torch._C._set_graph_executor_optimize(should_optimize)
try:
yield
finally:
torch._C._set_graph_executor_optimize(stored_flag)
@contextlib.contextmanager
def fuser(name):
"""
A context manager that facilitates switching between
backend fusers.
Valid names:
* ``fuser0`` - enables only legacy fuser
* ``fuser1`` - enables only NNC
* ``fuser2`` - enables only nvFuser
"""
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
if name == 'fuser0': # legacy fuser
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
elif name == 'fuser1': # NNC
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
old_profiling_mode = torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(True)
torch._C._jit_set_nvfuser_enabled(False)
elif name == 'fuser2': # nvFuser
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
else:
raise Exception("unrecognized fuser option")
try:
yield
finally:
if name == 'fuser1': # NNC
torch._C._jit_set_profiling_executor(old_profiling_executor)
torch._C._jit_set_profiling_mode(old_profiling_mode)
# recover the previous values
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
def export_opnames(m):
r"""
@ -95,212 +54,6 @@ def export_opnames(m):
"""
return torch._C._export_opnames(m._c)
def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
return_inputs=False, _return_inputs_states=False):
"""
.. warning::
This function is internal-only and should only be used by the ONNX
exporter. If you are trying to get a graph through tracing, please go
through the public API instead::
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
trace_graph = trace.graph
Trace a function or model, returning a tuple consisting of the both the
*trace* of an execution, as well as the original return value. If return_inputs,
also returns the trace inputs as part of the tuple
Tracing is guaranteed not to change the semantics of the function/module
that is traced.
Arguments:
f (torch.nn.Module or function): the function or module
to be traced.
args (tuple or Tensor): the positional arguments to pass to the
function/module to be traced. A non-tuple is assumed to
be a single positional argument to be passed to the model.
kwargs (dict): the keyword arguments to pass to the function/module
to be traced.
Example (trace a cell):
.. testcode::
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
"""
if kwargs is None:
kwargs = {}
if not isinstance(args, tuple):
args = (args,)
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
return outs
def freeze(mod, preserved_attrs : Optional[List[str]] = None):
r"""
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
By default, `forward` will be preserved, as well as attributes & methods specified in
`preserved_attrs`. Additionally, any attribute that is modified within a preserved
method will be preserved.
Freezing currently only accepts ScriptModules that are in eval mode.
Arguments:
mod (:class:`ScriptModule`): a module to be frozen
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
Attributes modified in preserved methods will also be preserved.
Returns:
Frozen :class:`ScriptModule`.
Example (Freezing a simple module with a Parameter):
.. testcode::
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
self.linear = torch.nn.Linear(N, M)
def forward(self, input):
output = self.weight.mm(input)
output = self.linear(output)
return output
scripted_module = torch.jit.script(MyModule(2, 3).eval())
frozen_module = torch.jit.freeze(scripted_module)
# parameters have been removed and inlined into the Graph as constants
assert len(list(frozen_module.named_parameters())) == 0
# See the compiled graph as Python code
print(frozen_module.code)
Example (Freezing a module with preserved attributes)
.. testcode::
import torch
class MyModule2(torch.nn.Module):
def __init__(self):
super(MyModule2, self).__init__()
self.modified_tensor = torch.tensor(10.)
self.version = 1
def forward(self, input):
self.modified_tensor += 1
return input + self.modified_tensor
scripted_module = torch.jit.script(MyModule2().eval())
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
assert frozen_module.version == 1
frozen_module.version = 2
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
# it to retain model semantics
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
# now that we've run it once, the next result will be incremented by one
assert frozen_module(torch.tensor(1)) == torch.tensor(13)
Note:
If you're not sure why an attribute is not being inlined as a constant, you can run
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
attribute is being modified.
"""
if not isinstance(mod, ScriptModule):
raise RuntimeError("Freezing expects a ScriptModule as input. "
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'.")
if mod.training:
raise RuntimeError("Freezing is currently only implemented for modules in eval mode. "
"Please call .eval() on your module before freezing.")
preserved_attrs = preserved_attrs if preserved_attrs is not None else []
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
RecursiveScriptModule._finalize_scriptmodule(out)
return out
class CompilationUnit(object):
def __init__(self, lang=None, _frames_up=0):
self._c = torch._C.CompilationUnit()
if lang is not None:
self.define(lang, _frames_up=_frames_up + 1)
def define(self, lang, rcb=None, _frames_up=0):
if not rcb:
rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
self._c.define(lang, rcb)
def __getattr__(self, attr):
r = self._c.find_function(attr)
if r is None:
raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
return r
def _try_get_dispatched_fn(fn):
if not callable(fn):
return None
return _jit_internal.boolean_dispatched.get(fn)
def _try_get_overloaded_fn(mod, field):
return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None
@contextlib.contextmanager
def _disable_emit_hooks():
hooks = torch._C._jit_get_emit_hooks()
torch._C._jit_set_emit_hooks(None, None)
yield
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811
def __enter__(self):
self.hooks = torch._C._jit_get_emit_hooks()
torch._C._jit_set_emit_hooks(None, None)
def __exit__(self, *args):
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
def _script_if_tracing(fn):
"""
Compiles ``fn`` when it is first called during tracing. ``torch.jit.script``
has a non-negligible start up time when it is first called due to
lazy-initializations of many compiler builtins. Therefore you should not use
it in library code. However, you may want to have parts of your library work
in tracing even if they use control flow. In these cases, you should use
``@torch.jit._script_if_tracing`` to substitute for
``torch.jit.script``.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if not is_tracing():
# Not tracing, don't do anything
return fn(*args, **kwargs)
compiled_fn = script(wrapper.__original_fn)
return compiled_fn(*args, **kwargs)
wrapper.__original_fn = fn
wrapper.__script_if_tracing_wrapper = True
return wrapper
def _unwrap_optional(x):
assert x is not None, "Unwrapping null optional"
return x
_register_builtin(_unwrap_optional, 'aten::_unwrap_optional')
_register_builtin(_wait, 'aten::wait')
_register_builtin(wait, 'aten::wait')
_register_builtin(is_scripting, 'aten::is_scripting')
# torch.jit.Error
Error = torch._C.JITException
@ -309,53 +62,11 @@ set_module(Error, "torch.jit")
Error.__name__ = "Error"
Error.__qualname__ = "Error"
def _get_named_tuple_properties(obj):
assert issubclass(obj, tuple) and hasattr(obj, '_fields')
fields = list(obj._fields)
annotations = []
has_annotations = hasattr(obj, '__annotations__')
for field in fields:
if has_annotations and field in obj.__annotations__:
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], _jit_internal.fake_range())
annotations.append(the_type)
else:
annotations.append(torch._C.TensorType.get())
return type(obj).__name__, fields, annotations
def _create_named_tuple(t, unqual_name, field_names):
TupleType = collections.namedtuple(unqual_name, field_names)
return TupleType(*t)
class _disable_tracing(object):
def __enter__(self):
self.state = torch._C._get_tracing_state()
torch._C._set_tracing_state(None)
def __exit__(self, *args):
torch._C._set_tracing_state(self.state)
self.state = None
# for use in python if using annotate
def annotate(the_type, the_value):
# noop in python
return the_value
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
def _graph_for(self, *args, **kwargs):
self(*args, **kwargs)
return last_executed_optimized_graph()
torch._C.ScriptMethod.graph_for = _graph_for
torch._C.ScriptFunction.graph_for = _graph_for
ScriptFunction = torch._C.ScriptFunction
ScriptFunction.__doc__ = """
Functionally equivalent to a :class:`ScriptModule`, but represents a single
function and does not have any attributes or Parameters.
"""
set_module(ScriptFunction, "torch.jit")
if not torch._C._jit_init():
raise RuntimeError("JIT initialization failed")

View File

@ -9,6 +9,12 @@ functionalities in `torch.jit`.
import torch
from torch.utils import set_module
from torch.jit._builtins import _register_builtin
from torch._jit_internal import Future
set_module(Future, "torch.jit")
def fork(func, *args, **kwargs):
"""
@ -84,3 +90,6 @@ def wait(future):
`T`: the return value of the the completed task
"""
return torch._C.wait(future)
_register_builtin(wait, "aten::wait")

101
torch/jit/_freeze.py Normal file
View File

@ -0,0 +1,101 @@
"""Freezing
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
from typing import Optional, List
import torch
from torch.jit._script import RecursiveScriptModule, ScriptModule
def freeze(mod, preserved_attrs: Optional[List[str]] = None):
r"""
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
By default, `forward` will be preserved, as well as attributes & methods specified in
`preserved_attrs`. Additionally, any attribute that is modified within a preserved
method will be preserved.
Freezing currently only accepts ScriptModules that are in eval mode.
Arguments:
mod (:class:`ScriptModule`): a module to be frozen
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
Attributes modified in preserved methods will also be preserved.
Returns:
Frozen :class:`ScriptModule`.
Example (Freezing a simple module with a Parameter):
.. testcode::
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
self.linear = torch.nn.Linear(N, M)
def forward(self, input):
output = self.weight.mm(input)
output = self.linear(output)
return output
scripted_module = torch.jit.script(MyModule(2, 3).eval())
frozen_module = torch.jit.freeze(scripted_module)
# parameters have been removed and inlined into the Graph as constants
assert len(list(frozen_module.named_parameters())) == 0
# See the compiled graph as Python code
print(frozen_module.code)
Example (Freezing a module with preserved attributes)
.. testcode::
import torch
class MyModule2(torch.nn.Module):
def __init__(self):
super(MyModule2, self).__init__()
self.modified_tensor = torch.tensor(10.)
self.version = 1
def forward(self, input):
self.modified_tensor += 1
return input + self.modified_tensor
scripted_module = torch.jit.script(MyModule2().eval())
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
assert frozen_module.version == 1
frozen_module.version = 2
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
# it to retain model semantics
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
# now that we've run it once, the next result will be incremented by one
assert frozen_module(torch.tensor(1)) == torch.tensor(13)
Note:
If you're not sure why an attribute is not being inlined as a constant, you can run
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
attribute is being modified.
"""
if not isinstance(mod, ScriptModule):
raise RuntimeError(
"Freezing expects a ScriptModule as input. "
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
)
if mod.training:
raise RuntimeError(
"Freezing is currently only implemented for modules in eval mode. "
"Please call .eval() on your module before freezing."
)
preserved_attrs = preserved_attrs if preserved_attrs is not None else []
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
RecursiveScriptModule._finalize_scriptmodule(out)
return out

70
torch/jit/_fuser.py Normal file
View File

@ -0,0 +1,70 @@
import contextlib
import torch
@contextlib.contextmanager
def optimized_execution(should_optimize):
"""
A context manager that controls whether the JIT's executor will run
optimizations before executing a function.
"""
stored_flag = torch._C._get_graph_executor_optimize()
torch._C._set_graph_executor_optimize(should_optimize)
try:
yield
finally:
torch._C._set_graph_executor_optimize(stored_flag)
@contextlib.contextmanager
def fuser(name):
"""
A context manager that facilitates switching between
backend fusers.
Valid names:
* ``fuser0`` - enables only legacy fuser
* ``fuser1`` - enables only NNC
* ``fuser2`` - enables only nvFuser
"""
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
if name == 'fuser0': # legacy fuser
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
elif name == 'fuser1': # NNC
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
old_profiling_mode = torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(True)
torch._C._jit_set_nvfuser_enabled(False)
elif name == 'fuser2': # nvFuser
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
else:
raise Exception("unrecognized fuser option")
try:
yield
finally:
if name == 'fuser1': # NNC
torch._C._jit_set_profiling_executor(old_profiling_executor)
torch._C._jit_set_profiling_mode(old_profiling_mode)
# recover the previous values
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
def _graph_for(self, *args, **kwargs):
self(*args, **kwargs)
return last_executed_optimized_graph()

View File

@ -609,7 +609,7 @@ def compile_unbound_method(concrete_type, fn):
if _jit_internal.is_ignored_fn(fn):
return None
stub = make_stub(fn, fn.__name__)
with torch.jit._disable_emit_hooks():
with torch._jit_internal._disable_emit_hooks():
# We don't want to call the hooks here since the graph that is calling
# this function is not yet complete
create_methods_from_stubs(concrete_type, (stub,))

View File

@ -15,12 +15,15 @@ import warnings
import torch
import torch._jit_internal as _jit_internal
from torch.utils import set_module
from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module
from torch.nn import Module
from torch.jit._state import _enabled
from torch.jit._builtins import _register_builtin
from torch._six import with_metaclass, get_function_from_type
from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def
from torch._jit_internal import _qualified_name
from torch.jit._fuser import _graph_for
from torch.jit._state import (
_try_get_jit_cached_function,
_try_get_jit_cached_overloads,
@ -28,6 +31,16 @@ from torch.jit._state import (
_set_jit_overload_cache,
)
torch._C.ScriptMethod.graph_for = _graph_for
torch._C.ScriptFunction.graph_for = _graph_for
ScriptFunction = torch._C.ScriptFunction
ScriptFunction.__doc__ = """
Functionally equivalent to a :class:`ScriptModule`, but represents a single
function and does not have any attributes or Parameters.
"""
set_module(ScriptFunction, "torch.jit")
if _enabled:
Attribute = collections.namedtuple("Attribute", ["value", "type"])
else:
@ -1053,3 +1066,32 @@ def _recursive_compile_class(obj, loc):
error_stack = torch._C.CallStack(_qual_name, loc)
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
_compile_and_register_class(obj, rcb, _qual_name)
_register_builtin(is_scripting, "aten::is_scripting")
class CompilationUnit(object):
def __init__(self, lang=None, _frames_up=0):
self._c = torch._C.CompilationUnit()
if lang is not None:
self.define(lang, _frames_up=_frames_up + 1)
def define(self, lang, rcb=None, _frames_up=0):
if not rcb:
rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
self._c.define(lang, rcb)
def __getattr__(self, attr):
r = self._c.find_function(attr)
if r is None:
raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
return r
def _unwrap_optional(x):
assert x is not None, "Unwrapping null optional"
return x
_register_builtin(_unwrap_optional, "aten::_unwrap_optional")

View File

@ -11,12 +11,13 @@ import torch
import os
import contextlib
import functools
import warnings
import inspect
import re
from torch.jit._state import _python_cu, _enabled
from torch.jit._script import ScriptModule, _CachedForward
from torch.jit._script import ScriptModule, _CachedForward, script
from torch._jit_internal import _qualified_name
from torch.autograd import function
from torch import _jit_internal
@ -1077,3 +1078,70 @@ class TopLevelTracedModule(TracedModule):
cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around.
"""
self.__dict__["_actual_script_module"]._reconstruct(cpp_module)
def _script_if_tracing(fn):
"""
Compiles ``fn`` when it is first called during tracing. ``torch.jit.script``
has a non-negligible start up time when it is first called due to
lazy-initializations of many compiler builtins. Therefore you should not use
it in library code. However, you may want to have parts of your library work
in tracing even if they use control flow. In these cases, you should use
``@torch.jit._script_if_tracing`` to substitute for
``torch.jit.script``.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if not is_tracing():
# Not tracing, don't do anything
return fn(*args, **kwargs)
compiled_fn = script(wrapper.__original_fn)
return compiled_fn(*args, **kwargs)
wrapper.__original_fn = fn
wrapper.__script_if_tracing_wrapper = True
return wrapper
def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
return_inputs=False, _return_inputs_states=False):
"""
.. warning::
This function is internal-only and should only be used by the ONNX
exporter. If you are trying to get a graph through tracing, please go
through the public API instead::
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
trace_graph = trace.graph
Trace a function or model, returning a tuple consisting of the both the
*trace* of an execution, as well as the original return value. If return_inputs,
also returns the trace inputs as part of the tuple
Tracing is guaranteed not to change the semantics of the function/module
that is traced.
Arguments:
f (torch.nn.Module or function): the function or module
to be traced.
args (tuple or Tensor): the positional arguments to pass to the
function/module to be traced. A non-tuple is assumed to
be a single positional argument to be passed to the model.
kwargs (dict): the keyword arguments to pass to the function/module
to be traced.
Example (trace a cell):
.. testcode::
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
"""
if kwargs is None:
kwargs = {}
if not isinstance(args, tuple):
args = (args,)
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
return outs

View File

@ -1,4 +1,5 @@
import torch.jit
from torch.jit._builtins import _find_builtin
import inspect
import textwrap
# this file is for generating documentation using sphinx autodoc
@ -92,7 +93,7 @@ def _get_nn_functional_ops():
for mod in torch.jit._builtins._modules_containing_builtins:
name = mod.__name__
for elem in dir(mod):
builtin = torch.jit._find_builtin(getattr(mod, elem))
builtin = _find_builtin(getattr(mod, elem))
if builtin is not None:
schemas = torch._C._jit_get_schemas_for_operator(builtin)
for schema in schemas:
@ -133,7 +134,7 @@ def _get_torchscript_builtins():
# Iterate over the specially added builtins
for fn, _builtin_name in builtins:
mod = inspect.getmodule(fn)
builtin = torch.jit._find_builtin(fn)
builtin = _find_builtin(fn)
if builtin is not None:
schemas = torch._C._jit_get_schemas_for_operator(builtin)
for schema in schemas:
@ -150,7 +151,7 @@ def _get_math_builtins():
# Iterate over the specially added builtins
for fn, _builtin_name in builtins:
mod = inspect.getmodule(fn)
builtin = torch.jit._find_builtin(fn)
builtin = _find_builtin(fn)
if builtin is not None:
schemas = torch._C._jit_get_schemas_for_operator(builtin)
for schema in schemas:

View File

@ -1222,7 +1222,7 @@ class RpcTest(RpcAgentTestFixture):
events = prof.function_events
rpc_mul_event = get_function_event(
events, torch.jit._find_builtin(torch.mul)
events, torch.jit._builtins._find_builtin(torch.mul)
)
remote_events = {

View File

@ -358,7 +358,7 @@ def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name
f_args_variable = (self_variable,) + args_variable
f_args_tensor = (self_tensor,) + args_tensor
with torch.jit._disable_emit_hooks():
with torch._jit_internal._disable_emit_hooks():
script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable)
return script_fn, inputs

View File

@ -159,7 +159,7 @@ class JitTestCase(TestCase):
return code_files, debug_files
# disable the hook while we parse code, otherwise we will re-enter the hook
with torch.jit._disable_emit_hooks():
with torch._jit_internal._disable_emit_hooks():
try:
# short-circuit if this is an empty function or module
if len(m.code) == 0: