diff --git a/test/jit/test_models.py b/test/jit/test_models.py index 14d9efc1d2e..b470ad95137 100644 --- a/test/jit/test_models.py +++ b/test/jit/test_models.py @@ -487,7 +487,7 @@ class TestModels(JitTestCase): return self.seq.forward(input) # disabled due to a jitter issues that will be fixed by using load/store in the compiler - with torch.jit._disable_emit_hooks(): + with torch._jit_internal._disable_emit_hooks(): # TODO: toggle export_import once above issues are fixed self.checkTrace(Traced(), (torch.rand(3, 4),), export_import=False) diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py index ab02e875478..db25bcc3a22 100644 --- a/test/quantization/test_quantize.py +++ b/test/quantization/test_quantize.py @@ -1817,7 +1817,7 @@ class TestDeprecatedJitQuantized(JitTestCase): def weight(self, w): self._packed_weight = torch.ops.quantized.linear_prepack(w) - with torch.jit._disable_emit_hooks(): + with torch._jit_internal._disable_emit_hooks(): x = torch.jit.script(Linear(10, 10)) torch._C._jit_pass_erase_shape_information(x.graph) diff --git a/test/test_jit.py b/test/test_jit.py index bd2a058b1b5..151bb857aaa 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2159,7 +2159,7 @@ graph(%Ra, %Rb): self.assertExpected(cu.foo.code) def test_import_method(self): - with torch.jit._disable_emit_hooks(): + with torch._jit_internal._disable_emit_hooks(): class Foo(torch.jit.ScriptModule): def __init__(self): super(Foo, self).__init__() @@ -3596,7 +3596,7 @@ def foo(x): mod.ninf = float("-inf") mod.nan = float("nan") - with torch.jit._disable_emit_hooks(): + with torch._jit_internal._disable_emit_hooks(): class Foo(torch.jit.ScriptModule): def __init__(self): super(Foo, self).__init__() @@ -9122,7 +9122,7 @@ a") x[seq_lens[b]:, b, :] = 0 eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens) - with torch.jit._disable_emit_hooks(): + with torch._jit_internal._disable_emit_hooks(): scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script) script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens) self.assertEqual(eager_seq, script_seq) @@ -9145,7 +9145,7 @@ a") lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2) - with torch.jit._disable_emit_hooks(): + with torch._jit_internal._disable_emit_hooks(): self.checkModule(lstm, [torch.ones(2, 2)]) def test_script_pad_sequence_pack_sequence(self): @@ -9165,7 +9165,7 @@ a") tensor1 = torch.tensor([1, 2, 3]) tensor2 = torch.tensor([4, 5]) tensor3 = torch.tensor([6]) - with torch.jit._disable_emit_hooks(): + with torch._jit_internal._disable_emit_hooks(): self.checkScript(pad_sequence_func, ([ones3, ones4, ones5],)) self.checkScript(pad_sequence_func, @@ -9361,7 +9361,7 @@ a") def test_tuples(self): # TODO: jitter issue. - with torch.jit._disable_emit_hooks(): # TODO: Python print broadcasting list + with torch._jit_internal._disable_emit_hooks(): # TODO: Python print broadcasting list def foo(i): a = (i + 4, i * 2) c = a @@ -12613,7 +12613,7 @@ a") self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) def test_bool_dispatch(self): - with torch.jit._disable_emit_hooks(): # TODO: Python print broadcasting list + with torch._jit_internal._disable_emit_hooks(): # TODO: Python print broadcasting list def kwarg_false(x): # type: (Tensor) -> Tensor return F.max_pool1d(x, 1, 1, return_indices=False) @@ -14237,7 +14237,7 @@ a") # type: (str) -> Tensor return self.table[key] + self.x - with torch.jit._disable_emit_hooks(): + with torch._jit_internal._disable_emit_hooks(): # TODO: re-enable module hook when Python printing of attributes is # supported m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"}) @@ -15393,7 +15393,7 @@ def add_nn_functional_test(name, self_size, args, variant_name='', check_ad=(), self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes) if test_name in EXCLUDE_PYTHON_PRINT: - with torch.jit._disable_emit_hooks(): + with torch._jit_internal._disable_emit_hooks(): run_test() else: run_test() diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index 0a2589e553c..45023a78d9a 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -474,7 +474,7 @@ class TestFuser(JitTestCase): @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on") - @torch.jit._disable_emit_hooks_decorator + @torch._jit_internal._disable_emit_hooks_decorator @_inline_everything def test_fuse_decompose_normalization(self): class ResLike(torch.jit.ScriptModule): diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index c9810169cb8..39c431cc91b 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -507,7 +507,7 @@ class TestFuser(JitTestCase): @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on") - @torch.jit._disable_emit_hooks_decorator + @torch._jit_internal._disable_emit_hooks_decorator @_inline_everything def test_fuse_decompose_normalization(self): class ResLike(torch.jit.ScriptModule): diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index d7cfde3e730..4490df49976 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -4,6 +4,8 @@ can be used in other places in torch/ (namely torch.nn) without running into circular dependency problems """ +import contextlib +import collections import inspect import weakref import warnings @@ -767,3 +769,45 @@ class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): def fake_range(): return SourceContext('', None, 0, 0).make_raw_range(0, 1) + + +def _try_get_dispatched_fn(fn): + if not callable(fn): + return None + return boolean_dispatched.get(fn) + + +def _get_named_tuple_properties(obj): + assert issubclass(obj, tuple) and hasattr(obj, '_fields') + fields = list(obj._fields) + annotations = [] + has_annotations = hasattr(obj, '__annotations__') + for field in fields: + if has_annotations and field in obj.__annotations__: + the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range()) + annotations.append(the_type) + else: + annotations.append(torch._C.TensorType.get()) + return type(obj).__name__, fields, annotations + + +def _create_named_tuple(t, unqual_name, field_names): + TupleType = collections.namedtuple(unqual_name, field_names) + return TupleType(*t) + + +@contextlib.contextmanager +def _disable_emit_hooks(): + hooks = torch._C._jit_get_emit_hooks() + torch._C._jit_set_emit_hooks(None, None) + yield + torch._C._jit_set_emit_hooks(hooks[0], hooks[1]) + + +def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811 + def __enter__(self): + self.hooks = torch._C._jit_get_emit_hooks() + torch._C._jit_set_emit_hooks(None, None) + + def __exit__(self, *args): + torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1]) diff --git a/torch/_ops.py b/torch/_ops.py index 0161a65eb52..70edc282202 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -61,7 +61,7 @@ class _OpNamespace(types.ModuleType): op = torch._C._jit_get_operation(qualified_op_name) # let the script frontend know that op is identical to the builtin op # with qualified_op_name - torch.jit._register_builtin(op, qualified_op_name) + torch.jit._builtins._register_builtin(op, qualified_op_name) setattr(self, op_name, op) op.__module__ = self.__module__ + "." + self.name return op diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 4302868ed07..5cbf14cef89 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -839,7 +839,7 @@ inline py::object toPyObject(IValue ivalue) { auto fieldNames = fmap( tuple->type()->schema()->arguments(), [](const Argument& arg) { return arg.name(); }); - return py::module::import("torch.jit") + return py::module::import("torch._jit_internal") .attr("_create_named_tuple")(t, unqualName, fieldNames); } else { return std::move(t); diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 1cb42833e70..5e72fc04bb5 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -686,8 +686,8 @@ TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc) { } } - py::object props = - py::module::import("torch.jit").attr("_get_named_tuple_properties")(obj); + py::object props = py::module::import("torch._jit_internal") + .attr("_get_named_tuple_properties")(obj); std::string unqualName; std::vector fields; std::vector annotations; @@ -788,7 +788,7 @@ std::shared_ptr toSugaredValue( } py::object builtin_name = - py::module::import("torch.jit").attr("_find_builtin")(obj); + py::module::import("torch.jit._builtins").attr("_find_builtin")(obj); if (!builtin_name.is_none()) { return std::make_shared( Symbol::fromQualString(py::str(builtin_name)), c10::nullopt); @@ -801,8 +801,8 @@ std::shared_ptr toSugaredValue( } } - py::object dispatched_fn = - py::module::import("torch.jit").attr("_try_get_dispatched_fn")(obj); + py::object dispatched_fn = py::module::import("torch._jit_internal") + .attr("_try_get_dispatched_fn")(obj); if (!dispatched_fn.is_none()) { return std::make_shared(std::move(dispatched_fn)); } diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index d3458d27588..bd2c73b367f 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -160,7 +160,7 @@ def _wait_all_workers(): is_leader_worker = leader_worker_name == self_worker_name # Set a long enough timeout for all shutdown messages to be processed. - timeout = 5 # seconds + timeout = 5 # second # Phase 1: Followers send intents. # All followers report intents to the leader. @@ -522,7 +522,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown() """ - qualified_name = torch.jit._find_builtin(func) + qualified_name = torch.jit._builtins._find_builtin(func) dst_worker_info = _to_worker_info(to) should_profile = torch.autograd._profiler_enabled() @@ -594,7 +594,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP if not callable(func): raise TypeError("function should be callable.") - qualified_name = torch.jit._find_builtin(func) + qualified_name = torch.jit._builtins._find_builtin(func) dst_worker_info = _to_worker_info(to) # TODO: profiling logic does not really belong in invoke_rpc, it should be diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index c1bdd45a92c..ec926e1afb1 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1,93 +1,52 @@ import torch._C -import torch._jit_internal as _jit_internal -from torch.jit._builtins import _find_builtin, _get_builtin_table, _register_builtin # noqa -from torch._jit_internal import Future -from torch.nn import Module from torch.utils import set_module -from torch.autograd.grad_mode import _DecoratorContextManager -from typing import Optional, List - -import collections -import contextlib -import functools -import os -import pathlib # These are imported so users can access them from the `torch.jit` module -from torch._jit_internal import Final, _overload, _overload_method -from torch._jit_internal import ignore, export, unused -from torch.jit._script import script, Attribute, ScriptModule, is_scripting, script_method, \ - RecursiveScriptModule, ScriptWarning, interface -from torch.jit._trace import trace, trace_module, TracedModule, TracerWarning, TracingCheckError, \ - is_tracing, ONNXTracedModule, _unique_state_dict, _flatten, TopLevelTracedModule +from torch._jit_internal import ( + Final, + Future, + _overload, + _overload_method, + ignore, + export, + unused, +) +from torch.jit._script import ( + script, + Attribute, + ScriptModule, + is_scripting, + script_method, + RecursiveScriptModule, + ScriptWarning, + interface, + CompilationUnit, + ScriptFunction, + _unwrap_optional, +) +from torch.jit._trace import ( + trace, + trace_module, + TracedModule, + TracerWarning, + TracingCheckError, + is_tracing, + ONNXTracedModule, + TopLevelTracedModule, + _unique_state_dict, + _flatten, + _script_if_tracing, + _get_trace_graph, +) from torch.jit._async import fork, wait from torch.jit._serialization import save, load - -set_module(Future, "torch.jit") +from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph # For backwards compatibility _fork = fork _wait = wait -@contextlib.contextmanager -def optimized_execution(should_optimize): - """ - A context manager that controls whether the JIT's executor will run - optimizations before executing a function. - """ - stored_flag = torch._C._get_graph_executor_optimize() - torch._C._set_graph_executor_optimize(should_optimize) - try: - yield - finally: - torch._C._set_graph_executor_optimize(stored_flag) - -@contextlib.contextmanager -def fuser(name): - """ - A context manager that facilitates switching between - backend fusers. - - Valid names: - * ``fuser0`` - enables only legacy fuser - * ``fuser1`` - enables only NNC - * ``fuser2`` - enables only nvFuser - """ - old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() - old_gpu_fuse = torch._C._jit_can_fuse_on_gpu() - old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() - old_nvfuser_state = torch._C._jit_nvfuser_enabled() - if name == 'fuser0': # legacy fuser - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - torch._C._jit_set_texpr_fuser_enabled(False) - torch._C._jit_set_nvfuser_enabled(False) - elif name == 'fuser1': # NNC - old_profiling_executor = torch._C._jit_set_profiling_executor(True) - old_profiling_mode = torch._C._jit_set_profiling_mode(True) - torch._C._jit_override_can_fuse_on_cpu(False) - torch._C._jit_override_can_fuse_on_gpu(False) - torch._C._jit_set_texpr_fuser_enabled(True) - torch._C._jit_set_nvfuser_enabled(False) - elif name == 'fuser2': # nvFuser - torch._C._jit_override_can_fuse_on_cpu(False) - torch._C._jit_override_can_fuse_on_gpu(False) - torch._C._jit_set_texpr_fuser_enabled(False) - torch._C._jit_set_nvfuser_enabled(True) - else: - raise Exception("unrecognized fuser option") - try: - yield - finally: - if name == 'fuser1': # NNC - torch._C._jit_set_profiling_executor(old_profiling_executor) - torch._C._jit_set_profiling_mode(old_profiling_mode) - # recover the previous values - torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse) - torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse) - torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state) - torch._C._jit_set_nvfuser_enabled(old_nvfuser_state) def export_opnames(m): r""" @@ -95,212 +54,6 @@ def export_opnames(m): """ return torch._C._export_opnames(m._c) -def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False, - return_inputs=False, _return_inputs_states=False): - """ - .. warning:: - This function is internal-only and should only be used by the ONNX - exporter. If you are trying to get a graph through tracing, please go - through the public API instead:: - - trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) - trace_graph = trace.graph - - Trace a function or model, returning a tuple consisting of the both the - *trace* of an execution, as well as the original return value. If return_inputs, - also returns the trace inputs as part of the tuple - - Tracing is guaranteed not to change the semantics of the function/module - that is traced. - - Arguments: - f (torch.nn.Module or function): the function or module - to be traced. - args (tuple or Tensor): the positional arguments to pass to the - function/module to be traced. A non-tuple is assumed to - be a single positional argument to be passed to the model. - kwargs (dict): the keyword arguments to pass to the function/module - to be traced. - - Example (trace a cell): - - .. testcode:: - - trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) - """ - if kwargs is None: - kwargs = {} - if not isinstance(args, tuple): - args = (args,) - outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) - return outs - - -def freeze(mod, preserved_attrs : Optional[List[str]] = None): - r""" - Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned - module's submodules, parameters, and attributes as constants in the TorchScript IR Graph. - By default, `forward` will be preserved, as well as attributes & methods specified in - `preserved_attrs`. Additionally, any attribute that is modified within a preserved - method will be preserved. - - Freezing currently only accepts ScriptModules that are in eval mode. - - Arguments: - mod (:class:`ScriptModule`): a module to be frozen - - preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method. - Attributes modified in preserved methods will also be preserved. - - Returns: - Frozen :class:`ScriptModule`. - - Example (Freezing a simple module with a Parameter): - - .. testcode:: - import torch - class MyModule(torch.nn.Module): - def __init__(self, N, M): - super(MyModule, self).__init__() - self.weight = torch.nn.Parameter(torch.rand(N, M)) - self.linear = torch.nn.Linear(N, M) - - def forward(self, input): - output = self.weight.mm(input) - output = self.linear(output) - return output - - scripted_module = torch.jit.script(MyModule(2, 3).eval()) - frozen_module = torch.jit.freeze(scripted_module) - # parameters have been removed and inlined into the Graph as constants - assert len(list(frozen_module.named_parameters())) == 0 - # See the compiled graph as Python code - print(frozen_module.code) - - Example (Freezing a module with preserved attributes) - - .. testcode:: - import torch - class MyModule2(torch.nn.Module): - def __init__(self): - super(MyModule2, self).__init__() - self.modified_tensor = torch.tensor(10.) - self.version = 1 - - def forward(self, input): - self.modified_tensor += 1 - return input + self.modified_tensor - - scripted_module = torch.jit.script(MyModule2().eval()) - frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"]) - # we've manually preserved `version`, so it still exists on the frozen module and can be modified - assert frozen_module.version == 1 - frozen_module.version = 2 - # `modified_tensor` is detected as being mutated in the forward, so freezing preserves - # it to retain model semantics - assert frozen_module(torch.tensor(1)) == torch.tensor(12) - # now that we've run it once, the next result will be incremented by one - assert frozen_module(torch.tensor(1)) == torch.tensor(13) - - Note: - If you're not sure why an attribute is not being inlined as a constant, you can run - `dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the - attribute is being modified. - """ - if not isinstance(mod, ScriptModule): - raise RuntimeError("Freezing expects a ScriptModule as input. " - "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'.") - - if mod.training: - raise RuntimeError("Freezing is currently only implemented for modules in eval mode. " - "Please call .eval() on your module before freezing.") - - preserved_attrs = preserved_attrs if preserved_attrs is not None else [] - - out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs)) - RecursiveScriptModule._finalize_scriptmodule(out) - - return out - - -class CompilationUnit(object): - def __init__(self, lang=None, _frames_up=0): - self._c = torch._C.CompilationUnit() - if lang is not None: - self.define(lang, _frames_up=_frames_up + 1) - - def define(self, lang, rcb=None, _frames_up=0): - if not rcb: - rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) - self._c.define(lang, rcb) - - def __getattr__(self, attr): - r = self._c.find_function(attr) - if r is None: - raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr)) - return r - - -def _try_get_dispatched_fn(fn): - if not callable(fn): - return None - return _jit_internal.boolean_dispatched.get(fn) - - -def _try_get_overloaded_fn(mod, field): - return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None - - -@contextlib.contextmanager -def _disable_emit_hooks(): - hooks = torch._C._jit_get_emit_hooks() - torch._C._jit_set_emit_hooks(None, None) - yield - torch._C._jit_set_emit_hooks(hooks[0], hooks[1]) - - -def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811 - def __enter__(self): - self.hooks = torch._C._jit_get_emit_hooks() - torch._C._jit_set_emit_hooks(None, None) - - def __exit__(self, *args): - torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1]) - - -def _script_if_tracing(fn): - """ - Compiles ``fn`` when it is first called during tracing. ``torch.jit.script`` - has a non-negligible start up time when it is first called due to - lazy-initializations of many compiler builtins. Therefore you should not use - it in library code. However, you may want to have parts of your library work - in tracing even if they use control flow. In these cases, you should use - ``@torch.jit._script_if_tracing`` to substitute for - ``torch.jit.script``. - """ - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if not is_tracing(): - # Not tracing, don't do anything - return fn(*args, **kwargs) - - compiled_fn = script(wrapper.__original_fn) - return compiled_fn(*args, **kwargs) - - wrapper.__original_fn = fn - wrapper.__script_if_tracing_wrapper = True - - return wrapper - -def _unwrap_optional(x): - assert x is not None, "Unwrapping null optional" - return x - -_register_builtin(_unwrap_optional, 'aten::_unwrap_optional') -_register_builtin(_wait, 'aten::wait') -_register_builtin(wait, 'aten::wait') -_register_builtin(is_scripting, 'aten::is_scripting') - # torch.jit.Error Error = torch._C.JITException @@ -309,53 +62,11 @@ set_module(Error, "torch.jit") Error.__name__ = "Error" Error.__qualname__ = "Error" -def _get_named_tuple_properties(obj): - assert issubclass(obj, tuple) and hasattr(obj, '_fields') - fields = list(obj._fields) - annotations = [] - has_annotations = hasattr(obj, '__annotations__') - for field in fields: - if has_annotations and field in obj.__annotations__: - the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], _jit_internal.fake_range()) - annotations.append(the_type) - else: - annotations.append(torch._C.TensorType.get()) - return type(obj).__name__, fields, annotations - -def _create_named_tuple(t, unqual_name, field_names): - TupleType = collections.namedtuple(unqual_name, field_names) - return TupleType(*t) - -class _disable_tracing(object): - def __enter__(self): - self.state = torch._C._get_tracing_state() - torch._C._set_tracing_state(None) - - def __exit__(self, *args): - torch._C._set_tracing_state(self.state) - self.state = None - - # for use in python if using annotate def annotate(the_type, the_value): # noop in python return the_value -last_executed_optimized_graph = torch._C._last_executed_optimized_graph - - -def _graph_for(self, *args, **kwargs): - self(*args, **kwargs) - return last_executed_optimized_graph() - -torch._C.ScriptMethod.graph_for = _graph_for -torch._C.ScriptFunction.graph_for = _graph_for -ScriptFunction = torch._C.ScriptFunction -ScriptFunction.__doc__ = """ -Functionally equivalent to a :class:`ScriptModule`, but represents a single -function and does not have any attributes or Parameters. -""" -set_module(ScriptFunction, "torch.jit") if not torch._C._jit_init(): raise RuntimeError("JIT initialization failed") diff --git a/torch/jit/_async.py b/torch/jit/_async.py index df69db0928b..5e67167bd41 100644 --- a/torch/jit/_async.py +++ b/torch/jit/_async.py @@ -9,6 +9,12 @@ functionalities in `torch.jit`. import torch +from torch.utils import set_module +from torch.jit._builtins import _register_builtin +from torch._jit_internal import Future + +set_module(Future, "torch.jit") + def fork(func, *args, **kwargs): """ @@ -84,3 +90,6 @@ def wait(future): `T`: the return value of the the completed task """ return torch._C.wait(future) + + +_register_builtin(wait, "aten::wait") diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py new file mode 100644 index 00000000000..5c217ea17c1 --- /dev/null +++ b/torch/jit/_freeze.py @@ -0,0 +1,101 @@ +"""Freezing + +This is not intended to be imported directly; please use the exposed +functionalities in `torch.jit`. +""" + +from typing import Optional, List + +import torch +from torch.jit._script import RecursiveScriptModule, ScriptModule + + +def freeze(mod, preserved_attrs: Optional[List[str]] = None): + r""" + Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned + module's submodules, parameters, and attributes as constants in the TorchScript IR Graph. + By default, `forward` will be preserved, as well as attributes & methods specified in + `preserved_attrs`. Additionally, any attribute that is modified within a preserved + method will be preserved. + + Freezing currently only accepts ScriptModules that are in eval mode. + + Arguments: + mod (:class:`ScriptModule`): a module to be frozen + + preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method. + Attributes modified in preserved methods will also be preserved. + + Returns: + Frozen :class:`ScriptModule`. + + Example (Freezing a simple module with a Parameter): + + .. testcode:: + import torch + class MyModule(torch.nn.Module): + def __init__(self, N, M): + super(MyModule, self).__init__() + self.weight = torch.nn.Parameter(torch.rand(N, M)) + self.linear = torch.nn.Linear(N, M) + + def forward(self, input): + output = self.weight.mm(input) + output = self.linear(output) + return output + + scripted_module = torch.jit.script(MyModule(2, 3).eval()) + frozen_module = torch.jit.freeze(scripted_module) + # parameters have been removed and inlined into the Graph as constants + assert len(list(frozen_module.named_parameters())) == 0 + # See the compiled graph as Python code + print(frozen_module.code) + + Example (Freezing a module with preserved attributes) + + .. testcode:: + import torch + class MyModule2(torch.nn.Module): + def __init__(self): + super(MyModule2, self).__init__() + self.modified_tensor = torch.tensor(10.) + self.version = 1 + + def forward(self, input): + self.modified_tensor += 1 + return input + self.modified_tensor + + scripted_module = torch.jit.script(MyModule2().eval()) + frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"]) + # we've manually preserved `version`, so it still exists on the frozen module and can be modified + assert frozen_module.version == 1 + frozen_module.version = 2 + # `modified_tensor` is detected as being mutated in the forward, so freezing preserves + # it to retain model semantics + assert frozen_module(torch.tensor(1)) == torch.tensor(12) + # now that we've run it once, the next result will be incremented by one + assert frozen_module(torch.tensor(1)) == torch.tensor(13) + + Note: + If you're not sure why an attribute is not being inlined as a constant, you can run + `dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the + attribute is being modified. + """ + if not isinstance(mod, ScriptModule): + raise RuntimeError( + "Freezing expects a ScriptModule as input. " + "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'." + ) + + if mod.training: + raise RuntimeError( + "Freezing is currently only implemented for modules in eval mode. " + "Please call .eval() on your module before freezing." + ) + + preserved_attrs = preserved_attrs if preserved_attrs is not None else [] + + out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs)) + RecursiveScriptModule._finalize_scriptmodule(out) + + return out diff --git a/torch/jit/_fuser.py b/torch/jit/_fuser.py new file mode 100644 index 00000000000..5de317cd935 --- /dev/null +++ b/torch/jit/_fuser.py @@ -0,0 +1,70 @@ +import contextlib + +import torch + +@contextlib.contextmanager +def optimized_execution(should_optimize): + """ + A context manager that controls whether the JIT's executor will run + optimizations before executing a function. + """ + stored_flag = torch._C._get_graph_executor_optimize() + torch._C._set_graph_executor_optimize(should_optimize) + try: + yield + finally: + torch._C._set_graph_executor_optimize(stored_flag) + +@contextlib.contextmanager +def fuser(name): + """ + A context manager that facilitates switching between + backend fusers. + + Valid names: + * ``fuser0`` - enables only legacy fuser + * ``fuser1`` - enables only NNC + * ``fuser2`` - enables only nvFuser + """ + old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() + old_gpu_fuse = torch._C._jit_can_fuse_on_gpu() + old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() + old_nvfuser_state = torch._C._jit_nvfuser_enabled() + if name == 'fuser0': # legacy fuser + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) + elif name == 'fuser1': # NNC + old_profiling_executor = torch._C._jit_set_profiling_executor(True) + old_profiling_mode = torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(True) + torch._C._jit_set_nvfuser_enabled(False) + elif name == 'fuser2': # nvFuser + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(True) + else: + raise Exception("unrecognized fuser option") + try: + yield + finally: + if name == 'fuser1': # NNC + torch._C._jit_set_profiling_executor(old_profiling_executor) + torch._C._jit_set_profiling_mode(old_profiling_mode) + # recover the previous values + torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse) + torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse) + torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state) + torch._C._jit_set_nvfuser_enabled(old_nvfuser_state) + + +last_executed_optimized_graph = torch._C._last_executed_optimized_graph + + +def _graph_for(self, *args, **kwargs): + self(*args, **kwargs) + return last_executed_optimized_graph() diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index d7f7be1f89a..0adface922f 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -609,7 +609,7 @@ def compile_unbound_method(concrete_type, fn): if _jit_internal.is_ignored_fn(fn): return None stub = make_stub(fn, fn.__name__) - with torch.jit._disable_emit_hooks(): + with torch._jit_internal._disable_emit_hooks(): # We don't want to call the hooks here since the graph that is calling # this function is not yet complete create_methods_from_stubs(concrete_type, (stub,)) diff --git a/torch/jit/_script.py b/torch/jit/_script.py index f1d8f7440d2..1c14379c305 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -15,12 +15,15 @@ import warnings import torch import torch._jit_internal as _jit_internal +from torch.utils import set_module from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module from torch.nn import Module from torch.jit._state import _enabled +from torch.jit._builtins import _register_builtin from torch._six import with_metaclass, get_function_from_type from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def from torch._jit_internal import _qualified_name +from torch.jit._fuser import _graph_for from torch.jit._state import ( _try_get_jit_cached_function, _try_get_jit_cached_overloads, @@ -28,6 +31,16 @@ from torch.jit._state import ( _set_jit_overload_cache, ) +torch._C.ScriptMethod.graph_for = _graph_for +torch._C.ScriptFunction.graph_for = _graph_for +ScriptFunction = torch._C.ScriptFunction +ScriptFunction.__doc__ = """ +Functionally equivalent to a :class:`ScriptModule`, but represents a single +function and does not have any attributes or Parameters. +""" +set_module(ScriptFunction, "torch.jit") + + if _enabled: Attribute = collections.namedtuple("Attribute", ["value", "type"]) else: @@ -1053,3 +1066,32 @@ def _recursive_compile_class(obj, loc): error_stack = torch._C.CallStack(_qual_name, loc) rcb = _jit_internal.createResolutionCallbackForClassMethods(obj) _compile_and_register_class(obj, rcb, _qual_name) + + +_register_builtin(is_scripting, "aten::is_scripting") + + +class CompilationUnit(object): + def __init__(self, lang=None, _frames_up=0): + self._c = torch._C.CompilationUnit() + if lang is not None: + self.define(lang, _frames_up=_frames_up + 1) + + def define(self, lang, rcb=None, _frames_up=0): + if not rcb: + rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) + self._c.define(lang, rcb) + + def __getattr__(self, attr): + r = self._c.find_function(attr) + if r is None: + raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr)) + return r + + +def _unwrap_optional(x): + assert x is not None, "Unwrapping null optional" + return x + + +_register_builtin(_unwrap_optional, "aten::_unwrap_optional") diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 6a18ebcd589..3b5a651a87c 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -11,12 +11,13 @@ import torch import os import contextlib +import functools import warnings import inspect import re from torch.jit._state import _python_cu, _enabled -from torch.jit._script import ScriptModule, _CachedForward +from torch.jit._script import ScriptModule, _CachedForward, script from torch._jit_internal import _qualified_name from torch.autograd import function from torch import _jit_internal @@ -1077,3 +1078,70 @@ class TopLevelTracedModule(TracedModule): cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around. """ self.__dict__["_actual_script_module"]._reconstruct(cpp_module) + + +def _script_if_tracing(fn): + """ + Compiles ``fn`` when it is first called during tracing. ``torch.jit.script`` + has a non-negligible start up time when it is first called due to + lazy-initializations of many compiler builtins. Therefore you should not use + it in library code. However, you may want to have parts of your library work + in tracing even if they use control flow. In these cases, you should use + ``@torch.jit._script_if_tracing`` to substitute for + ``torch.jit.script``. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if not is_tracing(): + # Not tracing, don't do anything + return fn(*args, **kwargs) + + compiled_fn = script(wrapper.__original_fn) + return compiled_fn(*args, **kwargs) + + wrapper.__original_fn = fn + wrapper.__script_if_tracing_wrapper = True + + return wrapper + + +def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False, + return_inputs=False, _return_inputs_states=False): + """ + .. warning:: + This function is internal-only and should only be used by the ONNX + exporter. If you are trying to get a graph through tracing, please go + through the public API instead:: + + trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) + trace_graph = trace.graph + + Trace a function or model, returning a tuple consisting of the both the + *trace* of an execution, as well as the original return value. If return_inputs, + also returns the trace inputs as part of the tuple + + Tracing is guaranteed not to change the semantics of the function/module + that is traced. + + Arguments: + f (torch.nn.Module or function): the function or module + to be traced. + args (tuple or Tensor): the positional arguments to pass to the + function/module to be traced. A non-tuple is assumed to + be a single positional argument to be passed to the model. + kwargs (dict): the keyword arguments to pass to the function/module + to be traced. + + Example (trace a cell): + + .. testcode:: + + trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) + """ + if kwargs is None: + kwargs = {} + if not isinstance(args, tuple): + args = (args,) + outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) + return outs diff --git a/torch/jit/supported_ops.py b/torch/jit/supported_ops.py index 4ca3079ec9b..da79ffb4164 100644 --- a/torch/jit/supported_ops.py +++ b/torch/jit/supported_ops.py @@ -1,4 +1,5 @@ import torch.jit +from torch.jit._builtins import _find_builtin import inspect import textwrap # this file is for generating documentation using sphinx autodoc @@ -92,7 +93,7 @@ def _get_nn_functional_ops(): for mod in torch.jit._builtins._modules_containing_builtins: name = mod.__name__ for elem in dir(mod): - builtin = torch.jit._find_builtin(getattr(mod, elem)) + builtin = _find_builtin(getattr(mod, elem)) if builtin is not None: schemas = torch._C._jit_get_schemas_for_operator(builtin) for schema in schemas: @@ -133,7 +134,7 @@ def _get_torchscript_builtins(): # Iterate over the specially added builtins for fn, _builtin_name in builtins: mod = inspect.getmodule(fn) - builtin = torch.jit._find_builtin(fn) + builtin = _find_builtin(fn) if builtin is not None: schemas = torch._C._jit_get_schemas_for_operator(builtin) for schema in schemas: @@ -150,7 +151,7 @@ def _get_math_builtins(): # Iterate over the specially added builtins for fn, _builtin_name in builtins: mod = inspect.getmodule(fn) - builtin = torch.jit._find_builtin(fn) + builtin = _find_builtin(fn) if builtin is not None: schemas = torch._C._jit_get_schemas_for_operator(builtin) for schema in schemas: diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index db8ae4960d0..cb6f53b50d5 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -1222,7 +1222,7 @@ class RpcTest(RpcAgentTestFixture): events = prof.function_events rpc_mul_event = get_function_event( - events, torch.jit._find_builtin(torch.mul) + events, torch.jit._builtins._find_builtin(torch.mul) ) remote_events = { diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index c1728f36cd0..08b4081819e 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -358,7 +358,7 @@ def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name f_args_variable = (self_variable,) + args_variable f_args_tensor = (self_tensor,) + args_tensor - with torch.jit._disable_emit_hooks(): + with torch._jit_internal._disable_emit_hooks(): script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable) return script_fn, inputs diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index 86c75bdd7fd..2f0f530891f 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -159,7 +159,7 @@ class JitTestCase(TestCase): return code_files, debug_files # disable the hook while we parse code, otherwise we will re-enter the hook - with torch.jit._disable_emit_hooks(): + with torch._jit_internal._disable_emit_hooks(): try: # short-circuit if this is an empty function or module if len(m.code) == 0: