mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR adds an internal wrapper on the [beartype](https://github.com/beartype/beartype) library to perform runtime type checking in `torch.onnx`. It uses beartype when it is found in the environment and is reduced to a no-op when beartype is not found. Setting the env var `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=ERRORS` will turn on the feature. setting `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=DISABLED` will disable all checks. When not set and `beartype` is installed, a warning message is emitted. Now when users call an api with invalid arguments e.g. ```python torch.onnx.export(conv, y, path, export_params=True, training=False) # traning should take TrainingModel, not bool ``` they get ``` Traceback (most recent call last): File "bisect_m1_error.py", line 63, in <module> main() File "bisect_m1_error.py", line 59, in main reveal_error() File "bisect_m1_error.py", line 32, in reveal_error torch.onnx.export(conv, y, cpu_model_path, export_params=True, training=False) File "<@beartype(torch.onnx.utils.export) at 0x1281f5a60>", line 136, in export File "pytorch/venv/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception raise exception_cls( # type: ignore[misc] beartype.roar.BeartypeCallHintParamViolation: @beartyped export() parameter training=False violates type hint <class 'torch._C._onnx.TrainingMode'>, as False not instance of <protocol "torch._C._onnx.TrainingMode">. ``` when `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK` is not set and `beartype` is installed, a warning message is emitted. ``` >>> torch.onnx.export("foo", "bar", "f") <stdin>:1: CallHintViolationWarning: Traceback (most recent call last): File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 54, in _coerce_beartype_exceptions_to_warnings return beartyped(*args, **kwargs) File "<@beartype(torch.onnx.utils.export) at 0x7f1d4ab35280>", line 39, in export File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception raise exception_cls( # type: ignore[misc] beartype.roar.BeartypeCallHintParamViolation: @beartyped export() parameter model='foo' violates type hint typing.Union[torch.nn.modules.module.Module, torch.jit._script.ScriptModule, torch.jit.ScriptFunction], as 'foo' not <protocol "torch.jit.ScriptFunction">, <protocol "torch.nn.modules.module.Module">, or <protocol "torch.jit._script.ScriptModule">. Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 63, in _coerce_beartype_exceptions_to_warnings return func(*args, **kwargs) File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 482, in export _export( File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 1422, in _export with exporter_context(model, training, verbose): File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__ return next(self.gen) File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 177, in exporter_context with select_model_mode_for_export( File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__ return next(self.gen) File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 95, in select_model_mode_for_export originally_training = model.training AttributeError: 'str' object has no attribute 'training' ``` We see the error is caught right when the type mismatch happens, improving from what otherwise would become `AttributeError: 'str' object has no attribute 'training'` Pull Request resolved: https://github.com/pytorch/pytorch/pull/83673 Approved by: https://github.com/BowenBao
96 lines
2.9 KiB
Bash
Executable File
96 lines
2.9 KiB
Bash
Executable File
#!/bin/bash
|
|
|
|
set -ex
|
|
|
|
UNKNOWN=()
|
|
|
|
# defaults
|
|
PARALLEL=1
|
|
export TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=ERRORS
|
|
|
|
while [[ $# -gt 0 ]]
|
|
do
|
|
arg="$1"
|
|
case $arg in
|
|
-p|--parallel)
|
|
PARALLEL=1
|
|
shift # past argument
|
|
;;
|
|
*) # unknown option
|
|
UNKNOWN+=("$1") # save it in an array for later
|
|
shift # past argument
|
|
;;
|
|
esac
|
|
done
|
|
set -- "${UNKNOWN[@]}" # leave UNKNOWN
|
|
|
|
if [[ $PARALLEL == 1 ]]; then
|
|
pip install pytest-xdist
|
|
fi
|
|
|
|
# pytest, scipy, hypothesis: these may not be necessary
|
|
# pytest-cov: installing since `coverage run -m pytest ..` doesn't work
|
|
# parameterized: parameterizing test class
|
|
pip install pytest scipy hypothesis pytest-cov parameterized
|
|
pip install -e tools/coverage_plugins_package # allows coverage to run w/o failing due to a missing plug-in
|
|
|
|
# realpath might not be available on MacOS
|
|
script_path=$(python -c "import os; import sys; print(os.path.realpath(sys.argv[1]))" "${BASH_SOURCE[0]}")
|
|
top_dir=$(dirname $(dirname $(dirname "$script_path")))
|
|
test_paths=(
|
|
"$top_dir/test/onnx"
|
|
)
|
|
|
|
args=()
|
|
args+=("-v")
|
|
args+=("--cov")
|
|
args+=("--cov-report")
|
|
args+=("xml:test/coverage.xml")
|
|
args+=("--cov-append")
|
|
|
|
args_parallel=()
|
|
if [[ $PARALLEL == 1 ]]; then
|
|
args_parallel+=("-n")
|
|
args_parallel+=("auto")
|
|
fi
|
|
|
|
# onnxruntime only support py3
|
|
# "Python.h" not found in py2, needed by TorchScript custom op compilation.
|
|
if [[ "${SHARD_NUMBER}" == "1" ]]; then
|
|
# These exclusions are for tests that take a long time / a lot of GPU
|
|
# memory to run; they should be passing (and you will test them if you
|
|
# run them locally
|
|
pytest "${args[@]}" "${args_parallel[@]}" \
|
|
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py" \
|
|
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
|
|
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py" \
|
|
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
|
|
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
|
|
--ignore "$top_dir/test/onnx/test_models.py" \
|
|
--ignore "$top_dir/test/onnx/test_models_quantized_onnxruntime.py" \
|
|
"${test_paths[@]}"
|
|
|
|
# Heavy memory usage tests that cannot run in parallel.
|
|
pytest "${args[@]}" \
|
|
"$top_dir/test/onnx/test_custom_ops.py" \
|
|
"$top_dir/test/onnx/test_utility_funs.py" \
|
|
"$top_dir/test/onnx/test_models_onnxruntime.py" "-k" "not TestModelsONNXRuntime"
|
|
fi
|
|
|
|
if [[ "${SHARD_NUMBER}" == "2" ]]; then
|
|
# Heavy memory usage tests that cannot run in parallel.
|
|
# TODO(#79802): Parameterize test_models.py
|
|
pytest "${args[@]}" \
|
|
"$top_dir/test/onnx/test_models.py" \
|
|
"$top_dir/test/onnx/test_models_quantized_onnxruntime.py" \
|
|
"$top_dir/test/onnx/test_models_onnxruntime.py" "-k" "TestModelsONNXRuntime"
|
|
|
|
pytest "${args[@]}" "${args_parallel[@]}" \
|
|
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py"
|
|
fi
|
|
|
|
# Our CI expects both coverage.xml and .coverage to be within test/
|
|
if [ -d .coverage ]; then
|
|
mv .coverage test/.coverage
|
|
fi
|