mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
move misc implementation out of jit/__init__.py (#41154)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/41154 Test Plan: Imported from OSS Reviewed By: ailzhang Differential Revision: D22445213 Pulled By: suo fbshipit-source-id: 200545715c5ef13beb1437f49e01efb21498ddb7
This commit is contained in:
parent
6392713584
commit
ca1b8ebbcb
|
|
@ -487,7 +487,7 @@ class TestModels(JitTestCase):
|
||||||
return self.seq.forward(input)
|
return self.seq.forward(input)
|
||||||
|
|
||||||
# disabled due to a jitter issues that will be fixed by using load/store in the compiler
|
# disabled due to a jitter issues that will be fixed by using load/store in the compiler
|
||||||
with torch.jit._disable_emit_hooks():
|
with torch._jit_internal._disable_emit_hooks():
|
||||||
# TODO: toggle export_import once above issues are fixed
|
# TODO: toggle export_import once above issues are fixed
|
||||||
self.checkTrace(Traced(), (torch.rand(3, 4),),
|
self.checkTrace(Traced(), (torch.rand(3, 4),),
|
||||||
export_import=False)
|
export_import=False)
|
||||||
|
|
|
||||||
|
|
@ -1817,7 +1817,7 @@ class TestDeprecatedJitQuantized(JitTestCase):
|
||||||
def weight(self, w):
|
def weight(self, w):
|
||||||
self._packed_weight = torch.ops.quantized.linear_prepack(w)
|
self._packed_weight = torch.ops.quantized.linear_prepack(w)
|
||||||
|
|
||||||
with torch.jit._disable_emit_hooks():
|
with torch._jit_internal._disable_emit_hooks():
|
||||||
x = torch.jit.script(Linear(10, 10))
|
x = torch.jit.script(Linear(10, 10))
|
||||||
torch._C._jit_pass_erase_shape_information(x.graph)
|
torch._C._jit_pass_erase_shape_information(x.graph)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2159,7 +2159,7 @@ graph(%Ra, %Rb):
|
||||||
self.assertExpected(cu.foo.code)
|
self.assertExpected(cu.foo.code)
|
||||||
|
|
||||||
def test_import_method(self):
|
def test_import_method(self):
|
||||||
with torch.jit._disable_emit_hooks():
|
with torch._jit_internal._disable_emit_hooks():
|
||||||
class Foo(torch.jit.ScriptModule):
|
class Foo(torch.jit.ScriptModule):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Foo, self).__init__()
|
super(Foo, self).__init__()
|
||||||
|
|
@ -3596,7 +3596,7 @@ def foo(x):
|
||||||
mod.ninf = float("-inf")
|
mod.ninf = float("-inf")
|
||||||
mod.nan = float("nan")
|
mod.nan = float("nan")
|
||||||
|
|
||||||
with torch.jit._disable_emit_hooks():
|
with torch._jit_internal._disable_emit_hooks():
|
||||||
class Foo(torch.jit.ScriptModule):
|
class Foo(torch.jit.ScriptModule):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Foo, self).__init__()
|
super(Foo, self).__init__()
|
||||||
|
|
@ -9122,7 +9122,7 @@ a")
|
||||||
x[seq_lens[b]:, b, :] = 0
|
x[seq_lens[b]:, b, :] = 0
|
||||||
|
|
||||||
eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens)
|
eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens)
|
||||||
with torch.jit._disable_emit_hooks():
|
with torch._jit_internal._disable_emit_hooks():
|
||||||
scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script)
|
scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script)
|
||||||
script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens)
|
script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens)
|
||||||
self.assertEqual(eager_seq, script_seq)
|
self.assertEqual(eager_seq, script_seq)
|
||||||
|
|
@ -9145,7 +9145,7 @@ a")
|
||||||
|
|
||||||
lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2)
|
lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2)
|
||||||
|
|
||||||
with torch.jit._disable_emit_hooks():
|
with torch._jit_internal._disable_emit_hooks():
|
||||||
self.checkModule(lstm, [torch.ones(2, 2)])
|
self.checkModule(lstm, [torch.ones(2, 2)])
|
||||||
|
|
||||||
def test_script_pad_sequence_pack_sequence(self):
|
def test_script_pad_sequence_pack_sequence(self):
|
||||||
|
|
@ -9165,7 +9165,7 @@ a")
|
||||||
tensor1 = torch.tensor([1, 2, 3])
|
tensor1 = torch.tensor([1, 2, 3])
|
||||||
tensor2 = torch.tensor([4, 5])
|
tensor2 = torch.tensor([4, 5])
|
||||||
tensor3 = torch.tensor([6])
|
tensor3 = torch.tensor([6])
|
||||||
with torch.jit._disable_emit_hooks():
|
with torch._jit_internal._disable_emit_hooks():
|
||||||
self.checkScript(pad_sequence_func,
|
self.checkScript(pad_sequence_func,
|
||||||
([ones3, ones4, ones5],))
|
([ones3, ones4, ones5],))
|
||||||
self.checkScript(pad_sequence_func,
|
self.checkScript(pad_sequence_func,
|
||||||
|
|
@ -9361,7 +9361,7 @@ a")
|
||||||
|
|
||||||
def test_tuples(self):
|
def test_tuples(self):
|
||||||
# TODO: jitter issue.
|
# TODO: jitter issue.
|
||||||
with torch.jit._disable_emit_hooks(): # TODO: Python print broadcasting list
|
with torch._jit_internal._disable_emit_hooks(): # TODO: Python print broadcasting list
|
||||||
def foo(i):
|
def foo(i):
|
||||||
a = (i + 4, i * 2)
|
a = (i + 4, i * 2)
|
||||||
c = a
|
c = a
|
||||||
|
|
@ -12613,7 +12613,7 @@ a")
|
||||||
self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
|
self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
|
||||||
|
|
||||||
def test_bool_dispatch(self):
|
def test_bool_dispatch(self):
|
||||||
with torch.jit._disable_emit_hooks(): # TODO: Python print broadcasting list
|
with torch._jit_internal._disable_emit_hooks(): # TODO: Python print broadcasting list
|
||||||
def kwarg_false(x):
|
def kwarg_false(x):
|
||||||
# type: (Tensor) -> Tensor
|
# type: (Tensor) -> Tensor
|
||||||
return F.max_pool1d(x, 1, 1, return_indices=False)
|
return F.max_pool1d(x, 1, 1, return_indices=False)
|
||||||
|
|
@ -14237,7 +14237,7 @@ a")
|
||||||
# type: (str) -> Tensor
|
# type: (str) -> Tensor
|
||||||
return self.table[key] + self.x
|
return self.table[key] + self.x
|
||||||
|
|
||||||
with torch.jit._disable_emit_hooks():
|
with torch._jit_internal._disable_emit_hooks():
|
||||||
# TODO: re-enable module hook when Python printing of attributes is
|
# TODO: re-enable module hook when Python printing of attributes is
|
||||||
# supported
|
# supported
|
||||||
m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
|
m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
|
||||||
|
|
@ -15393,7 +15393,7 @@ def add_nn_functional_test(name, self_size, args, variant_name='', check_ad=(),
|
||||||
self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
|
self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
|
||||||
|
|
||||||
if test_name in EXCLUDE_PYTHON_PRINT:
|
if test_name in EXCLUDE_PYTHON_PRINT:
|
||||||
with torch.jit._disable_emit_hooks():
|
with torch._jit_internal._disable_emit_hooks():
|
||||||
run_test()
|
run_test()
|
||||||
else:
|
else:
|
||||||
run_test()
|
run_test()
|
||||||
|
|
|
||||||
|
|
@ -474,7 +474,7 @@ class TestFuser(JitTestCase):
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
|
||||||
@torch.jit._disable_emit_hooks_decorator
|
@torch._jit_internal._disable_emit_hooks_decorator
|
||||||
@_inline_everything
|
@_inline_everything
|
||||||
def test_fuse_decompose_normalization(self):
|
def test_fuse_decompose_normalization(self):
|
||||||
class ResLike(torch.jit.ScriptModule):
|
class ResLike(torch.jit.ScriptModule):
|
||||||
|
|
|
||||||
|
|
@ -507,7 +507,7 @@ class TestFuser(JitTestCase):
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
|
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
|
||||||
@torch.jit._disable_emit_hooks_decorator
|
@torch._jit_internal._disable_emit_hooks_decorator
|
||||||
@_inline_everything
|
@_inline_everything
|
||||||
def test_fuse_decompose_normalization(self):
|
def test_fuse_decompose_normalization(self):
|
||||||
class ResLike(torch.jit.ScriptModule):
|
class ResLike(torch.jit.ScriptModule):
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ can be used in other places in torch/ (namely torch.nn) without running into
|
||||||
circular dependency problems
|
circular dependency problems
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import collections
|
||||||
import inspect
|
import inspect
|
||||||
import weakref
|
import weakref
|
||||||
import warnings
|
import warnings
|
||||||
|
|
@ -767,3 +769,45 @@ class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
|
||||||
|
|
||||||
def fake_range():
|
def fake_range():
|
||||||
return SourceContext('', None, 0, 0).make_raw_range(0, 1)
|
return SourceContext('', None, 0, 0).make_raw_range(0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _try_get_dispatched_fn(fn):
|
||||||
|
if not callable(fn):
|
||||||
|
return None
|
||||||
|
return boolean_dispatched.get(fn)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_named_tuple_properties(obj):
|
||||||
|
assert issubclass(obj, tuple) and hasattr(obj, '_fields')
|
||||||
|
fields = list(obj._fields)
|
||||||
|
annotations = []
|
||||||
|
has_annotations = hasattr(obj, '__annotations__')
|
||||||
|
for field in fields:
|
||||||
|
if has_annotations and field in obj.__annotations__:
|
||||||
|
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range())
|
||||||
|
annotations.append(the_type)
|
||||||
|
else:
|
||||||
|
annotations.append(torch._C.TensorType.get())
|
||||||
|
return type(obj).__name__, fields, annotations
|
||||||
|
|
||||||
|
|
||||||
|
def _create_named_tuple(t, unqual_name, field_names):
|
||||||
|
TupleType = collections.namedtuple(unqual_name, field_names)
|
||||||
|
return TupleType(*t)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _disable_emit_hooks():
|
||||||
|
hooks = torch._C._jit_get_emit_hooks()
|
||||||
|
torch._C._jit_set_emit_hooks(None, None)
|
||||||
|
yield
|
||||||
|
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
|
||||||
|
|
||||||
|
|
||||||
|
def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811
|
||||||
|
def __enter__(self):
|
||||||
|
self.hooks = torch._C._jit_get_emit_hooks()
|
||||||
|
torch._C._jit_set_emit_hooks(None, None)
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ class _OpNamespace(types.ModuleType):
|
||||||
op = torch._C._jit_get_operation(qualified_op_name)
|
op = torch._C._jit_get_operation(qualified_op_name)
|
||||||
# let the script frontend know that op is identical to the builtin op
|
# let the script frontend know that op is identical to the builtin op
|
||||||
# with qualified_op_name
|
# with qualified_op_name
|
||||||
torch.jit._register_builtin(op, qualified_op_name)
|
torch.jit._builtins._register_builtin(op, qualified_op_name)
|
||||||
setattr(self, op_name, op)
|
setattr(self, op_name, op)
|
||||||
op.__module__ = self.__module__ + "." + self.name
|
op.__module__ = self.__module__ + "." + self.name
|
||||||
return op
|
return op
|
||||||
|
|
|
||||||
|
|
@ -839,7 +839,7 @@ inline py::object toPyObject(IValue ivalue) {
|
||||||
auto fieldNames = fmap(
|
auto fieldNames = fmap(
|
||||||
tuple->type()->schema()->arguments(),
|
tuple->type()->schema()->arguments(),
|
||||||
[](const Argument& arg) { return arg.name(); });
|
[](const Argument& arg) { return arg.name(); });
|
||||||
return py::module::import("torch.jit")
|
return py::module::import("torch._jit_internal")
|
||||||
.attr("_create_named_tuple")(t, unqualName, fieldNames);
|
.attr("_create_named_tuple")(t, unqualName, fieldNames);
|
||||||
} else {
|
} else {
|
||||||
return std::move(t);
|
return std::move(t);
|
||||||
|
|
|
||||||
|
|
@ -686,8 +686,8 @@ TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
py::object props =
|
py::object props = py::module::import("torch._jit_internal")
|
||||||
py::module::import("torch.jit").attr("_get_named_tuple_properties")(obj);
|
.attr("_get_named_tuple_properties")(obj);
|
||||||
std::string unqualName;
|
std::string unqualName;
|
||||||
std::vector<std::string> fields;
|
std::vector<std::string> fields;
|
||||||
std::vector<TypePtr> annotations;
|
std::vector<TypePtr> annotations;
|
||||||
|
|
@ -788,7 +788,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
||||||
}
|
}
|
||||||
|
|
||||||
py::object builtin_name =
|
py::object builtin_name =
|
||||||
py::module::import("torch.jit").attr("_find_builtin")(obj);
|
py::module::import("torch.jit._builtins").attr("_find_builtin")(obj);
|
||||||
if (!builtin_name.is_none()) {
|
if (!builtin_name.is_none()) {
|
||||||
return std::make_shared<BuiltinFunction>(
|
return std::make_shared<BuiltinFunction>(
|
||||||
Symbol::fromQualString(py::str(builtin_name)), c10::nullopt);
|
Symbol::fromQualString(py::str(builtin_name)), c10::nullopt);
|
||||||
|
|
@ -801,8 +801,8 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
py::object dispatched_fn =
|
py::object dispatched_fn = py::module::import("torch._jit_internal")
|
||||||
py::module::import("torch.jit").attr("_try_get_dispatched_fn")(obj);
|
.attr("_try_get_dispatched_fn")(obj);
|
||||||
if (!dispatched_fn.is_none()) {
|
if (!dispatched_fn.is_none()) {
|
||||||
return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
|
return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -160,7 +160,7 @@ def _wait_all_workers():
|
||||||
|
|
||||||
is_leader_worker = leader_worker_name == self_worker_name
|
is_leader_worker = leader_worker_name == self_worker_name
|
||||||
# Set a long enough timeout for all shutdown messages to be processed.
|
# Set a long enough timeout for all shutdown messages to be processed.
|
||||||
timeout = 5 # seconds
|
timeout = 5 # second
|
||||||
|
|
||||||
# Phase 1: Followers send intents.
|
# Phase 1: Followers send intents.
|
||||||
# All followers report intents to the leader.
|
# All followers report intents to the leader.
|
||||||
|
|
@ -522,7 +522,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||||
>>> rpc.shutdown()
|
>>> rpc.shutdown()
|
||||||
"""
|
"""
|
||||||
qualified_name = torch.jit._find_builtin(func)
|
qualified_name = torch.jit._builtins._find_builtin(func)
|
||||||
dst_worker_info = _to_worker_info(to)
|
dst_worker_info = _to_worker_info(to)
|
||||||
should_profile = torch.autograd._profiler_enabled()
|
should_profile = torch.autograd._profiler_enabled()
|
||||||
|
|
||||||
|
|
@ -594,7 +594,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
|
||||||
if not callable(func):
|
if not callable(func):
|
||||||
raise TypeError("function should be callable.")
|
raise TypeError("function should be callable.")
|
||||||
|
|
||||||
qualified_name = torch.jit._find_builtin(func)
|
qualified_name = torch.jit._builtins._find_builtin(func)
|
||||||
dst_worker_info = _to_worker_info(to)
|
dst_worker_info = _to_worker_info(to)
|
||||||
|
|
||||||
# TODO: profiling logic does not really belong in invoke_rpc, it should be
|
# TODO: profiling logic does not really belong in invoke_rpc, it should be
|
||||||
|
|
|
||||||
|
|
@ -1,93 +1,52 @@
|
||||||
import torch._C
|
import torch._C
|
||||||
import torch._jit_internal as _jit_internal
|
|
||||||
|
|
||||||
from torch.jit._builtins import _find_builtin, _get_builtin_table, _register_builtin # noqa
|
|
||||||
from torch._jit_internal import Future
|
|
||||||
from torch.nn import Module
|
|
||||||
from torch.utils import set_module
|
from torch.utils import set_module
|
||||||
from torch.autograd.grad_mode import _DecoratorContextManager
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
import collections
|
|
||||||
import contextlib
|
|
||||||
import functools
|
|
||||||
import os
|
|
||||||
import pathlib
|
|
||||||
|
|
||||||
# These are imported so users can access them from the `torch.jit` module
|
# These are imported so users can access them from the `torch.jit` module
|
||||||
from torch._jit_internal import Final, _overload, _overload_method
|
from torch._jit_internal import (
|
||||||
from torch._jit_internal import ignore, export, unused
|
Final,
|
||||||
from torch.jit._script import script, Attribute, ScriptModule, is_scripting, script_method, \
|
Future,
|
||||||
RecursiveScriptModule, ScriptWarning, interface
|
_overload,
|
||||||
from torch.jit._trace import trace, trace_module, TracedModule, TracerWarning, TracingCheckError, \
|
_overload_method,
|
||||||
is_tracing, ONNXTracedModule, _unique_state_dict, _flatten, TopLevelTracedModule
|
ignore,
|
||||||
|
export,
|
||||||
|
unused,
|
||||||
|
)
|
||||||
|
from torch.jit._script import (
|
||||||
|
script,
|
||||||
|
Attribute,
|
||||||
|
ScriptModule,
|
||||||
|
is_scripting,
|
||||||
|
script_method,
|
||||||
|
RecursiveScriptModule,
|
||||||
|
ScriptWarning,
|
||||||
|
interface,
|
||||||
|
CompilationUnit,
|
||||||
|
ScriptFunction,
|
||||||
|
_unwrap_optional,
|
||||||
|
)
|
||||||
|
from torch.jit._trace import (
|
||||||
|
trace,
|
||||||
|
trace_module,
|
||||||
|
TracedModule,
|
||||||
|
TracerWarning,
|
||||||
|
TracingCheckError,
|
||||||
|
is_tracing,
|
||||||
|
ONNXTracedModule,
|
||||||
|
TopLevelTracedModule,
|
||||||
|
_unique_state_dict,
|
||||||
|
_flatten,
|
||||||
|
_script_if_tracing,
|
||||||
|
_get_trace_graph,
|
||||||
|
)
|
||||||
from torch.jit._async import fork, wait
|
from torch.jit._async import fork, wait
|
||||||
from torch.jit._serialization import save, load
|
from torch.jit._serialization import save, load
|
||||||
|
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
|
||||||
set_module(Future, "torch.jit")
|
|
||||||
|
|
||||||
# For backwards compatibility
|
# For backwards compatibility
|
||||||
_fork = fork
|
_fork = fork
|
||||||
_wait = wait
|
_wait = wait
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def optimized_execution(should_optimize):
|
|
||||||
"""
|
|
||||||
A context manager that controls whether the JIT's executor will run
|
|
||||||
optimizations before executing a function.
|
|
||||||
"""
|
|
||||||
stored_flag = torch._C._get_graph_executor_optimize()
|
|
||||||
torch._C._set_graph_executor_optimize(should_optimize)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
torch._C._set_graph_executor_optimize(stored_flag)
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def fuser(name):
|
|
||||||
"""
|
|
||||||
A context manager that facilitates switching between
|
|
||||||
backend fusers.
|
|
||||||
|
|
||||||
Valid names:
|
|
||||||
* ``fuser0`` - enables only legacy fuser
|
|
||||||
* ``fuser1`` - enables only NNC
|
|
||||||
* ``fuser2`` - enables only nvFuser
|
|
||||||
"""
|
|
||||||
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
|
|
||||||
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
|
|
||||||
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
|
|
||||||
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
|
|
||||||
if name == 'fuser0': # legacy fuser
|
|
||||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
|
||||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
|
||||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
|
||||||
torch._C._jit_set_nvfuser_enabled(False)
|
|
||||||
elif name == 'fuser1': # NNC
|
|
||||||
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
|
|
||||||
old_profiling_mode = torch._C._jit_set_profiling_mode(True)
|
|
||||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
|
||||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
|
||||||
torch._C._jit_set_texpr_fuser_enabled(True)
|
|
||||||
torch._C._jit_set_nvfuser_enabled(False)
|
|
||||||
elif name == 'fuser2': # nvFuser
|
|
||||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
|
||||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
|
||||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
|
||||||
torch._C._jit_set_nvfuser_enabled(True)
|
|
||||||
else:
|
|
||||||
raise Exception("unrecognized fuser option")
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if name == 'fuser1': # NNC
|
|
||||||
torch._C._jit_set_profiling_executor(old_profiling_executor)
|
|
||||||
torch._C._jit_set_profiling_mode(old_profiling_mode)
|
|
||||||
# recover the previous values
|
|
||||||
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
|
|
||||||
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
|
|
||||||
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
|
|
||||||
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
|
|
||||||
|
|
||||||
def export_opnames(m):
|
def export_opnames(m):
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -95,212 +54,6 @@ def export_opnames(m):
|
||||||
"""
|
"""
|
||||||
return torch._C._export_opnames(m._c)
|
return torch._C._export_opnames(m._c)
|
||||||
|
|
||||||
def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
|
|
||||||
return_inputs=False, _return_inputs_states=False):
|
|
||||||
"""
|
|
||||||
.. warning::
|
|
||||||
This function is internal-only and should only be used by the ONNX
|
|
||||||
exporter. If you are trying to get a graph through tracing, please go
|
|
||||||
through the public API instead::
|
|
||||||
|
|
||||||
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
|
|
||||||
trace_graph = trace.graph
|
|
||||||
|
|
||||||
Trace a function or model, returning a tuple consisting of the both the
|
|
||||||
*trace* of an execution, as well as the original return value. If return_inputs,
|
|
||||||
also returns the trace inputs as part of the tuple
|
|
||||||
|
|
||||||
Tracing is guaranteed not to change the semantics of the function/module
|
|
||||||
that is traced.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
f (torch.nn.Module or function): the function or module
|
|
||||||
to be traced.
|
|
||||||
args (tuple or Tensor): the positional arguments to pass to the
|
|
||||||
function/module to be traced. A non-tuple is assumed to
|
|
||||||
be a single positional argument to be passed to the model.
|
|
||||||
kwargs (dict): the keyword arguments to pass to the function/module
|
|
||||||
to be traced.
|
|
||||||
|
|
||||||
Example (trace a cell):
|
|
||||||
|
|
||||||
.. testcode::
|
|
||||||
|
|
||||||
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
|
|
||||||
"""
|
|
||||||
if kwargs is None:
|
|
||||||
kwargs = {}
|
|
||||||
if not isinstance(args, tuple):
|
|
||||||
args = (args,)
|
|
||||||
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
|
|
||||||
return outs
|
|
||||||
|
|
||||||
|
|
||||||
def freeze(mod, preserved_attrs : Optional[List[str]] = None):
|
|
||||||
r"""
|
|
||||||
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
|
|
||||||
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
|
|
||||||
By default, `forward` will be preserved, as well as attributes & methods specified in
|
|
||||||
`preserved_attrs`. Additionally, any attribute that is modified within a preserved
|
|
||||||
method will be preserved.
|
|
||||||
|
|
||||||
Freezing currently only accepts ScriptModules that are in eval mode.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
mod (:class:`ScriptModule`): a module to be frozen
|
|
||||||
|
|
||||||
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
|
|
||||||
Attributes modified in preserved methods will also be preserved.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Frozen :class:`ScriptModule`.
|
|
||||||
|
|
||||||
Example (Freezing a simple module with a Parameter):
|
|
||||||
|
|
||||||
.. testcode::
|
|
||||||
import torch
|
|
||||||
class MyModule(torch.nn.Module):
|
|
||||||
def __init__(self, N, M):
|
|
||||||
super(MyModule, self).__init__()
|
|
||||||
self.weight = torch.nn.Parameter(torch.rand(N, M))
|
|
||||||
self.linear = torch.nn.Linear(N, M)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
output = self.weight.mm(input)
|
|
||||||
output = self.linear(output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
scripted_module = torch.jit.script(MyModule(2, 3).eval())
|
|
||||||
frozen_module = torch.jit.freeze(scripted_module)
|
|
||||||
# parameters have been removed and inlined into the Graph as constants
|
|
||||||
assert len(list(frozen_module.named_parameters())) == 0
|
|
||||||
# See the compiled graph as Python code
|
|
||||||
print(frozen_module.code)
|
|
||||||
|
|
||||||
Example (Freezing a module with preserved attributes)
|
|
||||||
|
|
||||||
.. testcode::
|
|
||||||
import torch
|
|
||||||
class MyModule2(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(MyModule2, self).__init__()
|
|
||||||
self.modified_tensor = torch.tensor(10.)
|
|
||||||
self.version = 1
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
self.modified_tensor += 1
|
|
||||||
return input + self.modified_tensor
|
|
||||||
|
|
||||||
scripted_module = torch.jit.script(MyModule2().eval())
|
|
||||||
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
|
|
||||||
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
|
|
||||||
assert frozen_module.version == 1
|
|
||||||
frozen_module.version = 2
|
|
||||||
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
|
|
||||||
# it to retain model semantics
|
|
||||||
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
|
|
||||||
# now that we've run it once, the next result will be incremented by one
|
|
||||||
assert frozen_module(torch.tensor(1)) == torch.tensor(13)
|
|
||||||
|
|
||||||
Note:
|
|
||||||
If you're not sure why an attribute is not being inlined as a constant, you can run
|
|
||||||
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
|
|
||||||
attribute is being modified.
|
|
||||||
"""
|
|
||||||
if not isinstance(mod, ScriptModule):
|
|
||||||
raise RuntimeError("Freezing expects a ScriptModule as input. "
|
|
||||||
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'.")
|
|
||||||
|
|
||||||
if mod.training:
|
|
||||||
raise RuntimeError("Freezing is currently only implemented for modules in eval mode. "
|
|
||||||
"Please call .eval() on your module before freezing.")
|
|
||||||
|
|
||||||
preserved_attrs = preserved_attrs if preserved_attrs is not None else []
|
|
||||||
|
|
||||||
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
|
|
||||||
RecursiveScriptModule._finalize_scriptmodule(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class CompilationUnit(object):
|
|
||||||
def __init__(self, lang=None, _frames_up=0):
|
|
||||||
self._c = torch._C.CompilationUnit()
|
|
||||||
if lang is not None:
|
|
||||||
self.define(lang, _frames_up=_frames_up + 1)
|
|
||||||
|
|
||||||
def define(self, lang, rcb=None, _frames_up=0):
|
|
||||||
if not rcb:
|
|
||||||
rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
|
|
||||||
self._c.define(lang, rcb)
|
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
r = self._c.find_function(attr)
|
|
||||||
if r is None:
|
|
||||||
raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
def _try_get_dispatched_fn(fn):
|
|
||||||
if not callable(fn):
|
|
||||||
return None
|
|
||||||
return _jit_internal.boolean_dispatched.get(fn)
|
|
||||||
|
|
||||||
|
|
||||||
def _try_get_overloaded_fn(mod, field):
|
|
||||||
return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def _disable_emit_hooks():
|
|
||||||
hooks = torch._C._jit_get_emit_hooks()
|
|
||||||
torch._C._jit_set_emit_hooks(None, None)
|
|
||||||
yield
|
|
||||||
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
|
|
||||||
|
|
||||||
|
|
||||||
def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811
|
|
||||||
def __enter__(self):
|
|
||||||
self.hooks = torch._C._jit_get_emit_hooks()
|
|
||||||
torch._C._jit_set_emit_hooks(None, None)
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
|
|
||||||
|
|
||||||
|
|
||||||
def _script_if_tracing(fn):
|
|
||||||
"""
|
|
||||||
Compiles ``fn`` when it is first called during tracing. ``torch.jit.script``
|
|
||||||
has a non-negligible start up time when it is first called due to
|
|
||||||
lazy-initializations of many compiler builtins. Therefore you should not use
|
|
||||||
it in library code. However, you may want to have parts of your library work
|
|
||||||
in tracing even if they use control flow. In these cases, you should use
|
|
||||||
``@torch.jit._script_if_tracing`` to substitute for
|
|
||||||
``torch.jit.script``.
|
|
||||||
"""
|
|
||||||
@functools.wraps(fn)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
if not is_tracing():
|
|
||||||
# Not tracing, don't do anything
|
|
||||||
return fn(*args, **kwargs)
|
|
||||||
|
|
||||||
compiled_fn = script(wrapper.__original_fn)
|
|
||||||
return compiled_fn(*args, **kwargs)
|
|
||||||
|
|
||||||
wrapper.__original_fn = fn
|
|
||||||
wrapper.__script_if_tracing_wrapper = True
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
def _unwrap_optional(x):
|
|
||||||
assert x is not None, "Unwrapping null optional"
|
|
||||||
return x
|
|
||||||
|
|
||||||
_register_builtin(_unwrap_optional, 'aten::_unwrap_optional')
|
|
||||||
_register_builtin(_wait, 'aten::wait')
|
|
||||||
_register_builtin(wait, 'aten::wait')
|
|
||||||
_register_builtin(is_scripting, 'aten::is_scripting')
|
|
||||||
|
|
||||||
|
|
||||||
# torch.jit.Error
|
# torch.jit.Error
|
||||||
Error = torch._C.JITException
|
Error = torch._C.JITException
|
||||||
|
|
@ -309,53 +62,11 @@ set_module(Error, "torch.jit")
|
||||||
Error.__name__ = "Error"
|
Error.__name__ = "Error"
|
||||||
Error.__qualname__ = "Error"
|
Error.__qualname__ = "Error"
|
||||||
|
|
||||||
def _get_named_tuple_properties(obj):
|
|
||||||
assert issubclass(obj, tuple) and hasattr(obj, '_fields')
|
|
||||||
fields = list(obj._fields)
|
|
||||||
annotations = []
|
|
||||||
has_annotations = hasattr(obj, '__annotations__')
|
|
||||||
for field in fields:
|
|
||||||
if has_annotations and field in obj.__annotations__:
|
|
||||||
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], _jit_internal.fake_range())
|
|
||||||
annotations.append(the_type)
|
|
||||||
else:
|
|
||||||
annotations.append(torch._C.TensorType.get())
|
|
||||||
return type(obj).__name__, fields, annotations
|
|
||||||
|
|
||||||
def _create_named_tuple(t, unqual_name, field_names):
|
|
||||||
TupleType = collections.namedtuple(unqual_name, field_names)
|
|
||||||
return TupleType(*t)
|
|
||||||
|
|
||||||
class _disable_tracing(object):
|
|
||||||
def __enter__(self):
|
|
||||||
self.state = torch._C._get_tracing_state()
|
|
||||||
torch._C._set_tracing_state(None)
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
torch._C._set_tracing_state(self.state)
|
|
||||||
self.state = None
|
|
||||||
|
|
||||||
|
|
||||||
# for use in python if using annotate
|
# for use in python if using annotate
|
||||||
def annotate(the_type, the_value):
|
def annotate(the_type, the_value):
|
||||||
# noop in python
|
# noop in python
|
||||||
return the_value
|
return the_value
|
||||||
|
|
||||||
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
|
|
||||||
|
|
||||||
|
|
||||||
def _graph_for(self, *args, **kwargs):
|
|
||||||
self(*args, **kwargs)
|
|
||||||
return last_executed_optimized_graph()
|
|
||||||
|
|
||||||
torch._C.ScriptMethod.graph_for = _graph_for
|
|
||||||
torch._C.ScriptFunction.graph_for = _graph_for
|
|
||||||
ScriptFunction = torch._C.ScriptFunction
|
|
||||||
ScriptFunction.__doc__ = """
|
|
||||||
Functionally equivalent to a :class:`ScriptModule`, but represents a single
|
|
||||||
function and does not have any attributes or Parameters.
|
|
||||||
"""
|
|
||||||
set_module(ScriptFunction, "torch.jit")
|
|
||||||
|
|
||||||
if not torch._C._jit_init():
|
if not torch._C._jit_init():
|
||||||
raise RuntimeError("JIT initialization failed")
|
raise RuntimeError("JIT initialization failed")
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,12 @@ functionalities in `torch.jit`.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from torch.utils import set_module
|
||||||
|
from torch.jit._builtins import _register_builtin
|
||||||
|
from torch._jit_internal import Future
|
||||||
|
|
||||||
|
set_module(Future, "torch.jit")
|
||||||
|
|
||||||
|
|
||||||
def fork(func, *args, **kwargs):
|
def fork(func, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|
@ -84,3 +90,6 @@ def wait(future):
|
||||||
`T`: the return value of the the completed task
|
`T`: the return value of the the completed task
|
||||||
"""
|
"""
|
||||||
return torch._C.wait(future)
|
return torch._C.wait(future)
|
||||||
|
|
||||||
|
|
||||||
|
_register_builtin(wait, "aten::wait")
|
||||||
|
|
|
||||||
101
torch/jit/_freeze.py
Normal file
101
torch/jit/_freeze.py
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
"""Freezing
|
||||||
|
|
||||||
|
This is not intended to be imported directly; please use the exposed
|
||||||
|
functionalities in `torch.jit`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.jit._script import RecursiveScriptModule, ScriptModule
|
||||||
|
|
||||||
|
|
||||||
|
def freeze(mod, preserved_attrs: Optional[List[str]] = None):
|
||||||
|
r"""
|
||||||
|
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
|
||||||
|
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
|
||||||
|
By default, `forward` will be preserved, as well as attributes & methods specified in
|
||||||
|
`preserved_attrs`. Additionally, any attribute that is modified within a preserved
|
||||||
|
method will be preserved.
|
||||||
|
|
||||||
|
Freezing currently only accepts ScriptModules that are in eval mode.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
mod (:class:`ScriptModule`): a module to be frozen
|
||||||
|
|
||||||
|
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
|
||||||
|
Attributes modified in preserved methods will also be preserved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Frozen :class:`ScriptModule`.
|
||||||
|
|
||||||
|
Example (Freezing a simple module with a Parameter):
|
||||||
|
|
||||||
|
.. testcode::
|
||||||
|
import torch
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self, N, M):
|
||||||
|
super(MyModule, self).__init__()
|
||||||
|
self.weight = torch.nn.Parameter(torch.rand(N, M))
|
||||||
|
self.linear = torch.nn.Linear(N, M)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = self.weight.mm(input)
|
||||||
|
output = self.linear(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
scripted_module = torch.jit.script(MyModule(2, 3).eval())
|
||||||
|
frozen_module = torch.jit.freeze(scripted_module)
|
||||||
|
# parameters have been removed and inlined into the Graph as constants
|
||||||
|
assert len(list(frozen_module.named_parameters())) == 0
|
||||||
|
# See the compiled graph as Python code
|
||||||
|
print(frozen_module.code)
|
||||||
|
|
||||||
|
Example (Freezing a module with preserved attributes)
|
||||||
|
|
||||||
|
.. testcode::
|
||||||
|
import torch
|
||||||
|
class MyModule2(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModule2, self).__init__()
|
||||||
|
self.modified_tensor = torch.tensor(10.)
|
||||||
|
self.version = 1
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
self.modified_tensor += 1
|
||||||
|
return input + self.modified_tensor
|
||||||
|
|
||||||
|
scripted_module = torch.jit.script(MyModule2().eval())
|
||||||
|
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
|
||||||
|
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
|
||||||
|
assert frozen_module.version == 1
|
||||||
|
frozen_module.version = 2
|
||||||
|
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
|
||||||
|
# it to retain model semantics
|
||||||
|
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
|
||||||
|
# now that we've run it once, the next result will be incremented by one
|
||||||
|
assert frozen_module(torch.tensor(1)) == torch.tensor(13)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If you're not sure why an attribute is not being inlined as a constant, you can run
|
||||||
|
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
|
||||||
|
attribute is being modified.
|
||||||
|
"""
|
||||||
|
if not isinstance(mod, ScriptModule):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Freezing expects a ScriptModule as input. "
|
||||||
|
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
|
||||||
|
)
|
||||||
|
|
||||||
|
if mod.training:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Freezing is currently only implemented for modules in eval mode. "
|
||||||
|
"Please call .eval() on your module before freezing."
|
||||||
|
)
|
||||||
|
|
||||||
|
preserved_attrs = preserved_attrs if preserved_attrs is not None else []
|
||||||
|
|
||||||
|
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
|
||||||
|
RecursiveScriptModule._finalize_scriptmodule(out)
|
||||||
|
|
||||||
|
return out
|
||||||
70
torch/jit/_fuser.py
Normal file
70
torch/jit/_fuser.py
Normal file
|
|
@ -0,0 +1,70 @@
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def optimized_execution(should_optimize):
|
||||||
|
"""
|
||||||
|
A context manager that controls whether the JIT's executor will run
|
||||||
|
optimizations before executing a function.
|
||||||
|
"""
|
||||||
|
stored_flag = torch._C._get_graph_executor_optimize()
|
||||||
|
torch._C._set_graph_executor_optimize(should_optimize)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch._C._set_graph_executor_optimize(stored_flag)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def fuser(name):
|
||||||
|
"""
|
||||||
|
A context manager that facilitates switching between
|
||||||
|
backend fusers.
|
||||||
|
|
||||||
|
Valid names:
|
||||||
|
* ``fuser0`` - enables only legacy fuser
|
||||||
|
* ``fuser1`` - enables only NNC
|
||||||
|
* ``fuser2`` - enables only nvFuser
|
||||||
|
"""
|
||||||
|
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
|
||||||
|
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
|
||||||
|
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
|
||||||
|
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
|
||||||
|
if name == 'fuser0': # legacy fuser
|
||||||
|
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||||
|
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||||
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||||
|
torch._C._jit_set_nvfuser_enabled(False)
|
||||||
|
elif name == 'fuser1': # NNC
|
||||||
|
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
|
||||||
|
old_profiling_mode = torch._C._jit_set_profiling_mode(True)
|
||||||
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||||
|
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||||
|
torch._C._jit_set_texpr_fuser_enabled(True)
|
||||||
|
torch._C._jit_set_nvfuser_enabled(False)
|
||||||
|
elif name == 'fuser2': # nvFuser
|
||||||
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||||
|
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||||
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||||
|
torch._C._jit_set_nvfuser_enabled(True)
|
||||||
|
else:
|
||||||
|
raise Exception("unrecognized fuser option")
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if name == 'fuser1': # NNC
|
||||||
|
torch._C._jit_set_profiling_executor(old_profiling_executor)
|
||||||
|
torch._C._jit_set_profiling_mode(old_profiling_mode)
|
||||||
|
# recover the previous values
|
||||||
|
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
|
||||||
|
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
|
||||||
|
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
|
||||||
|
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
|
||||||
|
|
||||||
|
|
||||||
|
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
|
||||||
|
|
||||||
|
|
||||||
|
def _graph_for(self, *args, **kwargs):
|
||||||
|
self(*args, **kwargs)
|
||||||
|
return last_executed_optimized_graph()
|
||||||
|
|
@ -609,7 +609,7 @@ def compile_unbound_method(concrete_type, fn):
|
||||||
if _jit_internal.is_ignored_fn(fn):
|
if _jit_internal.is_ignored_fn(fn):
|
||||||
return None
|
return None
|
||||||
stub = make_stub(fn, fn.__name__)
|
stub = make_stub(fn, fn.__name__)
|
||||||
with torch.jit._disable_emit_hooks():
|
with torch._jit_internal._disable_emit_hooks():
|
||||||
# We don't want to call the hooks here since the graph that is calling
|
# We don't want to call the hooks here since the graph that is calling
|
||||||
# this function is not yet complete
|
# this function is not yet complete
|
||||||
create_methods_from_stubs(concrete_type, (stub,))
|
create_methods_from_stubs(concrete_type, (stub,))
|
||||||
|
|
|
||||||
|
|
@ -15,12 +15,15 @@ import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._jit_internal as _jit_internal
|
import torch._jit_internal as _jit_internal
|
||||||
|
from torch.utils import set_module
|
||||||
from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module
|
from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.jit._state import _enabled
|
from torch.jit._state import _enabled
|
||||||
|
from torch.jit._builtins import _register_builtin
|
||||||
from torch._six import with_metaclass, get_function_from_type
|
from torch._six import with_metaclass, get_function_from_type
|
||||||
from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def
|
from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def
|
||||||
from torch._jit_internal import _qualified_name
|
from torch._jit_internal import _qualified_name
|
||||||
|
from torch.jit._fuser import _graph_for
|
||||||
from torch.jit._state import (
|
from torch.jit._state import (
|
||||||
_try_get_jit_cached_function,
|
_try_get_jit_cached_function,
|
||||||
_try_get_jit_cached_overloads,
|
_try_get_jit_cached_overloads,
|
||||||
|
|
@ -28,6 +31,16 @@ from torch.jit._state import (
|
||||||
_set_jit_overload_cache,
|
_set_jit_overload_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
torch._C.ScriptMethod.graph_for = _graph_for
|
||||||
|
torch._C.ScriptFunction.graph_for = _graph_for
|
||||||
|
ScriptFunction = torch._C.ScriptFunction
|
||||||
|
ScriptFunction.__doc__ = """
|
||||||
|
Functionally equivalent to a :class:`ScriptModule`, but represents a single
|
||||||
|
function and does not have any attributes or Parameters.
|
||||||
|
"""
|
||||||
|
set_module(ScriptFunction, "torch.jit")
|
||||||
|
|
||||||
|
|
||||||
if _enabled:
|
if _enabled:
|
||||||
Attribute = collections.namedtuple("Attribute", ["value", "type"])
|
Attribute = collections.namedtuple("Attribute", ["value", "type"])
|
||||||
else:
|
else:
|
||||||
|
|
@ -1053,3 +1066,32 @@ def _recursive_compile_class(obj, loc):
|
||||||
error_stack = torch._C.CallStack(_qual_name, loc)
|
error_stack = torch._C.CallStack(_qual_name, loc)
|
||||||
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
|
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
|
||||||
_compile_and_register_class(obj, rcb, _qual_name)
|
_compile_and_register_class(obj, rcb, _qual_name)
|
||||||
|
|
||||||
|
|
||||||
|
_register_builtin(is_scripting, "aten::is_scripting")
|
||||||
|
|
||||||
|
|
||||||
|
class CompilationUnit(object):
|
||||||
|
def __init__(self, lang=None, _frames_up=0):
|
||||||
|
self._c = torch._C.CompilationUnit()
|
||||||
|
if lang is not None:
|
||||||
|
self.define(lang, _frames_up=_frames_up + 1)
|
||||||
|
|
||||||
|
def define(self, lang, rcb=None, _frames_up=0):
|
||||||
|
if not rcb:
|
||||||
|
rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
|
||||||
|
self._c.define(lang, rcb)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
r = self._c.find_function(attr)
|
||||||
|
if r is None:
|
||||||
|
raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
def _unwrap_optional(x):
|
||||||
|
assert x is not None, "Unwrapping null optional"
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
_register_builtin(_unwrap_optional, "aten::_unwrap_optional")
|
||||||
|
|
|
||||||
|
|
@ -11,12 +11,13 @@ import torch
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import functools
|
||||||
import warnings
|
import warnings
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from torch.jit._state import _python_cu, _enabled
|
from torch.jit._state import _python_cu, _enabled
|
||||||
from torch.jit._script import ScriptModule, _CachedForward
|
from torch.jit._script import ScriptModule, _CachedForward, script
|
||||||
from torch._jit_internal import _qualified_name
|
from torch._jit_internal import _qualified_name
|
||||||
from torch.autograd import function
|
from torch.autograd import function
|
||||||
from torch import _jit_internal
|
from torch import _jit_internal
|
||||||
|
|
@ -1077,3 +1078,70 @@ class TopLevelTracedModule(TracedModule):
|
||||||
cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around.
|
cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around.
|
||||||
"""
|
"""
|
||||||
self.__dict__["_actual_script_module"]._reconstruct(cpp_module)
|
self.__dict__["_actual_script_module"]._reconstruct(cpp_module)
|
||||||
|
|
||||||
|
|
||||||
|
def _script_if_tracing(fn):
|
||||||
|
"""
|
||||||
|
Compiles ``fn`` when it is first called during tracing. ``torch.jit.script``
|
||||||
|
has a non-negligible start up time when it is first called due to
|
||||||
|
lazy-initializations of many compiler builtins. Therefore you should not use
|
||||||
|
it in library code. However, you may want to have parts of your library work
|
||||||
|
in tracing even if they use control flow. In these cases, you should use
|
||||||
|
``@torch.jit._script_if_tracing`` to substitute for
|
||||||
|
``torch.jit.script``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not is_tracing():
|
||||||
|
# Not tracing, don't do anything
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
compiled_fn = script(wrapper.__original_fn)
|
||||||
|
return compiled_fn(*args, **kwargs)
|
||||||
|
|
||||||
|
wrapper.__original_fn = fn
|
||||||
|
wrapper.__script_if_tracing_wrapper = True
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
|
||||||
|
return_inputs=False, _return_inputs_states=False):
|
||||||
|
"""
|
||||||
|
.. warning::
|
||||||
|
This function is internal-only and should only be used by the ONNX
|
||||||
|
exporter. If you are trying to get a graph through tracing, please go
|
||||||
|
through the public API instead::
|
||||||
|
|
||||||
|
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
|
||||||
|
trace_graph = trace.graph
|
||||||
|
|
||||||
|
Trace a function or model, returning a tuple consisting of the both the
|
||||||
|
*trace* of an execution, as well as the original return value. If return_inputs,
|
||||||
|
also returns the trace inputs as part of the tuple
|
||||||
|
|
||||||
|
Tracing is guaranteed not to change the semantics of the function/module
|
||||||
|
that is traced.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
f (torch.nn.Module or function): the function or module
|
||||||
|
to be traced.
|
||||||
|
args (tuple or Tensor): the positional arguments to pass to the
|
||||||
|
function/module to be traced. A non-tuple is assumed to
|
||||||
|
be a single positional argument to be passed to the model.
|
||||||
|
kwargs (dict): the keyword arguments to pass to the function/module
|
||||||
|
to be traced.
|
||||||
|
|
||||||
|
Example (trace a cell):
|
||||||
|
|
||||||
|
.. testcode::
|
||||||
|
|
||||||
|
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
|
||||||
|
"""
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
if not isinstance(args, tuple):
|
||||||
|
args = (args,)
|
||||||
|
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
|
||||||
|
return outs
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import torch.jit
|
import torch.jit
|
||||||
|
from torch.jit._builtins import _find_builtin
|
||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
# this file is for generating documentation using sphinx autodoc
|
# this file is for generating documentation using sphinx autodoc
|
||||||
|
|
@ -92,7 +93,7 @@ def _get_nn_functional_ops():
|
||||||
for mod in torch.jit._builtins._modules_containing_builtins:
|
for mod in torch.jit._builtins._modules_containing_builtins:
|
||||||
name = mod.__name__
|
name = mod.__name__
|
||||||
for elem in dir(mod):
|
for elem in dir(mod):
|
||||||
builtin = torch.jit._find_builtin(getattr(mod, elem))
|
builtin = _find_builtin(getattr(mod, elem))
|
||||||
if builtin is not None:
|
if builtin is not None:
|
||||||
schemas = torch._C._jit_get_schemas_for_operator(builtin)
|
schemas = torch._C._jit_get_schemas_for_operator(builtin)
|
||||||
for schema in schemas:
|
for schema in schemas:
|
||||||
|
|
@ -133,7 +134,7 @@ def _get_torchscript_builtins():
|
||||||
# Iterate over the specially added builtins
|
# Iterate over the specially added builtins
|
||||||
for fn, _builtin_name in builtins:
|
for fn, _builtin_name in builtins:
|
||||||
mod = inspect.getmodule(fn)
|
mod = inspect.getmodule(fn)
|
||||||
builtin = torch.jit._find_builtin(fn)
|
builtin = _find_builtin(fn)
|
||||||
if builtin is not None:
|
if builtin is not None:
|
||||||
schemas = torch._C._jit_get_schemas_for_operator(builtin)
|
schemas = torch._C._jit_get_schemas_for_operator(builtin)
|
||||||
for schema in schemas:
|
for schema in schemas:
|
||||||
|
|
@ -150,7 +151,7 @@ def _get_math_builtins():
|
||||||
# Iterate over the specially added builtins
|
# Iterate over the specially added builtins
|
||||||
for fn, _builtin_name in builtins:
|
for fn, _builtin_name in builtins:
|
||||||
mod = inspect.getmodule(fn)
|
mod = inspect.getmodule(fn)
|
||||||
builtin = torch.jit._find_builtin(fn)
|
builtin = _find_builtin(fn)
|
||||||
if builtin is not None:
|
if builtin is not None:
|
||||||
schemas = torch._C._jit_get_schemas_for_operator(builtin)
|
schemas = torch._C._jit_get_schemas_for_operator(builtin)
|
||||||
for schema in schemas:
|
for schema in schemas:
|
||||||
|
|
|
||||||
|
|
@ -1222,7 +1222,7 @@ class RpcTest(RpcAgentTestFixture):
|
||||||
events = prof.function_events
|
events = prof.function_events
|
||||||
|
|
||||||
rpc_mul_event = get_function_event(
|
rpc_mul_event = get_function_event(
|
||||||
events, torch.jit._find_builtin(torch.mul)
|
events, torch.jit._builtins._find_builtin(torch.mul)
|
||||||
)
|
)
|
||||||
|
|
||||||
remote_events = {
|
remote_events = {
|
||||||
|
|
|
||||||
|
|
@ -358,7 +358,7 @@ def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name
|
||||||
|
|
||||||
f_args_variable = (self_variable,) + args_variable
|
f_args_variable = (self_variable,) + args_variable
|
||||||
f_args_tensor = (self_tensor,) + args_tensor
|
f_args_tensor = (self_tensor,) + args_tensor
|
||||||
with torch.jit._disable_emit_hooks():
|
with torch._jit_internal._disable_emit_hooks():
|
||||||
script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable)
|
script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable)
|
||||||
return script_fn, inputs
|
return script_fn, inputs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,7 @@ class JitTestCase(TestCase):
|
||||||
return code_files, debug_files
|
return code_files, debug_files
|
||||||
|
|
||||||
# disable the hook while we parse code, otherwise we will re-enter the hook
|
# disable the hook while we parse code, otherwise we will re-enter the hook
|
||||||
with torch.jit._disable_emit_hooks():
|
with torch._jit_internal._disable_emit_hooks():
|
||||||
try:
|
try:
|
||||||
# short-circuit if this is an empty function or module
|
# short-circuit if this is an empty function or module
|
||||||
if len(m.code) == 0:
|
if len(m.code) == 0:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user