mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
move misc implementation out of jit/__init__.py (#41154)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/41154 Test Plan: Imported from OSS Reviewed By: ailzhang Differential Revision: D22445213 Pulled By: suo fbshipit-source-id: 200545715c5ef13beb1437f49e01efb21498ddb7
This commit is contained in:
parent
6392713584
commit
ca1b8ebbcb
|
|
@ -487,7 +487,7 @@ class TestModels(JitTestCase):
|
|||
return self.seq.forward(input)
|
||||
|
||||
# disabled due to a jitter issues that will be fixed by using load/store in the compiler
|
||||
with torch.jit._disable_emit_hooks():
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
# TODO: toggle export_import once above issues are fixed
|
||||
self.checkTrace(Traced(), (torch.rand(3, 4),),
|
||||
export_import=False)
|
||||
|
|
|
|||
|
|
@ -1817,7 +1817,7 @@ class TestDeprecatedJitQuantized(JitTestCase):
|
|||
def weight(self, w):
|
||||
self._packed_weight = torch.ops.quantized.linear_prepack(w)
|
||||
|
||||
with torch.jit._disable_emit_hooks():
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
x = torch.jit.script(Linear(10, 10))
|
||||
torch._C._jit_pass_erase_shape_information(x.graph)
|
||||
|
||||
|
|
|
|||
|
|
@ -2159,7 +2159,7 @@ graph(%Ra, %Rb):
|
|||
self.assertExpected(cu.foo.code)
|
||||
|
||||
def test_import_method(self):
|
||||
with torch.jit._disable_emit_hooks():
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
class Foo(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(Foo, self).__init__()
|
||||
|
|
@ -3596,7 +3596,7 @@ def foo(x):
|
|||
mod.ninf = float("-inf")
|
||||
mod.nan = float("nan")
|
||||
|
||||
with torch.jit._disable_emit_hooks():
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
class Foo(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(Foo, self).__init__()
|
||||
|
|
@ -9122,7 +9122,7 @@ a")
|
|||
x[seq_lens[b]:, b, :] = 0
|
||||
|
||||
eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens)
|
||||
with torch.jit._disable_emit_hooks():
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script)
|
||||
script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens)
|
||||
self.assertEqual(eager_seq, script_seq)
|
||||
|
|
@ -9145,7 +9145,7 @@ a")
|
|||
|
||||
lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2)
|
||||
|
||||
with torch.jit._disable_emit_hooks():
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
self.checkModule(lstm, [torch.ones(2, 2)])
|
||||
|
||||
def test_script_pad_sequence_pack_sequence(self):
|
||||
|
|
@ -9165,7 +9165,7 @@ a")
|
|||
tensor1 = torch.tensor([1, 2, 3])
|
||||
tensor2 = torch.tensor([4, 5])
|
||||
tensor3 = torch.tensor([6])
|
||||
with torch.jit._disable_emit_hooks():
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
self.checkScript(pad_sequence_func,
|
||||
([ones3, ones4, ones5],))
|
||||
self.checkScript(pad_sequence_func,
|
||||
|
|
@ -9361,7 +9361,7 @@ a")
|
|||
|
||||
def test_tuples(self):
|
||||
# TODO: jitter issue.
|
||||
with torch.jit._disable_emit_hooks(): # TODO: Python print broadcasting list
|
||||
with torch._jit_internal._disable_emit_hooks(): # TODO: Python print broadcasting list
|
||||
def foo(i):
|
||||
a = (i + 4, i * 2)
|
||||
c = a
|
||||
|
|
@ -12613,7 +12613,7 @@ a")
|
|||
self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
|
||||
|
||||
def test_bool_dispatch(self):
|
||||
with torch.jit._disable_emit_hooks(): # TODO: Python print broadcasting list
|
||||
with torch._jit_internal._disable_emit_hooks(): # TODO: Python print broadcasting list
|
||||
def kwarg_false(x):
|
||||
# type: (Tensor) -> Tensor
|
||||
return F.max_pool1d(x, 1, 1, return_indices=False)
|
||||
|
|
@ -14237,7 +14237,7 @@ a")
|
|||
# type: (str) -> Tensor
|
||||
return self.table[key] + self.x
|
||||
|
||||
with torch.jit._disable_emit_hooks():
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
# TODO: re-enable module hook when Python printing of attributes is
|
||||
# supported
|
||||
m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
|
||||
|
|
@ -15393,7 +15393,7 @@ def add_nn_functional_test(name, self_size, args, variant_name='', check_ad=(),
|
|||
self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
|
||||
|
||||
if test_name in EXCLUDE_PYTHON_PRINT:
|
||||
with torch.jit._disable_emit_hooks():
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
run_test()
|
||||
else:
|
||||
run_test()
|
||||
|
|
|
|||
|
|
@ -474,7 +474,7 @@ class TestFuser(JitTestCase):
|
|||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
|
||||
@torch.jit._disable_emit_hooks_decorator
|
||||
@torch._jit_internal._disable_emit_hooks_decorator
|
||||
@_inline_everything
|
||||
def test_fuse_decompose_normalization(self):
|
||||
class ResLike(torch.jit.ScriptModule):
|
||||
|
|
|
|||
|
|
@ -507,7 +507,7 @@ class TestFuser(JitTestCase):
|
|||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
|
||||
@torch.jit._disable_emit_hooks_decorator
|
||||
@torch._jit_internal._disable_emit_hooks_decorator
|
||||
@_inline_everything
|
||||
def test_fuse_decompose_normalization(self):
|
||||
class ResLike(torch.jit.ScriptModule):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ can be used in other places in torch/ (namely torch.nn) without running into
|
|||
circular dependency problems
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import collections
|
||||
import inspect
|
||||
import weakref
|
||||
import warnings
|
||||
|
|
@ -767,3 +769,45 @@ class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
|
|||
|
||||
def fake_range():
|
||||
return SourceContext('', None, 0, 0).make_raw_range(0, 1)
|
||||
|
||||
|
||||
def _try_get_dispatched_fn(fn):
|
||||
if not callable(fn):
|
||||
return None
|
||||
return boolean_dispatched.get(fn)
|
||||
|
||||
|
||||
def _get_named_tuple_properties(obj):
|
||||
assert issubclass(obj, tuple) and hasattr(obj, '_fields')
|
||||
fields = list(obj._fields)
|
||||
annotations = []
|
||||
has_annotations = hasattr(obj, '__annotations__')
|
||||
for field in fields:
|
||||
if has_annotations and field in obj.__annotations__:
|
||||
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range())
|
||||
annotations.append(the_type)
|
||||
else:
|
||||
annotations.append(torch._C.TensorType.get())
|
||||
return type(obj).__name__, fields, annotations
|
||||
|
||||
|
||||
def _create_named_tuple(t, unqual_name, field_names):
|
||||
TupleType = collections.namedtuple(unqual_name, field_names)
|
||||
return TupleType(*t)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _disable_emit_hooks():
|
||||
hooks = torch._C._jit_get_emit_hooks()
|
||||
torch._C._jit_set_emit_hooks(None, None)
|
||||
yield
|
||||
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
|
||||
|
||||
|
||||
def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811
|
||||
def __enter__(self):
|
||||
self.hooks = torch._C._jit_get_emit_hooks()
|
||||
torch._C._jit_set_emit_hooks(None, None)
|
||||
|
||||
def __exit__(self, *args):
|
||||
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class _OpNamespace(types.ModuleType):
|
|||
op = torch._C._jit_get_operation(qualified_op_name)
|
||||
# let the script frontend know that op is identical to the builtin op
|
||||
# with qualified_op_name
|
||||
torch.jit._register_builtin(op, qualified_op_name)
|
||||
torch.jit._builtins._register_builtin(op, qualified_op_name)
|
||||
setattr(self, op_name, op)
|
||||
op.__module__ = self.__module__ + "." + self.name
|
||||
return op
|
||||
|
|
|
|||
|
|
@ -839,7 +839,7 @@ inline py::object toPyObject(IValue ivalue) {
|
|||
auto fieldNames = fmap(
|
||||
tuple->type()->schema()->arguments(),
|
||||
[](const Argument& arg) { return arg.name(); });
|
||||
return py::module::import("torch.jit")
|
||||
return py::module::import("torch._jit_internal")
|
||||
.attr("_create_named_tuple")(t, unqualName, fieldNames);
|
||||
} else {
|
||||
return std::move(t);
|
||||
|
|
|
|||
|
|
@ -686,8 +686,8 @@ TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc) {
|
|||
}
|
||||
}
|
||||
|
||||
py::object props =
|
||||
py::module::import("torch.jit").attr("_get_named_tuple_properties")(obj);
|
||||
py::object props = py::module::import("torch._jit_internal")
|
||||
.attr("_get_named_tuple_properties")(obj);
|
||||
std::string unqualName;
|
||||
std::vector<std::string> fields;
|
||||
std::vector<TypePtr> annotations;
|
||||
|
|
@ -788,7 +788,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
|||
}
|
||||
|
||||
py::object builtin_name =
|
||||
py::module::import("torch.jit").attr("_find_builtin")(obj);
|
||||
py::module::import("torch.jit._builtins").attr("_find_builtin")(obj);
|
||||
if (!builtin_name.is_none()) {
|
||||
return std::make_shared<BuiltinFunction>(
|
||||
Symbol::fromQualString(py::str(builtin_name)), c10::nullopt);
|
||||
|
|
@ -801,8 +801,8 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
|||
}
|
||||
}
|
||||
|
||||
py::object dispatched_fn =
|
||||
py::module::import("torch.jit").attr("_try_get_dispatched_fn")(obj);
|
||||
py::object dispatched_fn = py::module::import("torch._jit_internal")
|
||||
.attr("_try_get_dispatched_fn")(obj);
|
||||
if (!dispatched_fn.is_none()) {
|
||||
return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ def _wait_all_workers():
|
|||
|
||||
is_leader_worker = leader_worker_name == self_worker_name
|
||||
# Set a long enough timeout for all shutdown messages to be processed.
|
||||
timeout = 5 # seconds
|
||||
timeout = 5 # second
|
||||
|
||||
# Phase 1: Followers send intents.
|
||||
# All followers report intents to the leader.
|
||||
|
|
@ -522,7 +522,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
|||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.shutdown()
|
||||
"""
|
||||
qualified_name = torch.jit._find_builtin(func)
|
||||
qualified_name = torch.jit._builtins._find_builtin(func)
|
||||
dst_worker_info = _to_worker_info(to)
|
||||
should_profile = torch.autograd._profiler_enabled()
|
||||
|
||||
|
|
@ -594,7 +594,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
|
|||
if not callable(func):
|
||||
raise TypeError("function should be callable.")
|
||||
|
||||
qualified_name = torch.jit._find_builtin(func)
|
||||
qualified_name = torch.jit._builtins._find_builtin(func)
|
||||
dst_worker_info = _to_worker_info(to)
|
||||
|
||||
# TODO: profiling logic does not really belong in invoke_rpc, it should be
|
||||
|
|
|
|||
|
|
@ -1,93 +1,52 @@
|
|||
import torch._C
|
||||
import torch._jit_internal as _jit_internal
|
||||
|
||||
from torch.jit._builtins import _find_builtin, _get_builtin_table, _register_builtin # noqa
|
||||
from torch._jit_internal import Future
|
||||
from torch.nn import Module
|
||||
from torch.utils import set_module
|
||||
from torch.autograd.grad_mode import _DecoratorContextManager
|
||||
from typing import Optional, List
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
# These are imported so users can access them from the `torch.jit` module
|
||||
from torch._jit_internal import Final, _overload, _overload_method
|
||||
from torch._jit_internal import ignore, export, unused
|
||||
from torch.jit._script import script, Attribute, ScriptModule, is_scripting, script_method, \
|
||||
RecursiveScriptModule, ScriptWarning, interface
|
||||
from torch.jit._trace import trace, trace_module, TracedModule, TracerWarning, TracingCheckError, \
|
||||
is_tracing, ONNXTracedModule, _unique_state_dict, _flatten, TopLevelTracedModule
|
||||
from torch._jit_internal import (
|
||||
Final,
|
||||
Future,
|
||||
_overload,
|
||||
_overload_method,
|
||||
ignore,
|
||||
export,
|
||||
unused,
|
||||
)
|
||||
from torch.jit._script import (
|
||||
script,
|
||||
Attribute,
|
||||
ScriptModule,
|
||||
is_scripting,
|
||||
script_method,
|
||||
RecursiveScriptModule,
|
||||
ScriptWarning,
|
||||
interface,
|
||||
CompilationUnit,
|
||||
ScriptFunction,
|
||||
_unwrap_optional,
|
||||
)
|
||||
from torch.jit._trace import (
|
||||
trace,
|
||||
trace_module,
|
||||
TracedModule,
|
||||
TracerWarning,
|
||||
TracingCheckError,
|
||||
is_tracing,
|
||||
ONNXTracedModule,
|
||||
TopLevelTracedModule,
|
||||
_unique_state_dict,
|
||||
_flatten,
|
||||
_script_if_tracing,
|
||||
_get_trace_graph,
|
||||
)
|
||||
from torch.jit._async import fork, wait
|
||||
from torch.jit._serialization import save, load
|
||||
|
||||
set_module(Future, "torch.jit")
|
||||
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
|
||||
|
||||
# For backwards compatibility
|
||||
_fork = fork
|
||||
_wait = wait
|
||||
|
||||
@contextlib.contextmanager
|
||||
def optimized_execution(should_optimize):
|
||||
"""
|
||||
A context manager that controls whether the JIT's executor will run
|
||||
optimizations before executing a function.
|
||||
"""
|
||||
stored_flag = torch._C._get_graph_executor_optimize()
|
||||
torch._C._set_graph_executor_optimize(should_optimize)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._C._set_graph_executor_optimize(stored_flag)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fuser(name):
|
||||
"""
|
||||
A context manager that facilitates switching between
|
||||
backend fusers.
|
||||
|
||||
Valid names:
|
||||
* ``fuser0`` - enables only legacy fuser
|
||||
* ``fuser1`` - enables only NNC
|
||||
* ``fuser2`` - enables only nvFuser
|
||||
"""
|
||||
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
|
||||
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
|
||||
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
|
||||
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
|
||||
if name == 'fuser0': # legacy fuser
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
elif name == 'fuser1': # NNC
|
||||
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
|
||||
old_profiling_mode = torch._C._jit_set_profiling_mode(True)
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(True)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
elif name == 'fuser2': # nvFuser
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(True)
|
||||
else:
|
||||
raise Exception("unrecognized fuser option")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if name == 'fuser1': # NNC
|
||||
torch._C._jit_set_profiling_executor(old_profiling_executor)
|
||||
torch._C._jit_set_profiling_mode(old_profiling_mode)
|
||||
# recover the previous values
|
||||
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
|
||||
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
|
||||
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
|
||||
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
|
||||
|
||||
def export_opnames(m):
|
||||
r"""
|
||||
|
|
@ -95,212 +54,6 @@ def export_opnames(m):
|
|||
"""
|
||||
return torch._C._export_opnames(m._c)
|
||||
|
||||
def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
|
||||
return_inputs=False, _return_inputs_states=False):
|
||||
"""
|
||||
.. warning::
|
||||
This function is internal-only and should only be used by the ONNX
|
||||
exporter. If you are trying to get a graph through tracing, please go
|
||||
through the public API instead::
|
||||
|
||||
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
|
||||
trace_graph = trace.graph
|
||||
|
||||
Trace a function or model, returning a tuple consisting of the both the
|
||||
*trace* of an execution, as well as the original return value. If return_inputs,
|
||||
also returns the trace inputs as part of the tuple
|
||||
|
||||
Tracing is guaranteed not to change the semantics of the function/module
|
||||
that is traced.
|
||||
|
||||
Arguments:
|
||||
f (torch.nn.Module or function): the function or module
|
||||
to be traced.
|
||||
args (tuple or Tensor): the positional arguments to pass to the
|
||||
function/module to be traced. A non-tuple is assumed to
|
||||
be a single positional argument to be passed to the model.
|
||||
kwargs (dict): the keyword arguments to pass to the function/module
|
||||
to be traced.
|
||||
|
||||
Example (trace a cell):
|
||||
|
||||
.. testcode::
|
||||
|
||||
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if not isinstance(args, tuple):
|
||||
args = (args,)
|
||||
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
|
||||
return outs
|
||||
|
||||
|
||||
def freeze(mod, preserved_attrs : Optional[List[str]] = None):
|
||||
r"""
|
||||
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
|
||||
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
|
||||
By default, `forward` will be preserved, as well as attributes & methods specified in
|
||||
`preserved_attrs`. Additionally, any attribute that is modified within a preserved
|
||||
method will be preserved.
|
||||
|
||||
Freezing currently only accepts ScriptModules that are in eval mode.
|
||||
|
||||
Arguments:
|
||||
mod (:class:`ScriptModule`): a module to be frozen
|
||||
|
||||
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
|
||||
Attributes modified in preserved methods will also be preserved.
|
||||
|
||||
Returns:
|
||||
Frozen :class:`ScriptModule`.
|
||||
|
||||
Example (Freezing a simple module with a Parameter):
|
||||
|
||||
.. testcode::
|
||||
import torch
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self, N, M):
|
||||
super(MyModule, self).__init__()
|
||||
self.weight = torch.nn.Parameter(torch.rand(N, M))
|
||||
self.linear = torch.nn.Linear(N, M)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.weight.mm(input)
|
||||
output = self.linear(output)
|
||||
return output
|
||||
|
||||
scripted_module = torch.jit.script(MyModule(2, 3).eval())
|
||||
frozen_module = torch.jit.freeze(scripted_module)
|
||||
# parameters have been removed and inlined into the Graph as constants
|
||||
assert len(list(frozen_module.named_parameters())) == 0
|
||||
# See the compiled graph as Python code
|
||||
print(frozen_module.code)
|
||||
|
||||
Example (Freezing a module with preserved attributes)
|
||||
|
||||
.. testcode::
|
||||
import torch
|
||||
class MyModule2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule2, self).__init__()
|
||||
self.modified_tensor = torch.tensor(10.)
|
||||
self.version = 1
|
||||
|
||||
def forward(self, input):
|
||||
self.modified_tensor += 1
|
||||
return input + self.modified_tensor
|
||||
|
||||
scripted_module = torch.jit.script(MyModule2().eval())
|
||||
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
|
||||
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
|
||||
assert frozen_module.version == 1
|
||||
frozen_module.version = 2
|
||||
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
|
||||
# it to retain model semantics
|
||||
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
|
||||
# now that we've run it once, the next result will be incremented by one
|
||||
assert frozen_module(torch.tensor(1)) == torch.tensor(13)
|
||||
|
||||
Note:
|
||||
If you're not sure why an attribute is not being inlined as a constant, you can run
|
||||
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
|
||||
attribute is being modified.
|
||||
"""
|
||||
if not isinstance(mod, ScriptModule):
|
||||
raise RuntimeError("Freezing expects a ScriptModule as input. "
|
||||
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'.")
|
||||
|
||||
if mod.training:
|
||||
raise RuntimeError("Freezing is currently only implemented for modules in eval mode. "
|
||||
"Please call .eval() on your module before freezing.")
|
||||
|
||||
preserved_attrs = preserved_attrs if preserved_attrs is not None else []
|
||||
|
||||
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
|
||||
RecursiveScriptModule._finalize_scriptmodule(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CompilationUnit(object):
|
||||
def __init__(self, lang=None, _frames_up=0):
|
||||
self._c = torch._C.CompilationUnit()
|
||||
if lang is not None:
|
||||
self.define(lang, _frames_up=_frames_up + 1)
|
||||
|
||||
def define(self, lang, rcb=None, _frames_up=0):
|
||||
if not rcb:
|
||||
rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
|
||||
self._c.define(lang, rcb)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
r = self._c.find_function(attr)
|
||||
if r is None:
|
||||
raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
|
||||
return r
|
||||
|
||||
|
||||
def _try_get_dispatched_fn(fn):
|
||||
if not callable(fn):
|
||||
return None
|
||||
return _jit_internal.boolean_dispatched.get(fn)
|
||||
|
||||
|
||||
def _try_get_overloaded_fn(mod, field):
|
||||
return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _disable_emit_hooks():
|
||||
hooks = torch._C._jit_get_emit_hooks()
|
||||
torch._C._jit_set_emit_hooks(None, None)
|
||||
yield
|
||||
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
|
||||
|
||||
|
||||
def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811
|
||||
def __enter__(self):
|
||||
self.hooks = torch._C._jit_get_emit_hooks()
|
||||
torch._C._jit_set_emit_hooks(None, None)
|
||||
|
||||
def __exit__(self, *args):
|
||||
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
|
||||
|
||||
|
||||
def _script_if_tracing(fn):
|
||||
"""
|
||||
Compiles ``fn`` when it is first called during tracing. ``torch.jit.script``
|
||||
has a non-negligible start up time when it is first called due to
|
||||
lazy-initializations of many compiler builtins. Therefore you should not use
|
||||
it in library code. However, you may want to have parts of your library work
|
||||
in tracing even if they use control flow. In these cases, you should use
|
||||
``@torch.jit._script_if_tracing`` to substitute for
|
||||
``torch.jit.script``.
|
||||
"""
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not is_tracing():
|
||||
# Not tracing, don't do anything
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
compiled_fn = script(wrapper.__original_fn)
|
||||
return compiled_fn(*args, **kwargs)
|
||||
|
||||
wrapper.__original_fn = fn
|
||||
wrapper.__script_if_tracing_wrapper = True
|
||||
|
||||
return wrapper
|
||||
|
||||
def _unwrap_optional(x):
|
||||
assert x is not None, "Unwrapping null optional"
|
||||
return x
|
||||
|
||||
_register_builtin(_unwrap_optional, 'aten::_unwrap_optional')
|
||||
_register_builtin(_wait, 'aten::wait')
|
||||
_register_builtin(wait, 'aten::wait')
|
||||
_register_builtin(is_scripting, 'aten::is_scripting')
|
||||
|
||||
|
||||
# torch.jit.Error
|
||||
Error = torch._C.JITException
|
||||
|
|
@ -309,53 +62,11 @@ set_module(Error, "torch.jit")
|
|||
Error.__name__ = "Error"
|
||||
Error.__qualname__ = "Error"
|
||||
|
||||
def _get_named_tuple_properties(obj):
|
||||
assert issubclass(obj, tuple) and hasattr(obj, '_fields')
|
||||
fields = list(obj._fields)
|
||||
annotations = []
|
||||
has_annotations = hasattr(obj, '__annotations__')
|
||||
for field in fields:
|
||||
if has_annotations and field in obj.__annotations__:
|
||||
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], _jit_internal.fake_range())
|
||||
annotations.append(the_type)
|
||||
else:
|
||||
annotations.append(torch._C.TensorType.get())
|
||||
return type(obj).__name__, fields, annotations
|
||||
|
||||
def _create_named_tuple(t, unqual_name, field_names):
|
||||
TupleType = collections.namedtuple(unqual_name, field_names)
|
||||
return TupleType(*t)
|
||||
|
||||
class _disable_tracing(object):
|
||||
def __enter__(self):
|
||||
self.state = torch._C._get_tracing_state()
|
||||
torch._C._set_tracing_state(None)
|
||||
|
||||
def __exit__(self, *args):
|
||||
torch._C._set_tracing_state(self.state)
|
||||
self.state = None
|
||||
|
||||
|
||||
# for use in python if using annotate
|
||||
def annotate(the_type, the_value):
|
||||
# noop in python
|
||||
return the_value
|
||||
|
||||
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
|
||||
|
||||
|
||||
def _graph_for(self, *args, **kwargs):
|
||||
self(*args, **kwargs)
|
||||
return last_executed_optimized_graph()
|
||||
|
||||
torch._C.ScriptMethod.graph_for = _graph_for
|
||||
torch._C.ScriptFunction.graph_for = _graph_for
|
||||
ScriptFunction = torch._C.ScriptFunction
|
||||
ScriptFunction.__doc__ = """
|
||||
Functionally equivalent to a :class:`ScriptModule`, but represents a single
|
||||
function and does not have any attributes or Parameters.
|
||||
"""
|
||||
set_module(ScriptFunction, "torch.jit")
|
||||
|
||||
if not torch._C._jit_init():
|
||||
raise RuntimeError("JIT initialization failed")
|
||||
|
|
|
|||
|
|
@ -9,6 +9,12 @@ functionalities in `torch.jit`.
|
|||
|
||||
import torch
|
||||
|
||||
from torch.utils import set_module
|
||||
from torch.jit._builtins import _register_builtin
|
||||
from torch._jit_internal import Future
|
||||
|
||||
set_module(Future, "torch.jit")
|
||||
|
||||
|
||||
def fork(func, *args, **kwargs):
|
||||
"""
|
||||
|
|
@ -84,3 +90,6 @@ def wait(future):
|
|||
`T`: the return value of the the completed task
|
||||
"""
|
||||
return torch._C.wait(future)
|
||||
|
||||
|
||||
_register_builtin(wait, "aten::wait")
|
||||
|
|
|
|||
101
torch/jit/_freeze.py
Normal file
101
torch/jit/_freeze.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
"""Freezing
|
||||
|
||||
This is not intended to be imported directly; please use the exposed
|
||||
functionalities in `torch.jit`.
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
from torch.jit._script import RecursiveScriptModule, ScriptModule
|
||||
|
||||
|
||||
def freeze(mod, preserved_attrs: Optional[List[str]] = None):
|
||||
r"""
|
||||
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
|
||||
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
|
||||
By default, `forward` will be preserved, as well as attributes & methods specified in
|
||||
`preserved_attrs`. Additionally, any attribute that is modified within a preserved
|
||||
method will be preserved.
|
||||
|
||||
Freezing currently only accepts ScriptModules that are in eval mode.
|
||||
|
||||
Arguments:
|
||||
mod (:class:`ScriptModule`): a module to be frozen
|
||||
|
||||
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
|
||||
Attributes modified in preserved methods will also be preserved.
|
||||
|
||||
Returns:
|
||||
Frozen :class:`ScriptModule`.
|
||||
|
||||
Example (Freezing a simple module with a Parameter):
|
||||
|
||||
.. testcode::
|
||||
import torch
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self, N, M):
|
||||
super(MyModule, self).__init__()
|
||||
self.weight = torch.nn.Parameter(torch.rand(N, M))
|
||||
self.linear = torch.nn.Linear(N, M)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.weight.mm(input)
|
||||
output = self.linear(output)
|
||||
return output
|
||||
|
||||
scripted_module = torch.jit.script(MyModule(2, 3).eval())
|
||||
frozen_module = torch.jit.freeze(scripted_module)
|
||||
# parameters have been removed and inlined into the Graph as constants
|
||||
assert len(list(frozen_module.named_parameters())) == 0
|
||||
# See the compiled graph as Python code
|
||||
print(frozen_module.code)
|
||||
|
||||
Example (Freezing a module with preserved attributes)
|
||||
|
||||
.. testcode::
|
||||
import torch
|
||||
class MyModule2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule2, self).__init__()
|
||||
self.modified_tensor = torch.tensor(10.)
|
||||
self.version = 1
|
||||
|
||||
def forward(self, input):
|
||||
self.modified_tensor += 1
|
||||
return input + self.modified_tensor
|
||||
|
||||
scripted_module = torch.jit.script(MyModule2().eval())
|
||||
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
|
||||
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
|
||||
assert frozen_module.version == 1
|
||||
frozen_module.version = 2
|
||||
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
|
||||
# it to retain model semantics
|
||||
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
|
||||
# now that we've run it once, the next result will be incremented by one
|
||||
assert frozen_module(torch.tensor(1)) == torch.tensor(13)
|
||||
|
||||
Note:
|
||||
If you're not sure why an attribute is not being inlined as a constant, you can run
|
||||
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
|
||||
attribute is being modified.
|
||||
"""
|
||||
if not isinstance(mod, ScriptModule):
|
||||
raise RuntimeError(
|
||||
"Freezing expects a ScriptModule as input. "
|
||||
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
|
||||
)
|
||||
|
||||
if mod.training:
|
||||
raise RuntimeError(
|
||||
"Freezing is currently only implemented for modules in eval mode. "
|
||||
"Please call .eval() on your module before freezing."
|
||||
)
|
||||
|
||||
preserved_attrs = preserved_attrs if preserved_attrs is not None else []
|
||||
|
||||
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
|
||||
RecursiveScriptModule._finalize_scriptmodule(out)
|
||||
|
||||
return out
|
||||
70
torch/jit/_fuser.py
Normal file
70
torch/jit/_fuser.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
@contextlib.contextmanager
|
||||
def optimized_execution(should_optimize):
|
||||
"""
|
||||
A context manager that controls whether the JIT's executor will run
|
||||
optimizations before executing a function.
|
||||
"""
|
||||
stored_flag = torch._C._get_graph_executor_optimize()
|
||||
torch._C._set_graph_executor_optimize(should_optimize)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._C._set_graph_executor_optimize(stored_flag)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fuser(name):
|
||||
"""
|
||||
A context manager that facilitates switching between
|
||||
backend fusers.
|
||||
|
||||
Valid names:
|
||||
* ``fuser0`` - enables only legacy fuser
|
||||
* ``fuser1`` - enables only NNC
|
||||
* ``fuser2`` - enables only nvFuser
|
||||
"""
|
||||
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
|
||||
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
|
||||
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
|
||||
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
|
||||
if name == 'fuser0': # legacy fuser
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
elif name == 'fuser1': # NNC
|
||||
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
|
||||
old_profiling_mode = torch._C._jit_set_profiling_mode(True)
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(True)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
elif name == 'fuser2': # nvFuser
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(True)
|
||||
else:
|
||||
raise Exception("unrecognized fuser option")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if name == 'fuser1': # NNC
|
||||
torch._C._jit_set_profiling_executor(old_profiling_executor)
|
||||
torch._C._jit_set_profiling_mode(old_profiling_mode)
|
||||
# recover the previous values
|
||||
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
|
||||
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
|
||||
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
|
||||
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
|
||||
|
||||
|
||||
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
|
||||
|
||||
|
||||
def _graph_for(self, *args, **kwargs):
|
||||
self(*args, **kwargs)
|
||||
return last_executed_optimized_graph()
|
||||
|
|
@ -609,7 +609,7 @@ def compile_unbound_method(concrete_type, fn):
|
|||
if _jit_internal.is_ignored_fn(fn):
|
||||
return None
|
||||
stub = make_stub(fn, fn.__name__)
|
||||
with torch.jit._disable_emit_hooks():
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
# We don't want to call the hooks here since the graph that is calling
|
||||
# this function is not yet complete
|
||||
create_methods_from_stubs(concrete_type, (stub,))
|
||||
|
|
|
|||
|
|
@ -15,12 +15,15 @@ import warnings
|
|||
|
||||
import torch
|
||||
import torch._jit_internal as _jit_internal
|
||||
from torch.utils import set_module
|
||||
from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module
|
||||
from torch.nn import Module
|
||||
from torch.jit._state import _enabled
|
||||
from torch.jit._builtins import _register_builtin
|
||||
from torch._six import with_metaclass, get_function_from_type
|
||||
from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def
|
||||
from torch._jit_internal import _qualified_name
|
||||
from torch.jit._fuser import _graph_for
|
||||
from torch.jit._state import (
|
||||
_try_get_jit_cached_function,
|
||||
_try_get_jit_cached_overloads,
|
||||
|
|
@ -28,6 +31,16 @@ from torch.jit._state import (
|
|||
_set_jit_overload_cache,
|
||||
)
|
||||
|
||||
torch._C.ScriptMethod.graph_for = _graph_for
|
||||
torch._C.ScriptFunction.graph_for = _graph_for
|
||||
ScriptFunction = torch._C.ScriptFunction
|
||||
ScriptFunction.__doc__ = """
|
||||
Functionally equivalent to a :class:`ScriptModule`, but represents a single
|
||||
function and does not have any attributes or Parameters.
|
||||
"""
|
||||
set_module(ScriptFunction, "torch.jit")
|
||||
|
||||
|
||||
if _enabled:
|
||||
Attribute = collections.namedtuple("Attribute", ["value", "type"])
|
||||
else:
|
||||
|
|
@ -1053,3 +1066,32 @@ def _recursive_compile_class(obj, loc):
|
|||
error_stack = torch._C.CallStack(_qual_name, loc)
|
||||
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
|
||||
_compile_and_register_class(obj, rcb, _qual_name)
|
||||
|
||||
|
||||
_register_builtin(is_scripting, "aten::is_scripting")
|
||||
|
||||
|
||||
class CompilationUnit(object):
|
||||
def __init__(self, lang=None, _frames_up=0):
|
||||
self._c = torch._C.CompilationUnit()
|
||||
if lang is not None:
|
||||
self.define(lang, _frames_up=_frames_up + 1)
|
||||
|
||||
def define(self, lang, rcb=None, _frames_up=0):
|
||||
if not rcb:
|
||||
rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
|
||||
self._c.define(lang, rcb)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
r = self._c.find_function(attr)
|
||||
if r is None:
|
||||
raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
|
||||
return r
|
||||
|
||||
|
||||
def _unwrap_optional(x):
|
||||
assert x is not None, "Unwrapping null optional"
|
||||
return x
|
||||
|
||||
|
||||
_register_builtin(_unwrap_optional, "aten::_unwrap_optional")
|
||||
|
|
|
|||
|
|
@ -11,12 +11,13 @@ import torch
|
|||
|
||||
import os
|
||||
import contextlib
|
||||
import functools
|
||||
import warnings
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from torch.jit._state import _python_cu, _enabled
|
||||
from torch.jit._script import ScriptModule, _CachedForward
|
||||
from torch.jit._script import ScriptModule, _CachedForward, script
|
||||
from torch._jit_internal import _qualified_name
|
||||
from torch.autograd import function
|
||||
from torch import _jit_internal
|
||||
|
|
@ -1077,3 +1078,70 @@ class TopLevelTracedModule(TracedModule):
|
|||
cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around.
|
||||
"""
|
||||
self.__dict__["_actual_script_module"]._reconstruct(cpp_module)
|
||||
|
||||
|
||||
def _script_if_tracing(fn):
|
||||
"""
|
||||
Compiles ``fn`` when it is first called during tracing. ``torch.jit.script``
|
||||
has a non-negligible start up time when it is first called due to
|
||||
lazy-initializations of many compiler builtins. Therefore you should not use
|
||||
it in library code. However, you may want to have parts of your library work
|
||||
in tracing even if they use control flow. In these cases, you should use
|
||||
``@torch.jit._script_if_tracing`` to substitute for
|
||||
``torch.jit.script``.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not is_tracing():
|
||||
# Not tracing, don't do anything
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
compiled_fn = script(wrapper.__original_fn)
|
||||
return compiled_fn(*args, **kwargs)
|
||||
|
||||
wrapper.__original_fn = fn
|
||||
wrapper.__script_if_tracing_wrapper = True
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
|
||||
return_inputs=False, _return_inputs_states=False):
|
||||
"""
|
||||
.. warning::
|
||||
This function is internal-only and should only be used by the ONNX
|
||||
exporter. If you are trying to get a graph through tracing, please go
|
||||
through the public API instead::
|
||||
|
||||
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
|
||||
trace_graph = trace.graph
|
||||
|
||||
Trace a function or model, returning a tuple consisting of the both the
|
||||
*trace* of an execution, as well as the original return value. If return_inputs,
|
||||
also returns the trace inputs as part of the tuple
|
||||
|
||||
Tracing is guaranteed not to change the semantics of the function/module
|
||||
that is traced.
|
||||
|
||||
Arguments:
|
||||
f (torch.nn.Module or function): the function or module
|
||||
to be traced.
|
||||
args (tuple or Tensor): the positional arguments to pass to the
|
||||
function/module to be traced. A non-tuple is assumed to
|
||||
be a single positional argument to be passed to the model.
|
||||
kwargs (dict): the keyword arguments to pass to the function/module
|
||||
to be traced.
|
||||
|
||||
Example (trace a cell):
|
||||
|
||||
.. testcode::
|
||||
|
||||
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if not isinstance(args, tuple):
|
||||
args = (args,)
|
||||
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
|
||||
return outs
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import torch.jit
|
||||
from torch.jit._builtins import _find_builtin
|
||||
import inspect
|
||||
import textwrap
|
||||
# this file is for generating documentation using sphinx autodoc
|
||||
|
|
@ -92,7 +93,7 @@ def _get_nn_functional_ops():
|
|||
for mod in torch.jit._builtins._modules_containing_builtins:
|
||||
name = mod.__name__
|
||||
for elem in dir(mod):
|
||||
builtin = torch.jit._find_builtin(getattr(mod, elem))
|
||||
builtin = _find_builtin(getattr(mod, elem))
|
||||
if builtin is not None:
|
||||
schemas = torch._C._jit_get_schemas_for_operator(builtin)
|
||||
for schema in schemas:
|
||||
|
|
@ -133,7 +134,7 @@ def _get_torchscript_builtins():
|
|||
# Iterate over the specially added builtins
|
||||
for fn, _builtin_name in builtins:
|
||||
mod = inspect.getmodule(fn)
|
||||
builtin = torch.jit._find_builtin(fn)
|
||||
builtin = _find_builtin(fn)
|
||||
if builtin is not None:
|
||||
schemas = torch._C._jit_get_schemas_for_operator(builtin)
|
||||
for schema in schemas:
|
||||
|
|
@ -150,7 +151,7 @@ def _get_math_builtins():
|
|||
# Iterate over the specially added builtins
|
||||
for fn, _builtin_name in builtins:
|
||||
mod = inspect.getmodule(fn)
|
||||
builtin = torch.jit._find_builtin(fn)
|
||||
builtin = _find_builtin(fn)
|
||||
if builtin is not None:
|
||||
schemas = torch._C._jit_get_schemas_for_operator(builtin)
|
||||
for schema in schemas:
|
||||
|
|
|
|||
|
|
@ -1222,7 +1222,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||
events = prof.function_events
|
||||
|
||||
rpc_mul_event = get_function_event(
|
||||
events, torch.jit._find_builtin(torch.mul)
|
||||
events, torch.jit._builtins._find_builtin(torch.mul)
|
||||
)
|
||||
|
||||
remote_events = {
|
||||
|
|
|
|||
|
|
@ -358,7 +358,7 @@ def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name
|
|||
|
||||
f_args_variable = (self_variable,) + args_variable
|
||||
f_args_tensor = (self_tensor,) + args_tensor
|
||||
with torch.jit._disable_emit_hooks():
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable)
|
||||
return script_fn, inputs
|
||||
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ class JitTestCase(TestCase):
|
|||
return code_files, debug_files
|
||||
|
||||
# disable the hook while we parse code, otherwise we will re-enter the hook
|
||||
with torch.jit._disable_emit_hooks():
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
try:
|
||||
# short-circuit if this is an empty function or module
|
||||
if len(m.code) == 0:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user