mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Merge script and _script_pdt API (#62420)
Summary: Merge `torch.jit.script` and `torch.jit._script_pdt` API. This PR merges profile directed typing with script api Pull Request resolved: https://github.com/pytorch/pytorch/pull/62420 Reviewed By: iramazanli Differential Revision: D30579015 Pulled By: nikithamalgifb fbshipit-source-id: 99ba6839d235d61b2dd0144b466b2063a53ccece
This commit is contained in:
parent
0e8c3c51d9
commit
510d2ece81
|
|
@ -40,7 +40,7 @@ class TestPDT(JitTestCase):
|
||||||
make_global(TestPDTModel)
|
make_global(TestPDTModel)
|
||||||
pdt_model = TestPDTModel()
|
pdt_model = TestPDTModel()
|
||||||
inp: List[Tuple[Any, ...]] = [(20, ), (2.7, ), (False, ), ]
|
inp: List[Tuple[Any, ...]] = [(20, ), (2.7, ), (False, ), ]
|
||||||
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp})
|
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp})
|
||||||
self.assertEqual(scripted_pdt_model(50), pdt_model(50))
|
self.assertEqual(scripted_pdt_model(50), pdt_model(50))
|
||||||
self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8))
|
self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8))
|
||||||
self.assertTrue(scripted_pdt_model(True), pdt_model(True))
|
self.assertTrue(scripted_pdt_model(True), pdt_model(True))
|
||||||
|
|
@ -67,7 +67,7 @@ class TestPDT(JitTestCase):
|
||||||
inner_pdt_model = NestedPDTInner()
|
inner_pdt_model = NestedPDTInner()
|
||||||
wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model)
|
wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model)
|
||||||
inp: List[Tuple[Any, ...]] = [(20, ), (False, )]
|
inp: List[Tuple[Any, ...]] = [(20, ), (False, )]
|
||||||
scripted_pdt_model = torch.jit._script_pdt(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp})
|
scripted_pdt_model = torch.jit.script(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp})
|
||||||
self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30))
|
self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30))
|
||||||
self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9))
|
self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9))
|
||||||
self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True))
|
self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True))
|
||||||
|
|
@ -95,7 +95,7 @@ class TestPDT(JitTestCase):
|
||||||
outer_pdt_model = NestedModulePDTOuter(inner_pdt_model)
|
outer_pdt_model = NestedModulePDTOuter(inner_pdt_model)
|
||||||
inner_input: List[Tuple[Any, ...]] = [(10, 10), (1.9, 20), ]
|
inner_input: List[Tuple[Any, ...]] = [(10, 10), (1.9, 20), ]
|
||||||
outer_input: List[Tuple[Any, ...]] = [(20, ), (False, )]
|
outer_input: List[Tuple[Any, ...]] = [(20, ), (False, )]
|
||||||
scripted_pdt_model = torch.jit._script_pdt(outer_pdt_model, example_inputs={inner_pdt_model: inner_input,
|
scripted_pdt_model = torch.jit.script(outer_pdt_model, example_inputs={inner_pdt_model: inner_input,
|
||||||
outer_pdt_model: outer_input, })
|
outer_pdt_model: outer_input, })
|
||||||
self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30))
|
self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30))
|
||||||
self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9))
|
self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9))
|
||||||
|
|
@ -119,7 +119,7 @@ class TestPDT(JitTestCase):
|
||||||
make_global(NestedFunctionInForward)
|
make_global(NestedFunctionInForward)
|
||||||
pdt_model = NestedFunctionInForward()
|
pdt_model = NestedFunctionInForward()
|
||||||
inp: List[Tuple[Any, ...]] = [(-1, ), (False, )]
|
inp: List[Tuple[Any, ...]] = [(-1, ), (False, )]
|
||||||
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp})
|
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp})
|
||||||
self.assertEqual(scripted_pdt_model(30), pdt_model(30))
|
self.assertEqual(scripted_pdt_model(30), pdt_model(30))
|
||||||
self.assertEqual(scripted_pdt_model(True), pdt_model(True))
|
self.assertEqual(scripted_pdt_model(True), pdt_model(True))
|
||||||
|
|
||||||
|
|
@ -142,7 +142,7 @@ class TestPDT(JitTestCase):
|
||||||
make_global(TestModelWithExport)
|
make_global(TestModelWithExport)
|
||||||
pdt_model = TestModelWithExport()
|
pdt_model = TestModelWithExport()
|
||||||
inp: List[Tuple[Any, ...]] = [(20, 10, ), (2.7, 8.9, ), ]
|
inp: List[Tuple[Any, ...]] = [(20, 10, ), (2.7, 8.9, ), ]
|
||||||
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model.fn: inp})
|
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model.fn: inp})
|
||||||
self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90))
|
self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90))
|
||||||
self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2))
|
self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2))
|
||||||
self.assertTrue(scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2))
|
self.assertTrue(scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2))
|
||||||
|
|
@ -155,7 +155,7 @@ class TestPDT(JitTestCase):
|
||||||
make_global(PDTModel)
|
make_global(PDTModel)
|
||||||
pdt_model = PDTModel()
|
pdt_model = PDTModel()
|
||||||
inp: List[Tuple[Any, ...]] = [([10, 20, ], ), ]
|
inp: List[Tuple[Any, ...]] = [([10, 20, ], ), ]
|
||||||
scripted_pdt_model = torch.jit._script_pdt(PDTModel, example_inputs={pdt_model.test_sum: inp})
|
scripted_pdt_model = torch.jit.script(PDTModel, example_inputs={pdt_model.test_sum: inp})
|
||||||
script_model = scripted_pdt_model()
|
script_model = scripted_pdt_model()
|
||||||
self.assertEqual(script_model.test_sum([10, 20, 30, ], ), pdt_model.test_sum([10, 20, 30, ], ))
|
self.assertEqual(script_model.test_sum([10, 20, 30, ], ), pdt_model.test_sum([10, 20, 30, ], ))
|
||||||
|
|
||||||
|
|
@ -174,7 +174,7 @@ class TestPDT(JitTestCase):
|
||||||
pdt_model = PDTModelWithManyMethods()
|
pdt_model = PDTModelWithManyMethods()
|
||||||
list_inp: List[Tuple[Any, ...]] = [([1.2, 2.3, ], ), ]
|
list_inp: List[Tuple[Any, ...]] = [([1.2, 2.3, ], ), ]
|
||||||
str_inp: List[Tuple[Any, ...]] = [("abc", "b", ), ]
|
str_inp: List[Tuple[Any, ...]] = [("abc", "b", ), ]
|
||||||
scripted_pdt_model = torch.jit._script_pdt(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp,
|
scripted_pdt_model = torch.jit.script(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp,
|
||||||
pdt_model.test_substring: str_inp})
|
pdt_model.test_substring: str_inp})
|
||||||
script_model = scripted_pdt_model()
|
script_model = scripted_pdt_model()
|
||||||
self.assertEqual(script_model.test_list_to_dict([1.1, 2.2, 3.3, ], ), pdt_model.test_list_to_dict([1.1, 2.2, 3.3, ], ))
|
self.assertEqual(script_model.test_list_to_dict([1.1, 2.2, 3.3, ], ), pdt_model.test_list_to_dict([1.1, 2.2, 3.3, ], ))
|
||||||
|
|
@ -195,8 +195,8 @@ class TestPDT(JitTestCase):
|
||||||
pdt_model_two = PDTModelTwo()
|
pdt_model_two = PDTModelTwo()
|
||||||
dict_inp: List[Tuple[Any, ...]] = [({1.2: True, 2.3: False, }, 1.2), ]
|
dict_inp: List[Tuple[Any, ...]] = [({1.2: True, 2.3: False, }, 1.2), ]
|
||||||
list_inp: List[Tuple[Any, ...]] = [(["abc", "b", ], "c"), ]
|
list_inp: List[Tuple[Any, ...]] = [(["abc", "b", ], "c"), ]
|
||||||
scripted_pdt_model_one = torch.jit._script_pdt(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp})
|
scripted_pdt_model_one = torch.jit.script(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp})
|
||||||
scripted_pdt_model_two = torch.jit._script_pdt(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp})
|
scripted_pdt_model_two = torch.jit.script(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp})
|
||||||
|
|
||||||
script_model_one, script_model_two = scripted_pdt_model_one(), scripted_pdt_model_two()
|
script_model_one, script_model_two = scripted_pdt_model_one(), scripted_pdt_model_two()
|
||||||
self.assertEqual(script_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4),
|
self.assertEqual(script_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4),
|
||||||
|
|
@ -209,28 +209,28 @@ class TestPDT(JitTestCase):
|
||||||
return a + b
|
return a + b
|
||||||
|
|
||||||
make_global(test_sum)
|
make_global(test_sum)
|
||||||
scripted_fn_add = torch.jit._script_pdt(test_sum, example_inputs=[(3, 4)])
|
scripted_fn_add = torch.jit.script(test_sum, example_inputs=[(3, 4)])
|
||||||
self.assertEqual(scripted_fn_add(10, 2), test_sum(10, 2))
|
self.assertEqual(scripted_fn_add(10, 2), test_sum(10, 2))
|
||||||
|
|
||||||
def test_sub(a, b):
|
def test_sub(a, b):
|
||||||
return a - b
|
return a - b
|
||||||
|
|
||||||
make_global(test_sub)
|
make_global(test_sub)
|
||||||
scripted_fn_sub = torch.jit._script_pdt(test_sub, example_inputs=[(3.9, 4.10)])
|
scripted_fn_sub = torch.jit.script(test_sub, example_inputs=[(3.9, 4.10)])
|
||||||
self.assertEqual(scripted_fn_sub(6.5, 2.9), test_sub(6.5, 2.9))
|
self.assertEqual(scripted_fn_sub(6.5, 2.9), test_sub(6.5, 2.9))
|
||||||
|
|
||||||
def test_mul(a, b):
|
def test_mul(a, b):
|
||||||
return a * b
|
return a * b
|
||||||
|
|
||||||
make_global(test_mul)
|
make_global(test_mul)
|
||||||
scripted_fn_mul = torch.jit._script_pdt(test_mul, example_inputs=[(-10, 9)])
|
scripted_fn_mul = torch.jit.script(test_mul, example_inputs=[(-10, 9)])
|
||||||
self.assertEqual(scripted_fn_mul(-1, 3), test_mul(-1, 3))
|
self.assertEqual(scripted_fn_mul(-1, 3), test_mul(-1, 3))
|
||||||
|
|
||||||
def test_args_complex(real, img):
|
def test_args_complex(real, img):
|
||||||
return torch.complex(real, img)
|
return torch.complex(real, img)
|
||||||
|
|
||||||
make_global(test_args_complex)
|
make_global(test_args_complex)
|
||||||
scripted_fn_complex = torch.jit._script_pdt(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))])
|
scripted_fn_complex = torch.jit.script(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))])
|
||||||
arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4)
|
arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4)
|
||||||
self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2))
|
self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2))
|
||||||
|
|
||||||
|
|
@ -241,7 +241,7 @@ class TestPDT(JitTestCase):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
make_global(test_bool)
|
make_global(test_bool)
|
||||||
scripted_fn_bool = torch.jit._script_pdt(test_bool, example_inputs=[(True,)])
|
scripted_fn_bool = torch.jit.script(test_bool, example_inputs=[(True,)])
|
||||||
self.assertEqual(scripted_fn_bool(True), test_bool(True))
|
self.assertEqual(scripted_fn_bool(True), test_bool(True))
|
||||||
|
|
||||||
def test_str(a):
|
def test_str(a):
|
||||||
|
|
@ -251,7 +251,7 @@ class TestPDT(JitTestCase):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
make_global(test_str)
|
make_global(test_str)
|
||||||
scripted_fn_str = torch.jit._script_pdt(test_str, example_inputs=[("",)])
|
scripted_fn_str = torch.jit.script(test_str, example_inputs=[("",)])
|
||||||
self.assertEqual(scripted_fn_str("abc"), test_str("abc"))
|
self.assertEqual(scripted_fn_str("abc"), test_str("abc"))
|
||||||
|
|
||||||
def test_pdt_list_and_tuple(self):
|
def test_pdt_list_and_tuple(self):
|
||||||
|
|
@ -260,24 +260,24 @@ class TestPDT(JitTestCase):
|
||||||
|
|
||||||
make_global(test_list_and_tuple)
|
make_global(test_list_and_tuple)
|
||||||
|
|
||||||
scripted_fn_float_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([4.9, 8.9],)])
|
scripted_fn_float_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([4.9, 8.9],)])
|
||||||
self.assertEqual(scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6]))
|
self.assertEqual(scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6]))
|
||||||
|
|
||||||
scripted_fn_bool_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([True, False, True],)])
|
scripted_fn_bool_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([True, False, True],)])
|
||||||
self.assertEqual(scripted_fn_bool_list_input([True, True, True]), test_list_and_tuple([True, True, True]))
|
self.assertEqual(scripted_fn_bool_list_input([True, True, True]), test_list_and_tuple([True, True, True]))
|
||||||
|
|
||||||
scripted_fn_int_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([3, 4, 5], )])
|
scripted_fn_int_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([3, 4, 5], )])
|
||||||
self.assertEqual(scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3]))
|
self.assertEqual(scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3]))
|
||||||
|
|
||||||
scripted_fn_float_tuple_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[((4.9, 8.9),)])
|
scripted_fn_float_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((4.9, 8.9),)])
|
||||||
self.assertEqual(scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6)))
|
self.assertEqual(scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6)))
|
||||||
|
|
||||||
scripted_fn_bool_tuple_input = torch.jit._script_pdt(test_list_and_tuple,
|
scripted_fn_bool_tuple_input = torch.jit.script(test_list_and_tuple,
|
||||||
example_inputs=[((True, False, True),)])
|
example_inputs=[((True, False, True),)])
|
||||||
self.assertEqual(scripted_fn_bool_tuple_input((True, True, True)),
|
self.assertEqual(scripted_fn_bool_tuple_input((True, True, True)),
|
||||||
test_list_and_tuple((True, True, True)))
|
test_list_and_tuple((True, True, True)))
|
||||||
|
|
||||||
scripted_fn_int_tuple_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[((3, 4, 5), )])
|
scripted_fn_int_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((3, 4, 5), )])
|
||||||
self.assertEqual(scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3)))
|
self.assertEqual(scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3)))
|
||||||
|
|
||||||
def test_nested_list_and_tuple(self):
|
def test_nested_list_and_tuple(self):
|
||||||
|
|
@ -295,22 +295,22 @@ class TestPDT(JitTestCase):
|
||||||
make_global(test_nested_list, test_nested_tuple)
|
make_global(test_nested_list, test_nested_tuple)
|
||||||
|
|
||||||
list_inp = [[1, 2, 3, ], [5, 6, 7, ]]
|
list_inp = [[1, 2, 3, ], [5, 6, 7, ]]
|
||||||
scripted_fn = torch.jit._script_pdt(test_nested_list, example_inputs=[(list_inp, ), ])
|
scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ])
|
||||||
inp = [[0, 4, 7, ], [8, 11, ], [6, -1, -20, ]]
|
inp = [[0, 4, 7, ], [8, 11, ], [6, -1, -20, ]]
|
||||||
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))
|
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))
|
||||||
|
|
||||||
list_inp = ([1, 2, 3, ], [5, 6, 7, ])
|
list_inp = ([1, 2, 3, ], [5, 6, 7, ])
|
||||||
scripted_fn = torch.jit._script_pdt(test_nested_list, example_inputs=[(list_inp, ), ])
|
scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ])
|
||||||
inp = ([0, 4, 7, ], [8, 11, ], [6, -1, -20, ])
|
inp = ([0, 4, 7, ], [8, 11, ], [6, -1, -20, ])
|
||||||
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))
|
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))
|
||||||
|
|
||||||
tup_inp = [(1.0, 2.6, 3.7, ), (5.7, 6.1, 1.7, )]
|
tup_inp = [(1.0, 2.6, 3.7, ), (5.7, 6.1, 1.7, )]
|
||||||
scripted_fn = torch.jit._script_pdt(test_nested_tuple, example_inputs=[(tup_inp, ), ])
|
scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ])
|
||||||
inp = [(1.0, 4.1, 7.4, ), (4.8, 1.1, -1.2, ), (6.3, -1.3, -2.0, )]
|
inp = [(1.0, 4.1, 7.4, ), (4.8, 1.1, -1.2, ), (6.3, -1.3, -2.0, )]
|
||||||
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))
|
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))
|
||||||
|
|
||||||
tup_inp = ((True, False, True, ), (False, False, False, ))
|
tup_inp = ((True, False, True, ), (False, False, False, ))
|
||||||
scripted_fn = torch.jit._script_pdt(test_nested_tuple, example_inputs=[(tup_inp, ), ])
|
scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ])
|
||||||
inp = ((True, True, True, ), (False, False, True, ))
|
inp = ((True, True, True, ), (False, False, True, ))
|
||||||
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))
|
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))
|
||||||
|
|
||||||
|
|
@ -324,11 +324,11 @@ class TestPDT(JitTestCase):
|
||||||
make_global(test_dict, test_dict_int_list)
|
make_global(test_dict, test_dict_int_list)
|
||||||
|
|
||||||
str_bool_inp = {'foo' : True, 'bar': False}
|
str_bool_inp = {'foo' : True, 'bar': False}
|
||||||
scripted_fn = torch.jit._script_pdt(test_dict, example_inputs=[(str_bool_inp,)])
|
scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)])
|
||||||
self.assertEqual(scripted_fn({'foo' : False, 'bar': True}, ), test_dict({'foo' : False, 'bar': True}, ))
|
self.assertEqual(scripted_fn({'foo' : False, 'bar': True}, ), test_dict({'foo' : False, 'bar': True}, ))
|
||||||
|
|
||||||
str_list_inp = {0 : [True, False], 1: [False, True]}
|
str_list_inp = {0 : [True, False], 1: [False, True]}
|
||||||
scripted_fn = torch.jit._script_pdt(test_dict_int_list, example_inputs=[(str_list_inp,)])
|
scripted_fn = torch.jit.script(test_dict_int_list, example_inputs=[(str_list_inp,)])
|
||||||
self.assertEqual(scripted_fn({0 : [False, False], 1: [True, True]}, ),
|
self.assertEqual(scripted_fn({0 : [False, False], 1: [True, True]}, ),
|
||||||
test_dict_int_list({0 : [False, False], 1: [True, True]}, ))
|
test_dict_int_list({0 : [False, False], 1: [True, True]}, ))
|
||||||
|
|
||||||
|
|
@ -349,13 +349,13 @@ class TestPDT(JitTestCase):
|
||||||
|
|
||||||
make_global(test_multiple_types, test_multiple_type_refinement)
|
make_global(test_multiple_types, test_multiple_type_refinement)
|
||||||
|
|
||||||
scripted_fn = torch.jit._script_pdt(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )])
|
scripted_fn = torch.jit.script(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )])
|
||||||
self.assertEqual(scripted_fn(10), test_multiple_types(10))
|
self.assertEqual(scripted_fn(10), test_multiple_types(10))
|
||||||
self.assertEqual(scripted_fn("def"), test_multiple_types("def"))
|
self.assertEqual(scripted_fn("def"), test_multiple_types("def"))
|
||||||
self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999))
|
self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999))
|
||||||
self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14]))
|
self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14]))
|
||||||
|
|
||||||
scripted_fn = torch.jit._script_pdt(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,),
|
scripted_fn = torch.jit.script(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,),
|
||||||
([3, 4, 5],), (True, ), ({"a": True}, ), ])
|
([3, 4, 5],), (True, ), ({"a": True}, ), ])
|
||||||
self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10))
|
self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10))
|
||||||
self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def"))
|
self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def"))
|
||||||
|
|
@ -381,7 +381,7 @@ class TestPDT(JitTestCase):
|
||||||
make_global(UserDefinedClass, test_model)
|
make_global(UserDefinedClass, test_model)
|
||||||
|
|
||||||
user_class = UserDefinedClass()
|
user_class = UserDefinedClass()
|
||||||
scripted_fn = torch.jit._script_pdt(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
|
scripted_fn = torch.jit.script(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
|
||||||
self.assertEqual(scripted_fn(100, user_class, ), test_model(100, user_class))
|
self.assertEqual(scripted_fn(100, user_class, ), test_model(100, user_class))
|
||||||
self.assertEqual(scripted_fn(1.9, user_class, ), test_model(1.9, user_class))
|
self.assertEqual(scripted_fn(1.9, user_class, ), test_model(1.9, user_class))
|
||||||
|
|
||||||
|
|
@ -403,7 +403,7 @@ class TestPDT(JitTestCase):
|
||||||
make_global(ClassWithArgs, test_model_with_args)
|
make_global(ClassWithArgs, test_model_with_args)
|
||||||
|
|
||||||
user_class = ClassWithArgs(False)
|
user_class = ClassWithArgs(False)
|
||||||
scripted_fn = torch.jit._script_pdt(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
|
scripted_fn = torch.jit.script(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
|
||||||
self.assertEqual(scripted_fn(100, ClassWithArgs(True), ), test_model_with_args(100, ClassWithArgs(True)))
|
self.assertEqual(scripted_fn(100, ClassWithArgs(True), ), test_model_with_args(100, ClassWithArgs(True)))
|
||||||
|
|
||||||
def test_nn_parameter_as_arg(self):
|
def test_nn_parameter_as_arg(self):
|
||||||
|
|
@ -420,7 +420,7 @@ class TestPDT(JitTestCase):
|
||||||
|
|
||||||
make_global(TestNNParameter)
|
make_global(TestNNParameter)
|
||||||
pdt_model = TestNNParameter()
|
pdt_model = TestNNParameter()
|
||||||
scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [(10, ), ], })
|
scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [(10, ), ], })
|
||||||
self.assertEqual(scripted_fn(20), pdt_model(20))
|
self.assertEqual(scripted_fn(20), pdt_model(20))
|
||||||
|
|
||||||
def test_fx_tracing_with_typing(self):
|
def test_fx_tracing_with_typing(self):
|
||||||
|
|
@ -434,7 +434,7 @@ class TestPDT(JitTestCase):
|
||||||
|
|
||||||
make_global(FXModel, FXModelOutput)
|
make_global(FXModel, FXModelOutput)
|
||||||
pdt_model = FXModel()
|
pdt_model = FXModel()
|
||||||
scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
|
scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
|
||||||
self.assertEqual(scripted_fn([20]), pdt_model([20]))
|
self.assertEqual(scripted_fn([20]), pdt_model([20]))
|
||||||
|
|
||||||
def test_nonetype_as_optional_of_type(self):
|
def test_nonetype_as_optional_of_type(self):
|
||||||
|
|
@ -446,11 +446,11 @@ class TestPDT(JitTestCase):
|
||||||
|
|
||||||
make_global(test_none)
|
make_global(test_none)
|
||||||
|
|
||||||
scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (10.6, )])
|
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10.6, )])
|
||||||
self.assertEqual(scripted_fn(30.9, ), test_none(30.9, ))
|
self.assertEqual(scripted_fn(30.9, ), test_none(30.9, ))
|
||||||
|
|
||||||
scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (10, )])
|
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10, )])
|
||||||
self.assertEqual(scripted_fn(2, ), test_none(2, ))
|
self.assertEqual(scripted_fn(2, ), test_none(2, ))
|
||||||
|
|
||||||
scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (torch.Tensor(1), )])
|
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (torch.Tensor(1), )])
|
||||||
self.assertEqual(scripted_fn(torch.ones(1), ), test_none(torch.ones(1), ))
|
self.assertEqual(scripted_fn(torch.ones(1), ), test_none(torch.ones(1), ))
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ from torch._jit_internal import (
|
||||||
)
|
)
|
||||||
from torch.jit._script import (
|
from torch.jit._script import (
|
||||||
script,
|
script,
|
||||||
_script_pdt,
|
|
||||||
Attribute,
|
Attribute,
|
||||||
ScriptModule,
|
ScriptModule,
|
||||||
script_method,
|
script_method,
|
||||||
|
|
|
||||||
|
|
@ -984,57 +984,6 @@ def call_prepare_scriptable_func(obj):
|
||||||
memo: Dict[int, torch.nn.Module] = {}
|
memo: Dict[int, torch.nn.Module] = {}
|
||||||
return call_prepare_scriptable_func_impl(obj, memo)
|
return call_prepare_scriptable_func_impl(obj, memo)
|
||||||
|
|
||||||
|
|
||||||
def _script_pdt(obj, optimize=None, _frames_up=0, _rcb=None,
|
|
||||||
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None):
|
|
||||||
# This is a private API, intended for internal use only. Usage of this API is only for experimental
|
|
||||||
# purposes only and is highly discouraged.
|
|
||||||
global type_trace_db
|
|
||||||
if not _enabled:
|
|
||||||
return obj
|
|
||||||
|
|
||||||
if optimize is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
|
|
||||||
)
|
|
||||||
|
|
||||||
# No-op for modules and functions that are already scripted
|
|
||||||
if isinstance(obj, ScriptModule):
|
|
||||||
return obj
|
|
||||||
if isinstance(obj, ScriptFunction):
|
|
||||||
return obj
|
|
||||||
|
|
||||||
if example_inputs:
|
|
||||||
# If MonkeyType is installed, enable profile directed type annotation
|
|
||||||
# Check if example_inputs are defined and generate call traces
|
|
||||||
# for the method by running eager mode version of the method with
|
|
||||||
# the provide example inputs. This logs all the traces in type_trace_db
|
|
||||||
type_trace_db = JitTypeTraceStore()
|
|
||||||
if monkeytype_trace:
|
|
||||||
monkeytype_config = JitTypeTraceConfig(type_trace_db)
|
|
||||||
with monkeytype_trace(monkeytype_config):
|
|
||||||
if isinstance(example_inputs, Dict):
|
|
||||||
# If the obj is an nn.Module or a class, then each method is
|
|
||||||
# executed with the arguments provided in the example inputs.
|
|
||||||
# example inputs here will be of type Dict(class.method, (arguments))
|
|
||||||
# This is used to infer type annotations for those methods
|
|
||||||
# which are not called directly under the hood of monkeytype.
|
|
||||||
for module, example_input in example_inputs.items():
|
|
||||||
for example in example_input:
|
|
||||||
module(*example)
|
|
||||||
elif isinstance(example_inputs, List):
|
|
||||||
for examples in example_inputs:
|
|
||||||
obj(*examples)
|
|
||||||
else:
|
|
||||||
warnings.warn("Error: Unable to infer types. Please format the inputs to type `List[Tuple]`"
|
|
||||||
" or `Dict[Callable, List[Tuple]]` to be run with MonkeyType.")
|
|
||||||
else:
|
|
||||||
warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
|
|
||||||
"to enable Profile-Directed Typing in TorchScript. Refer to "
|
|
||||||
"https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ")
|
|
||||||
return script(obj, optimize, _frames_up, _rcb)
|
|
||||||
|
|
||||||
|
|
||||||
def create_script_dict(obj):
|
def create_script_dict(obj):
|
||||||
"""
|
"""
|
||||||
Create a ``torch._C.ScriptDict`` instance with the data from ``obj``.
|
Create a ``torch._C.ScriptDict`` instance with the data from ``obj``.
|
||||||
|
|
@ -1065,7 +1014,8 @@ def create_script_list(obj, type_hint=None):
|
||||||
return torch._C.ScriptList(obj) # type: ignore[attr-defined]
|
return torch._C.ScriptList(obj) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def script(obj, optimize=None, _frames_up=0, _rcb=None):
|
def script(obj, optimize=None, _frames_up=0, _rcb=None,
|
||||||
|
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None):
|
||||||
r"""
|
r"""
|
||||||
Scripting a function or ``nn.Module`` will inspect the source code, compile
|
Scripting a function or ``nn.Module`` will inspect the source code, compile
|
||||||
it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or
|
it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or
|
||||||
|
|
@ -1083,6 +1033,8 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
|
||||||
Args:
|
Args:
|
||||||
obj (callable, class, or ``nn.Module``): The ``nn.Module``, function, class type,
|
obj (callable, class, or ``nn.Module``): The ``nn.Module``, function, class type,
|
||||||
dictionary, or list to compile.
|
dictionary, or list to compile.
|
||||||
|
example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]): Provide example inputs
|
||||||
|
to annotate the arguments for a function or ``nn.Module``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If ``obj`` is ``nn.Module``, ``script`` returns
|
If ``obj`` is ``nn.Module``, ``script`` returns
|
||||||
|
|
@ -1124,6 +1076,34 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
****Scripting a function using example_inputs**
|
||||||
|
Example inputs can be used to annotate a function arguments.
|
||||||
|
|
||||||
|
Example (annotating a function before scripting):
|
||||||
|
|
||||||
|
.. testcode::
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def test_sum(a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
# Annotate the arguments to be int
|
||||||
|
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])
|
||||||
|
|
||||||
|
print(type(scripted_fn)) # torch.jit.ScriptFunction
|
||||||
|
|
||||||
|
# See the compiled graph as Python code
|
||||||
|
print(scripted_fn.code)
|
||||||
|
|
||||||
|
# Call the function using the TorchScript interpreter
|
||||||
|
scripted_fn(20, 100)
|
||||||
|
|
||||||
|
.. testoutput::
|
||||||
|
:hide:
|
||||||
|
|
||||||
|
...
|
||||||
|
|
||||||
**Scripting an nn.Module**
|
**Scripting an nn.Module**
|
||||||
Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively
|
Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively
|
||||||
compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses
|
compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses
|
||||||
|
|
@ -1210,7 +1190,30 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
|
||||||
scripted_module = torch.jit.script(MyModule())
|
scripted_module = torch.jit.script(MyModule())
|
||||||
print(scripted_module.some_entry_point(torch.randn(2, 2)))
|
print(scripted_module.some_entry_point(torch.randn(2, 2)))
|
||||||
print(scripted_module(torch.randn(2, 2)))
|
print(scripted_module(torch.randn(2, 2)))
|
||||||
|
|
||||||
|
Example ( Annotating forward of nn.Module using example_inputs)::
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
class MyModule(NamedTuple):
|
||||||
|
result: List[int]
|
||||||
|
|
||||||
|
class TestNNModule(torch.nn.Module):
|
||||||
|
def forward(self, a) -> MyModule:
|
||||||
|
result = MyModule(result=a)
|
||||||
|
return result
|
||||||
|
|
||||||
|
pdt_model = TestNNModule()
|
||||||
|
|
||||||
|
# Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward
|
||||||
|
scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
|
||||||
|
|
||||||
|
# Run the scripted_model with actual inputs
|
||||||
|
print(scripted_model([20]))
|
||||||
"""
|
"""
|
||||||
|
global type_trace_db
|
||||||
if not _enabled:
|
if not _enabled:
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
@ -1227,6 +1230,35 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
|
||||||
if isinstance(obj, ScriptFunction):
|
if isinstance(obj, ScriptFunction):
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
if example_inputs:
|
||||||
|
# If MonkeyType is installed, enable profile directed type annotation
|
||||||
|
# Check if example_inputs are defined and generate call traces
|
||||||
|
# for the method by running eager mode version of the method with
|
||||||
|
# the provide example inputs. This logs all the traces in type_trace_db
|
||||||
|
type_trace_db = JitTypeTraceStore()
|
||||||
|
if monkeytype_trace:
|
||||||
|
monkeytype_config = JitTypeTraceConfig(type_trace_db)
|
||||||
|
with monkeytype_trace(monkeytype_config):
|
||||||
|
if isinstance(example_inputs, Dict):
|
||||||
|
# If the obj is an nn.Module or a class, then each method is
|
||||||
|
# executed with the arguments provided in the example inputs.
|
||||||
|
# example inputs here will be of type Dict(class.method, (arguments))
|
||||||
|
# This is used to infer type annotations for those methods
|
||||||
|
# which are not called directly under the hood of monkeytype.
|
||||||
|
for module, example_input in example_inputs.items():
|
||||||
|
for example in example_input:
|
||||||
|
module(*example)
|
||||||
|
elif isinstance(example_inputs, List):
|
||||||
|
for examples in example_inputs:
|
||||||
|
obj(*examples)
|
||||||
|
else:
|
||||||
|
raise ValueError("Error: Unable to infer types. Please format the inputs to type `List[Tuple]`"
|
||||||
|
" or `Dict[Callable, List[Tuple]]` to be run with MonkeyType.")
|
||||||
|
else:
|
||||||
|
warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
|
||||||
|
"to enable Profile-Directed Typing in TorchScript. Refer to "
|
||||||
|
"https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ")
|
||||||
|
|
||||||
if isinstance(obj, torch.nn.Module):
|
if isinstance(obj, torch.nn.Module):
|
||||||
obj = call_prepare_scriptable_func(obj)
|
obj = call_prepare_scriptable_func(obj)
|
||||||
return torch.jit._recursive.create_script_module(
|
return torch.jit._recursive.create_script_module(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user