mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This PR adds a parametrized test for cond. It tests cond can be traced with valid inputs. Specifically valid inputs is combination of: - pred (python boolean, boolean tensor, int tensor, scalar tensor) - true_fn/false_fn (func, obj, nn_module) - Operands (0 or more tensor inputs), tested with 0 and 2 - closures (0 or more tensor closures), tested with 0 and 2 - nested_level (no nesting or level-2 nested cond) What this test doesn't cover: - pred: symbolic boolean expression as predicate - true_fn/false_fn: that mutates indermediate tensors - operands: non-tensor operands such as float, int - closures: nn_module attribute closures, python constant closures - nested_level: 3+ Pull Request resolved: https://github.com/pytorch/pytorch/pull/110727 Approved by: https://github.com/zou3519 |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| base.py | ||
| builder.py | ||
| builtin.py | ||
| constant.py | ||
| ctx_manager.py | ||
| dicts.py | ||
| distributed.py | ||
| functions.py | ||
| higher_order_ops.py | ||
| lists.py | ||
| misc.py | ||
| nn_module.py | ||
| optimizer.py | ||
| tensor.py | ||
| torch.py | ||
| user_defined.py | ||