mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This PR adds a parametrized test for cond. It tests cond can be traced with valid inputs. Specifically valid inputs is combination of: - pred (python boolean, boolean tensor, int tensor, scalar tensor) - true_fn/false_fn (func, obj, nn_module) - Operands (0 or more tensor inputs), tested with 0 and 2 - closures (0 or more tensor closures), tested with 0 and 2 - nested_level (no nesting or level-2 nested cond) What this test doesn't cover: - pred: symbolic boolean expression as predicate - true_fn/false_fn: that mutates indermediate tensors - operands: non-tensor operands such as float, int - closures: nn_module attribute closures, python constant closures - nested_level: 3+ Pull Request resolved: https://github.com/pytorch/pytorch/pull/110727 Approved by: https://github.com/zou3519 |
||
|---|---|---|
| .. | ||
| attn_ft.py | ||
| attn_positional.py | ||
| common_utils.py | ||
| discover_coverage.py | ||
| functorch_additional_op_db.py | ||
| test_aotdispatch.py | ||
| test_control_flow.py | ||
| test_dims.py | ||
| test_eager_transforms.py | ||
| test_logging.py | ||
| test_memory_efficient_fusion.py | ||
| test_minifier.py | ||
| test_ops.py | ||
| test_parsing.py | ||
| test_rearrange.py | ||
| test_vmap_registrations.py | ||
| test_vmap.py | ||
| xfail_suggester.py | ||