Automatically replaces split with rsplit when relevant and only performs the split up to the first ( or last value). This allows early return of the split function and improve efficiency. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160107 Approved by: https://github.com/albanD |
||
|---|---|---|
| .. | ||
| error_reproduction.py | ||
| ops_test_common.py | ||
| ops_test_data.py | ||
| README.md | ||
| test_ops.py | ||
Test op correctness by comparing with PyTorch results using OpInfo
OpInfo is PyTorch's standard mechanism for composing test data for operators.
Read more about them on ce4a097bf7/torch/testing/_internal/opinfo/core.py (L362).
Usage
# All
python -m pytest test_ops.py
# To run tests on a specific operator (e.g. torch.ceil):
python -m pytest test_ops.py -k ceil
# To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):
python -m pytest test_ops.py -k nn_functional_scaled_dot_product_attention
Environment variables
-
Set environment variable
CATCH_ORT_SEGFAULT=1to catch segmentation faults in onnxruntime by running the inference sessions in a separate process. -
Set
CREATE_REPRODUCTION_REPORT=1to create markdown files for reproduction of errors. E.g.CREATE_REPRODUCTION_REPORT=1 python -m pytest test/onnx/torchlib/test_ops.py -k div_mode_int
How to add a new operator test
See usage in ops_test_data.py
How to add custom OpInfo tests
Sometimes, there is no existing OpInfo that fits our need to test an operator. You want to create a custom OpInfo for it.
Follow the steps below to create new OpInfo tests:
-
Use the implementation for
ops.aten.slice_scatteras a reference (e67335101e/tests/function_libs/torch_lib/extra_opinfo.py (L2412-L2418)) to declare anOpInfoinextra_opinfo.py.opinfo_core.OpInfo( "ops.aten.slice_scatter", aten_name="slice_scatter", dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), sample_inputs_func=sample_inputs_slice_scatter, supports_out=False, ),- The first argument should be the operator name under the
torch.opsnamespace. For example, if you want to test theprims.varop, then put"ops.prims.var". It should almost always start withops.. - Follow existing examples to specify the
dtypesyou want to test the op on. - Specify
op=if the target operator is not the same as the OpInfo name (first arg). For examplee67335101e/tests/function_libs/torch_lib/extra_opinfo.py (L2065-L2068).
opinfo_core.OpInfo( "ops.aten.bernoulli.p_deterministic", op=torch.ops.aten.bernoulli.p,The op is
torch.ops.aten.bernoulli.p, which is different from the nameops.aten.bernoulli.p_deterministic. OpInfo names need to be globally unique in a test suite. Whenopis not specified, it will look for the op intorch.using its name. - The first argument should be the operator name under the
-
Implement the
sample_inputs_func. (Ref:e67335101e/tests/function_libs/torch_lib/extra_opinfo.py (L1242-L1268))- Copy the function and decide what the input shapes should be. Use
make_argto generate a torch.Tensor. Alternatively you could also usetorch.tensorto generate the tensor yourself. Be sure to double check the dtype and device. Finally yield each test cases with
yield opinfo_core.SampleInput(input, args=(...), kwargs={...})inputis the first arg. The rest of the args are inargs. - Copy the function and decide what the input shapes should be. Use
-
Enable the test case in
ops_test_data.py- Add a
TorchLibOpInfoentry to theTESTED_TORCHLIB_OPSlist. (For examplee67335101e/tests/function_libs/torch_lib/ops_test_data.py (L2116))
TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter)You can additionally specify dtype tolerance (
e67335101e/tests/function_libs/torch_lib/ops_test_data.py (L539)) or conditional skips (e67335101e/tests/function_libs/torch_lib/ops_test_data.py (L586-L590)). - Add a
Now that the test is added, you may run the test like mentioned above. Set CREATE_REPRODUCTION_REPORT=1 to get markdown reports and view failing input combinations should any test case fails.