Merge script and _script_pdt API (#62420)

Summary:
Merge `torch.jit.script` and `torch.jit._script_pdt` API. This PR merges profile directed typing with script api

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62420

Reviewed By: iramazanli

Differential Revision: D30579015

Pulled By: nikithamalgifb

fbshipit-source-id: 99ba6839d235d61b2dd0144b466b2063a53ccece
This commit is contained in:
nikithamalgi 2021-08-26 18:54:51 -07:00 committed by Facebook GitHub Bot
parent 0e8c3c51d9
commit 510d2ece81
3 changed files with 124 additions and 93 deletions

View File

@ -40,7 +40,7 @@ class TestPDT(JitTestCase):
make_global(TestPDTModel)
pdt_model = TestPDTModel()
inp: List[Tuple[Any, ...]] = [(20, ), (2.7, ), (False, ), ]
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp})
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp})
self.assertEqual(scripted_pdt_model(50), pdt_model(50))
self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8))
self.assertTrue(scripted_pdt_model(True), pdt_model(True))
@ -67,7 +67,7 @@ class TestPDT(JitTestCase):
inner_pdt_model = NestedPDTInner()
wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model)
inp: List[Tuple[Any, ...]] = [(20, ), (False, )]
scripted_pdt_model = torch.jit._script_pdt(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp})
scripted_pdt_model = torch.jit.script(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp})
self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30))
self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9))
self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True))
@ -95,7 +95,7 @@ class TestPDT(JitTestCase):
outer_pdt_model = NestedModulePDTOuter(inner_pdt_model)
inner_input: List[Tuple[Any, ...]] = [(10, 10), (1.9, 20), ]
outer_input: List[Tuple[Any, ...]] = [(20, ), (False, )]
scripted_pdt_model = torch.jit._script_pdt(outer_pdt_model, example_inputs={inner_pdt_model: inner_input,
scripted_pdt_model = torch.jit.script(outer_pdt_model, example_inputs={inner_pdt_model: inner_input,
outer_pdt_model: outer_input, })
self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30))
self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9))
@ -119,7 +119,7 @@ class TestPDT(JitTestCase):
make_global(NestedFunctionInForward)
pdt_model = NestedFunctionInForward()
inp: List[Tuple[Any, ...]] = [(-1, ), (False, )]
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp})
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp})
self.assertEqual(scripted_pdt_model(30), pdt_model(30))
self.assertEqual(scripted_pdt_model(True), pdt_model(True))
@ -142,7 +142,7 @@ class TestPDT(JitTestCase):
make_global(TestModelWithExport)
pdt_model = TestModelWithExport()
inp: List[Tuple[Any, ...]] = [(20, 10, ), (2.7, 8.9, ), ]
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model.fn: inp})
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model.fn: inp})
self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90))
self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2))
self.assertTrue(scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2))
@ -155,7 +155,7 @@ class TestPDT(JitTestCase):
make_global(PDTModel)
pdt_model = PDTModel()
inp: List[Tuple[Any, ...]] = [([10, 20, ], ), ]
scripted_pdt_model = torch.jit._script_pdt(PDTModel, example_inputs={pdt_model.test_sum: inp})
scripted_pdt_model = torch.jit.script(PDTModel, example_inputs={pdt_model.test_sum: inp})
script_model = scripted_pdt_model()
self.assertEqual(script_model.test_sum([10, 20, 30, ], ), pdt_model.test_sum([10, 20, 30, ], ))
@ -174,7 +174,7 @@ class TestPDT(JitTestCase):
pdt_model = PDTModelWithManyMethods()
list_inp: List[Tuple[Any, ...]] = [([1.2, 2.3, ], ), ]
str_inp: List[Tuple[Any, ...]] = [("abc", "b", ), ]
scripted_pdt_model = torch.jit._script_pdt(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp,
scripted_pdt_model = torch.jit.script(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp,
pdt_model.test_substring: str_inp})
script_model = scripted_pdt_model()
self.assertEqual(script_model.test_list_to_dict([1.1, 2.2, 3.3, ], ), pdt_model.test_list_to_dict([1.1, 2.2, 3.3, ], ))
@ -195,8 +195,8 @@ class TestPDT(JitTestCase):
pdt_model_two = PDTModelTwo()
dict_inp: List[Tuple[Any, ...]] = [({1.2: True, 2.3: False, }, 1.2), ]
list_inp: List[Tuple[Any, ...]] = [(["abc", "b", ], "c"), ]
scripted_pdt_model_one = torch.jit._script_pdt(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp})
scripted_pdt_model_two = torch.jit._script_pdt(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp})
scripted_pdt_model_one = torch.jit.script(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp})
scripted_pdt_model_two = torch.jit.script(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp})
script_model_one, script_model_two = scripted_pdt_model_one(), scripted_pdt_model_two()
self.assertEqual(script_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4),
@ -209,28 +209,28 @@ class TestPDT(JitTestCase):
return a + b
make_global(test_sum)
scripted_fn_add = torch.jit._script_pdt(test_sum, example_inputs=[(3, 4)])
scripted_fn_add = torch.jit.script(test_sum, example_inputs=[(3, 4)])
self.assertEqual(scripted_fn_add(10, 2), test_sum(10, 2))
def test_sub(a, b):
return a - b
make_global(test_sub)
scripted_fn_sub = torch.jit._script_pdt(test_sub, example_inputs=[(3.9, 4.10)])
scripted_fn_sub = torch.jit.script(test_sub, example_inputs=[(3.9, 4.10)])
self.assertEqual(scripted_fn_sub(6.5, 2.9), test_sub(6.5, 2.9))
def test_mul(a, b):
return a * b
make_global(test_mul)
scripted_fn_mul = torch.jit._script_pdt(test_mul, example_inputs=[(-10, 9)])
scripted_fn_mul = torch.jit.script(test_mul, example_inputs=[(-10, 9)])
self.assertEqual(scripted_fn_mul(-1, 3), test_mul(-1, 3))
def test_args_complex(real, img):
return torch.complex(real, img)
make_global(test_args_complex)
scripted_fn_complex = torch.jit._script_pdt(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))])
scripted_fn_complex = torch.jit.script(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))])
arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4)
self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2))
@ -241,7 +241,7 @@ class TestPDT(JitTestCase):
return 0
make_global(test_bool)
scripted_fn_bool = torch.jit._script_pdt(test_bool, example_inputs=[(True,)])
scripted_fn_bool = torch.jit.script(test_bool, example_inputs=[(True,)])
self.assertEqual(scripted_fn_bool(True), test_bool(True))
def test_str(a):
@ -251,7 +251,7 @@ class TestPDT(JitTestCase):
return True
make_global(test_str)
scripted_fn_str = torch.jit._script_pdt(test_str, example_inputs=[("",)])
scripted_fn_str = torch.jit.script(test_str, example_inputs=[("",)])
self.assertEqual(scripted_fn_str("abc"), test_str("abc"))
def test_pdt_list_and_tuple(self):
@ -260,24 +260,24 @@ class TestPDT(JitTestCase):
make_global(test_list_and_tuple)
scripted_fn_float_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([4.9, 8.9],)])
scripted_fn_float_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([4.9, 8.9],)])
self.assertEqual(scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6]))
scripted_fn_bool_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([True, False, True],)])
scripted_fn_bool_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([True, False, True],)])
self.assertEqual(scripted_fn_bool_list_input([True, True, True]), test_list_and_tuple([True, True, True]))
scripted_fn_int_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([3, 4, 5], )])
scripted_fn_int_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([3, 4, 5], )])
self.assertEqual(scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3]))
scripted_fn_float_tuple_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[((4.9, 8.9),)])
scripted_fn_float_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((4.9, 8.9),)])
self.assertEqual(scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6)))
scripted_fn_bool_tuple_input = torch.jit._script_pdt(test_list_and_tuple,
scripted_fn_bool_tuple_input = torch.jit.script(test_list_and_tuple,
example_inputs=[((True, False, True),)])
self.assertEqual(scripted_fn_bool_tuple_input((True, True, True)),
test_list_and_tuple((True, True, True)))
scripted_fn_int_tuple_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[((3, 4, 5), )])
scripted_fn_int_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((3, 4, 5), )])
self.assertEqual(scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3)))
def test_nested_list_and_tuple(self):
@ -295,22 +295,22 @@ class TestPDT(JitTestCase):
make_global(test_nested_list, test_nested_tuple)
list_inp = [[1, 2, 3, ], [5, 6, 7, ]]
scripted_fn = torch.jit._script_pdt(test_nested_list, example_inputs=[(list_inp, ), ])
scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ])
inp = [[0, 4, 7, ], [8, 11, ], [6, -1, -20, ]]
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))
list_inp = ([1, 2, 3, ], [5, 6, 7, ])
scripted_fn = torch.jit._script_pdt(test_nested_list, example_inputs=[(list_inp, ), ])
scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ])
inp = ([0, 4, 7, ], [8, 11, ], [6, -1, -20, ])
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))
tup_inp = [(1.0, 2.6, 3.7, ), (5.7, 6.1, 1.7, )]
scripted_fn = torch.jit._script_pdt(test_nested_tuple, example_inputs=[(tup_inp, ), ])
scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ])
inp = [(1.0, 4.1, 7.4, ), (4.8, 1.1, -1.2, ), (6.3, -1.3, -2.0, )]
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))
tup_inp = ((True, False, True, ), (False, False, False, ))
scripted_fn = torch.jit._script_pdt(test_nested_tuple, example_inputs=[(tup_inp, ), ])
scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ])
inp = ((True, True, True, ), (False, False, True, ))
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))
@ -324,11 +324,11 @@ class TestPDT(JitTestCase):
make_global(test_dict, test_dict_int_list)
str_bool_inp = {'foo' : True, 'bar': False}
scripted_fn = torch.jit._script_pdt(test_dict, example_inputs=[(str_bool_inp,)])
scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)])
self.assertEqual(scripted_fn({'foo' : False, 'bar': True}, ), test_dict({'foo' : False, 'bar': True}, ))
str_list_inp = {0 : [True, False], 1: [False, True]}
scripted_fn = torch.jit._script_pdt(test_dict_int_list, example_inputs=[(str_list_inp,)])
scripted_fn = torch.jit.script(test_dict_int_list, example_inputs=[(str_list_inp,)])
self.assertEqual(scripted_fn({0 : [False, False], 1: [True, True]}, ),
test_dict_int_list({0 : [False, False], 1: [True, True]}, ))
@ -349,13 +349,13 @@ class TestPDT(JitTestCase):
make_global(test_multiple_types, test_multiple_type_refinement)
scripted_fn = torch.jit._script_pdt(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )])
scripted_fn = torch.jit.script(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )])
self.assertEqual(scripted_fn(10), test_multiple_types(10))
self.assertEqual(scripted_fn("def"), test_multiple_types("def"))
self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999))
self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14]))
scripted_fn = torch.jit._script_pdt(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,),
scripted_fn = torch.jit.script(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,),
([3, 4, 5],), (True, ), ({"a": True}, ), ])
self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10))
self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def"))
@ -381,7 +381,7 @@ class TestPDT(JitTestCase):
make_global(UserDefinedClass, test_model)
user_class = UserDefinedClass()
scripted_fn = torch.jit._script_pdt(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
scripted_fn = torch.jit.script(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
self.assertEqual(scripted_fn(100, user_class, ), test_model(100, user_class))
self.assertEqual(scripted_fn(1.9, user_class, ), test_model(1.9, user_class))
@ -403,7 +403,7 @@ class TestPDT(JitTestCase):
make_global(ClassWithArgs, test_model_with_args)
user_class = ClassWithArgs(False)
scripted_fn = torch.jit._script_pdt(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
scripted_fn = torch.jit.script(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
self.assertEqual(scripted_fn(100, ClassWithArgs(True), ), test_model_with_args(100, ClassWithArgs(True)))
def test_nn_parameter_as_arg(self):
@ -420,7 +420,7 @@ class TestPDT(JitTestCase):
make_global(TestNNParameter)
pdt_model = TestNNParameter()
scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [(10, ), ], })
scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [(10, ), ], })
self.assertEqual(scripted_fn(20), pdt_model(20))
def test_fx_tracing_with_typing(self):
@ -434,7 +434,7 @@ class TestPDT(JitTestCase):
make_global(FXModel, FXModelOutput)
pdt_model = FXModel()
scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
self.assertEqual(scripted_fn([20]), pdt_model([20]))
def test_nonetype_as_optional_of_type(self):
@ -446,11 +446,11 @@ class TestPDT(JitTestCase):
make_global(test_none)
scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (10.6, )])
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10.6, )])
self.assertEqual(scripted_fn(30.9, ), test_none(30.9, ))
scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (10, )])
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10, )])
self.assertEqual(scripted_fn(2, ), test_none(2, ))
scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (torch.Tensor(1), )])
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (torch.Tensor(1), )])
self.assertEqual(scripted_fn(torch.ones(1), ), test_none(torch.ones(1), ))

View File

@ -20,7 +20,6 @@ from torch._jit_internal import (
)
from torch.jit._script import (
script,
_script_pdt,
Attribute,
ScriptModule,
script_method,

View File

@ -984,57 +984,6 @@ def call_prepare_scriptable_func(obj):
memo: Dict[int, torch.nn.Module] = {}
return call_prepare_scriptable_func_impl(obj, memo)
def _script_pdt(obj, optimize=None, _frames_up=0, _rcb=None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None):
# This is a private API, intended for internal use only. Usage of this API is only for experimental
# purposes only and is highly discouraged.
global type_trace_db
if not _enabled:
return obj
if optimize is not None:
warnings.warn(
"`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
)
# No-op for modules and functions that are already scripted
if isinstance(obj, ScriptModule):
return obj
if isinstance(obj, ScriptFunction):
return obj
if example_inputs:
# If MonkeyType is installed, enable profile directed type annotation
# Check if example_inputs are defined and generate call traces
# for the method by running eager mode version of the method with
# the provide example inputs. This logs all the traces in type_trace_db
type_trace_db = JitTypeTraceStore()
if monkeytype_trace:
monkeytype_config = JitTypeTraceConfig(type_trace_db)
with monkeytype_trace(monkeytype_config):
if isinstance(example_inputs, Dict):
# If the obj is an nn.Module or a class, then each method is
# executed with the arguments provided in the example inputs.
# example inputs here will be of type Dict(class.method, (arguments))
# This is used to infer type annotations for those methods
# which are not called directly under the hood of monkeytype.
for module, example_input in example_inputs.items():
for example in example_input:
module(*example)
elif isinstance(example_inputs, List):
for examples in example_inputs:
obj(*examples)
else:
warnings.warn("Error: Unable to infer types. Please format the inputs to type `List[Tuple]`"
" or `Dict[Callable, List[Tuple]]` to be run with MonkeyType.")
else:
warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
"to enable Profile-Directed Typing in TorchScript. Refer to "
"https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ")
return script(obj, optimize, _frames_up, _rcb)
def create_script_dict(obj):
"""
Create a ``torch._C.ScriptDict`` instance with the data from ``obj``.
@ -1065,7 +1014,8 @@ def create_script_list(obj, type_hint=None):
return torch._C.ScriptList(obj) # type: ignore[attr-defined]
def script(obj, optimize=None, _frames_up=0, _rcb=None):
def script(obj, optimize=None, _frames_up=0, _rcb=None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None):
r"""
Scripting a function or ``nn.Module`` will inspect the source code, compile
it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or
@ -1083,6 +1033,8 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
Args:
obj (callable, class, or ``nn.Module``): The ``nn.Module``, function, class type,
dictionary, or list to compile.
example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]): Provide example inputs
to annotate the arguments for a function or ``nn.Module``.
Returns:
If ``obj`` is ``nn.Module``, ``script`` returns
@ -1124,6 +1076,34 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
...
****Scripting a function using example_inputs**
Example inputs can be used to annotate a function arguments.
Example (annotating a function before scripting):
.. testcode::
import torch
def test_sum(a, b):
return a + b
# Annotate the arguments to be int
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])
print(type(scripted_fn)) # torch.jit.ScriptFunction
# See the compiled graph as Python code
print(scripted_fn.code)
# Call the function using the TorchScript interpreter
scripted_fn(20, 100)
.. testoutput::
:hide:
...
**Scripting an nn.Module**
Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively
compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses
@ -1210,7 +1190,30 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))
Example ( Annotating forward of nn.Module using example_inputs)::
import torch
import torch.nn as nn
from typing import NamedTuple
class MyModule(NamedTuple):
result: List[int]
class TestNNModule(torch.nn.Module):
def forward(self, a) -> MyModule:
result = MyModule(result=a)
return result
pdt_model = TestNNModule()
# Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward
scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
# Run the scripted_model with actual inputs
print(scripted_model([20]))
"""
global type_trace_db
if not _enabled:
return obj
@ -1227,6 +1230,35 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
if isinstance(obj, ScriptFunction):
return obj
if example_inputs:
# If MonkeyType is installed, enable profile directed type annotation
# Check if example_inputs are defined and generate call traces
# for the method by running eager mode version of the method with
# the provide example inputs. This logs all the traces in type_trace_db
type_trace_db = JitTypeTraceStore()
if monkeytype_trace:
monkeytype_config = JitTypeTraceConfig(type_trace_db)
with monkeytype_trace(monkeytype_config):
if isinstance(example_inputs, Dict):
# If the obj is an nn.Module or a class, then each method is
# executed with the arguments provided in the example inputs.
# example inputs here will be of type Dict(class.method, (arguments))
# This is used to infer type annotations for those methods
# which are not called directly under the hood of monkeytype.
for module, example_input in example_inputs.items():
for example in example_input:
module(*example)
elif isinstance(example_inputs, List):
for examples in example_inputs:
obj(*examples)
else:
raise ValueError("Error: Unable to infer types. Please format the inputs to type `List[Tuple]`"
" or `Dict[Callable, List[Tuple]]` to be run with MonkeyType.")
else:
warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
"to enable Profile-Directed Typing in TorchScript. Refer to "
"https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ")
if isinstance(obj, torch.nn.Module):
obj = call_prepare_scriptable_func(obj)
return torch.jit._recursive.create_script_module(