mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: In fbcode, aten and jit ops can get registered in different orders depending on build mode. In dev mode, aten is registered first; in opt mode, jit is registered first. This causes problems in torch.ops.aten.* calls; these calls use `torch._C._jit_get_operation`, which selects an overload based on the inputs to the call. It searches through the overloads for the op with the given name, and chooses the first one that matches the input types. "First" depends on whether aten or jit ops were registered first - e.g. in `test_both_scalars_cuda` in opt mode, it chooses `add.complex` and returns a complex value. We also saw this issue in https://github.com/pytorch/pytorch/pull/103576. This PR sorts the list of overloads first, putting the aten ops first. Differential Revision: D48304930 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107138 Approved by: https://github.com/ezyang, https://github.com/eellison
459 lines
15 KiB
Python
459 lines
15 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
|
from torch.testing import FileCheck
|
|
from torch import jit
|
|
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
|
import os
|
|
import sys
|
|
import torch
|
|
import torch.testing._internal.jit_utils
|
|
import torch.nn as nn
|
|
import unittest
|
|
from torch.testing._internal.common_utils import freeze_rng_state
|
|
from torch.testing._internal.jit_utils import RUN_CUDA_HALF
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
|
|
if __name__ == '__main__':
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead.")
|
|
|
|
class TestMisc(JitTestCase):
|
|
def test_joined_str(self):
|
|
def func(x):
|
|
hello, test = "Hello", "test"
|
|
print(f"{hello + ' ' + test}, I'm a {test}")
|
|
print("format blank")
|
|
hi = 'hi'
|
|
print(f"stuff before {hi}")
|
|
print(f"{hi} stuff after")
|
|
return x + 1
|
|
|
|
x = torch.arange(4., requires_grad=True)
|
|
# TODO: Add support for f-strings in string parser frontend
|
|
# self.checkScript(func, [x], optimize=True, capture_output=True)
|
|
|
|
with self.capture_stdout() as captured:
|
|
out = func(x)
|
|
|
|
scripted = torch.jit.script(func)
|
|
with self.capture_stdout() as captured_script:
|
|
out_script = func(x)
|
|
|
|
self.assertEqual(out, out_script)
|
|
self.assertEqual(captured, captured_script)
|
|
|
|
def test_kwarg_support(self):
|
|
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"):
|
|
class M(torch.nn.Module):
|
|
def forward(self, *, n_tokens: int, device_name: str = 2):
|
|
pass
|
|
torch.jit.script(M())
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, *, n_tokens: int, device_name: str):
|
|
return n_tokens, device_name
|
|
|
|
sm = torch.jit.script(M())
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "missing value for argument 'n_tokens'"):
|
|
sm()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "positional arg"):
|
|
sm(3, 'hello')
|
|
|
|
self.assertEqual(sm(n_tokens=3, device_name='hello'), (3, 'hello'))
|
|
|
|
def test_tuple_subscripted_assign(self):
|
|
with self.assertRaisesRegex(RuntimeError, "subscripted assignment"):
|
|
@torch.jit.script
|
|
def foo(a: Tuple[int, int]) -> None:
|
|
a[0] = a[1]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "augmented assignment"):
|
|
@torch.jit.script
|
|
def bar(a: Tuple[int, int]) -> None:
|
|
a[0] += a[1]
|
|
|
|
def test_subexpression_List_Future(self):
|
|
|
|
@torch.jit.script
|
|
def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]:
|
|
return x[0]
|
|
|
|
FileCheck().check('Future[int]').check('Future[int]').run(fn.graph)
|
|
|
|
def test_subexpression_Future_annotate(self):
|
|
@torch.jit.script
|
|
def fn() -> torch.jit.Future[int]:
|
|
x: List[torch.jit.Future[int]] = []
|
|
return x[0]
|
|
|
|
FileCheck().check("Future[int][]").run(fn.graph)
|
|
|
|
def test_future_isinstance(self):
|
|
@torch.jit.script
|
|
def fn(x: Any) -> torch.jit.Future[int]:
|
|
assert isinstance(x, jit.Future[int])
|
|
return x
|
|
|
|
FileCheck().check("Future[int]").run(fn.graph)
|
|
|
|
def test_str_refine_any(self):
|
|
def forward(x: Any) -> str:
|
|
if isinstance(x, str):
|
|
return x
|
|
return "foo"
|
|
forward = torch.jit.script(forward)
|
|
self.assertEqual(forward(1), "foo")
|
|
self.assertEqual(forward("bar"), "bar")
|
|
|
|
def test_subexpression_Tuple_int_int_Future(self):
|
|
|
|
@torch.jit.script
|
|
def fn(x: Tuple[int, int, torch.jit.Future[int]]) -> Tuple[int, torch.jit.Future[int]]:
|
|
return x[0], x[2]
|
|
|
|
FileCheck().check('(int, int, Future[int])').check('(int, Future[int])').run(fn.graph)
|
|
|
|
def test_subexpression_Dict_int_Future(self):
|
|
|
|
@torch.jit.script
|
|
def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]:
|
|
return x[y]
|
|
|
|
FileCheck().check('Dict(int, Future(int))').check('Future[int]').run(fn.graph)
|
|
|
|
def test_subexpression_Optional(self):
|
|
|
|
@torch.jit.script
|
|
def fn(x: Optional[Dict[int, torch.jit.Future[int]]]) -> Optional[torch.jit.Future[int]]:
|
|
if x is not None:
|
|
return x[0]
|
|
else:
|
|
return None
|
|
|
|
FileCheck().check('Dict(int, Future(int))?').run(fn.graph)
|
|
|
|
def test_if_returning_any(self):
|
|
"""
|
|
Check that an if statement can return different
|
|
types early from each branch when the return
|
|
type of the function is Any.
|
|
"""
|
|
def if_function(inp: torch.Tensor) -> Any:
|
|
if inp.shape[0] == 1:
|
|
return inp * inp
|
|
else:
|
|
return "str"
|
|
|
|
self.checkScript(if_function, (torch.randn(5),))
|
|
|
|
def test_hacked_twin(self):
|
|
|
|
def gen_data():
|
|
with freeze_rng_state():
|
|
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
|
|
|
|
input, index, value, = gen_data()
|
|
input1, index1, value1, = gen_data()
|
|
out1 = torch.ops.aten.index_put.hacked_twin(input, [index], value, accumulate=False)
|
|
out2 = torch.index_put(input1, [index1], value1, accumulate=False)
|
|
self.assertEqual(out1, out2)
|
|
|
|
torch.ops.aten.index_put_.hacked_twin(input, [index], value, accumulate=False)
|
|
torch.index_put_(input1, [index1], value1, accumulate=False)
|
|
self.assertEqual(input, input1)
|
|
|
|
def test_unsafe_hacked_twin(self):
|
|
|
|
def gen_data():
|
|
with freeze_rng_state():
|
|
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
|
|
|
|
input, index, value, = gen_data()
|
|
input1, index1, value1, = gen_data()
|
|
out1 = torch.ops.aten._unsafe_index_put.hacked_twin(input, [index], value, accumulate=False)
|
|
out2 = torch.index_put(input1, [index1], value1, accumulate=False)
|
|
self.assertEqual(out1, out2)
|
|
|
|
torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index])
|
|
torch.index_put(input1, [index1], value1, accumulate=False)
|
|
self.assertEqual(input, input1)
|
|
|
|
def index_put_fn(input, index, value):
|
|
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False)
|
|
|
|
input2, index2, value2 = gen_data()
|
|
script_index_put_fn = torch.jit.script(index_put_fn)
|
|
expect = index_put_fn(input2.clone(), index2, value2)
|
|
actual = script_index_put_fn(input2.clone(), index2, value2)
|
|
self.assertEqual(expect, actual)
|
|
|
|
def index_fn(input, index, value):
|
|
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False)
|
|
|
|
script_index_fn = torch.jit.script(index_fn)
|
|
expect = index_fn(input2.clone(), index2, value2)
|
|
actual = script_index_fn(input2.clone(), index2, value2)
|
|
self.assertEqual(expect, actual)
|
|
|
|
def test_export_opnames_interface(self):
|
|
|
|
@torch.jit.interface
|
|
class OneTwoModule(nn.Module):
|
|
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
pass
|
|
|
|
def two(self, x: torch.Tensor) -> torch.Tensor:
|
|
pass
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
pass
|
|
|
|
class FooMod(nn.Module):
|
|
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x + y
|
|
|
|
def two(self, x: torch.Tensor) -> torch.Tensor:
|
|
return 2 * x
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.one(self.two(x), x)
|
|
|
|
class BarMod(nn.Module):
|
|
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x * y
|
|
|
|
def two(self, x: torch.Tensor) -> torch.Tensor:
|
|
return 2 / x
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.two(self.one(x, x))
|
|
|
|
make_global(OneTwoModule)
|
|
|
|
class M(nn.Module):
|
|
sub : OneTwoModule
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.sub = BarMod()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.sub.forward(x)
|
|
|
|
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
|
|
return mod_list[0].forward(x) + mod_list[1].forward(x)
|
|
|
|
torch._C._enable_mobile_interface_call_export()
|
|
scripted_M_mod = torch.jit.script(M())
|
|
self.assertTrue({'aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'}.issubset(
|
|
set(torch.jit.export_opnames(scripted_M_mod))))
|
|
|
|
scripted_M_mod.sub = torch.jit.script(FooMod())
|
|
self.assertTrue({'aten::add.Tensor', 'aten::mul.Scalar'}.issubset(
|
|
set(torch.jit.export_opnames(scripted_M_mod))))
|
|
|
|
def test_math_inf(self):
|
|
from math import inf
|
|
|
|
def foo():
|
|
return inf
|
|
|
|
self.checkScript(foo, ())
|
|
|
|
def test_list_literal_infer(self):
|
|
def expects_intlist(x: List[int]):
|
|
x.append(3)
|
|
return x
|
|
|
|
def foo():
|
|
return expects_intlist([])
|
|
|
|
self.checkScript(foo, ())
|
|
|
|
def annotated_list_fail():
|
|
return expects_intlist(torch.jit.annotate([], List[Tensor]))
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
torch.jit.script(annotated_list_fail)
|
|
|
|
def non_temporary_fail():
|
|
a = []
|
|
return expects_intlist(a)
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
torch.jit.script(non_temporary_fail)
|
|
|
|
|
|
@torch.jit.script
|
|
def test_return():
|
|
return []
|
|
|
|
FileCheck().check("Tensor[] = prim::ListConstruct").run(test_return.graph)
|
|
|
|
def test_legacy_tensor_constructor(self):
|
|
# testing PyObject overload
|
|
def test_all_dtypes():
|
|
return (
|
|
torch.BoolTensor([2]),
|
|
torch.LongTensor([3]),
|
|
torch.ByteTensor([4]),
|
|
torch.CharTensor([5]),
|
|
torch.DoubleTensor([6]),
|
|
torch.FloatTensor([7]),
|
|
torch.IntTensor([8]),
|
|
torch.ShortTensor([1]),
|
|
torch.HalfTensor([1]),
|
|
)
|
|
|
|
self.checkScript(test_all_dtypes, ())
|
|
|
|
# now test empty overload
|
|
def empty_overload():
|
|
return torch.LongTensor(2, 3, 4)
|
|
|
|
eager = empty_overload()
|
|
jit = torch.jit.script(empty_overload)()
|
|
eager[:] = 1
|
|
jit[:] = 1
|
|
self.assertEqual(eager, jit)
|
|
|
|
def no_inputs():
|
|
return torch.DoubleTensor()
|
|
|
|
self.checkScript(no_inputs, ())
|
|
|
|
# bad schema
|
|
def multiple_args():
|
|
return torch.LongTensor(1, [2])
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "multiple positional arguments that were not all integers"):
|
|
torch.jit.script(multiple_args)
|
|
|
|
# kwarg bad schema
|
|
def bad_kwarg():
|
|
return torch.LongTensor(hello="1")
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "hello"):
|
|
torch.jit.script(bad_kwarg)
|
|
|
|
|
|
def test_broadcasting_list(self):
|
|
"""
|
|
Test BroadcastingList and torch.nn._size_N_t alias
|
|
"""
|
|
from torch._jit_internal import BroadcastingList2
|
|
from torch.nn.common_types import _size_2_t
|
|
|
|
def sum_i(x: _size_2_t) -> int:
|
|
return x[0] + x[1]
|
|
|
|
def sum_f(x: BroadcastingList2[float]) -> float:
|
|
return x[0] + x[1]
|
|
|
|
self.assertTrue(torch.jit.script(sum_i)(4) == 8)
|
|
self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.)
|
|
|
|
def test_parse_ir_annotate(self):
|
|
ir = """
|
|
graph():
|
|
%3 : int[] = prim::Constant[value=annotate(List[int], [])]()
|
|
return (%3)
|
|
"""
|
|
graph = torch._C.parse_ir(ir, True)
|
|
func = torch._C._create_function_from_graph("forward", graph)
|
|
ret = func()
|
|
self.assertTrue(ret == [])
|
|
|
|
def test_parse_ir_single_element_tensor_positive(self):
|
|
ir = """
|
|
graph():
|
|
%7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={0}]()
|
|
return (%7)
|
|
"""
|
|
graph = torch._C.parse_ir(ir, True)
|
|
func = torch._C._create_function_from_graph("forward", graph)
|
|
ret = func()
|
|
self.assertTrue(ret.numel() == 1)
|
|
self.assertTrue(len(ret.size()) == 1)
|
|
|
|
def test_parse_ir_single_element_tensor_negative(self):
|
|
ir = """
|
|
graph():
|
|
%7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={-17}]()
|
|
return (%7)
|
|
"""
|
|
graph = torch._C.parse_ir(ir, True)
|
|
func = torch._C._create_function_from_graph("forward", graph)
|
|
ret = func()
|
|
self.assertTrue(ret.numel() == 1)
|
|
self.assertTrue(len(ret.size()) == 1)
|
|
|
|
|
|
def test_script_many_decorators(self):
|
|
def no_op_decorator(f):
|
|
return f
|
|
|
|
@no_op_decorator
|
|
@no_op_decorator
|
|
@no_op_decorator
|
|
@no_op_decorator
|
|
@no_op_decorator
|
|
def foo(x, dim: int):
|
|
return x.unsqueeze(dim)
|
|
|
|
x = torch.randn(1,)
|
|
expected = foo(x, 0)
|
|
scripted = torch.jit.script(foo)
|
|
actual = scripted(x, 0)
|
|
torch.testing.assert_close(expected, actual)
|
|
|
|
@unittest.skipIf(not RUN_CUDA_HALF, "need CUDA half support")
|
|
def test_pow_multiple_dtype(self):
|
|
# https://github.com/pytorch/pytorch/issues/75476
|
|
def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
|
|
p = torch.sigmoid(p)
|
|
result = p ** gamma
|
|
return result
|
|
|
|
x = torch.rand((2, 2), dtype=torch.half, device='cuda')
|
|
|
|
ref = fn(x)
|
|
|
|
script_fn = torch.jit.script(fn)
|
|
for i in range(4):
|
|
res = script_fn(x)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_jit_get_operation_order(self):
|
|
# See https://github.com/pytorch/pytorch/pull/107138.
|
|
# Depending on order of operator registration, you can get different
|
|
# order of overloads in the JIT operator registry.
|
|
# This is to verify that the order of operators returned by
|
|
# _jit_get_operation always puts aten ops first (i.e. by sorting
|
|
# to put them first)
|
|
|
|
# Make sure that this chooses a "scalar" overload not a "complex" overload
|
|
ret = torch.ops.aten.add(4, 3.3)
|
|
self.assertFalse("complex" in str(ret.dtype))
|
|
|
|
# "Scalar" overload is a normal aten op; "complex" is added by torchscript.
|
|
# We want "Scalar" to come before "complex".
|
|
op, override_names = torch._C._jit_get_operation("aten::add")
|
|
print(override_names)
|
|
complex_indices = [i for i, name in enumerate(override_names) if name == "complex"]
|
|
Scalar_indices = [i for i, name in enumerate(override_names) if name == "Scalar"]
|
|
|
|
self.assertTrue(len(complex_indices) > 0)
|
|
self.assertTrue(len(Scalar_indices) > 0)
|
|
self.assertTrue(complex_indices[0] > Scalar_indices[0])
|