pytorch/test/typing/pass/jit.py
Randolf Scholz 32f50b7021 Improve type annotations for jit.script (#108782)
Fixes #108781

- [x] added `@overload` for `jit.script`
- [x] added typing unittest in `test/typing/pass/jit.py`
    - NOTE: unittest is not automatically checked by mypy when executing lintrunner currently. (how to fix?)
- [x] used `stubgen` to create [torch/jit/_script.pyi](https://github.com/pytorch/pytorch/pull/108782/files#diff-738e66abee2523a952b3ddbaecf95e187cce559473cf8c1b3da7c247ee5d1132) and added overloads there. (adding them inside `_script.py` itself interfered with JIT engine)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108782
Approved by: https://github.com/ezyang
2023-09-13 19:20:25 +00:00

40 lines
894 B
Python

from enum import Enum
from typing import Type, TypeVar
import pytest
from torch import jit, nn, ScriptDict, ScriptFunction, ScriptList
from typing_extensions import assert_never, assert_type, ParamSpec
P = ParamSpec("P")
R = TypeVar("R", covariant=True)
class Color(Enum):
RED = 1
GREEN = 2
BLUE = 3
# Script Enum
assert_type(jit.script(Color), Type[Color])
# ScriptDict
assert_type(jit.script({1: 1}), ScriptDict)
# ScriptList
assert_type(jit.script([0]), ScriptList)
# ScriptModule
scripted_module = jit.script(nn.Linear(2, 2))
assert_type(scripted_module, jit.RecursiveScriptModule)
# ScripFunction
# NOTE: can't use assert_type because of parameter names
# NOTE: Generic usage only possible with Python 3.9
relu: ScriptFunction = jit.script(nn.functional.relu)
# can't script nn.Module class
with pytest.raises(RuntimeError):
assert_never(jit.script(nn.Linear))