mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Reland https://github.com/pytorch/pytorch/pull/128709. When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants. We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph. Test Plan: The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches, Pull Request resolved: https://github.com/pytorch/pytorch/pull/130493 Approved by: https://github.com/BoyuanFeng |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| base.py | ||
| builder.py | ||
| builtin.py | ||
| constant.py | ||
| ctx_manager.py | ||
| dicts.py | ||
| distributed.py | ||
| functions.py | ||
| higher_order_ops.py | ||
| iter.py | ||
| lazy.py | ||
| lists.py | ||
| misc.py | ||
| nn_module.py | ||
| optimizer.py | ||
| script_object.py | ||
| sdpa.py | ||
| tensor.py | ||
| torch_function.py | ||
| torch.py | ||
| user_defined.py | ||