mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Applying doc fixes from PR https://github.com/pytorch/pytorch/pull/127267 - with CLA Pull Request resolved: https://github.com/pytorch/pytorch/pull/132544 Approved by: https://github.com/kit1980
177 lines
6.3 KiB
ReStructuredText
177 lines
6.3 KiB
ReStructuredText
.. _cond:
|
|
|
|
Control Flow - Cond
|
|
====================
|
|
|
|
`torch.cond` is a structured control flow operator. It can be used to specify if-else like control flow
|
|
and can logically be seen as implemented as follows.
|
|
|
|
.. code-block:: python
|
|
|
|
def cond(
|
|
pred: Union[bool, torch.Tensor],
|
|
true_fn: Callable,
|
|
false_fn: Callable,
|
|
operands: Tuple[torch.Tensor]
|
|
):
|
|
if pred:
|
|
return true_fn(*operands)
|
|
else:
|
|
return false_fn(*operands)
|
|
|
|
Its unique power lies in its ability of expressing **data-dependent control flow**: it lowers to a conditional
|
|
operator (`torch.ops.higher_order.cond`), which preserves predicate, true function and false functions.
|
|
This unlocks great flexibility in writing and deploying models that change model architecture based on
|
|
the **value** or **shape** of inputs or intermediate outputs of tensor operations.
|
|
|
|
.. warning::
|
|
`torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and
|
|
doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
|
|
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
|
|
|
|
Examples
|
|
~~~~~~~~
|
|
|
|
Below is an example that uses cond to branch based on input shape:
|
|
|
|
.. code-block:: python
|
|
|
|
import torch
|
|
|
|
def true_fn(x: torch.Tensor):
|
|
return x.cos() + x.sin()
|
|
|
|
def false_fn(x: torch.Tensor):
|
|
return x.sin()
|
|
|
|
class DynamicShapeCondPredicate(torch.nn.Module):
|
|
"""
|
|
A basic usage of cond based on dynamic shape predicate.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
def true_fn(x: torch.Tensor):
|
|
return x.cos()
|
|
|
|
def false_fn(x: torch.Tensor):
|
|
return x.sin()
|
|
|
|
return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
|
|
|
|
dyn_shape_mod = DynamicShapeCondPredicate()
|
|
|
|
We can eagerly run the model and expect the results vary based on input shape:
|
|
|
|
.. code-block:: python
|
|
|
|
inp = torch.randn(3)
|
|
inp2 = torch.randn(5)
|
|
assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
|
|
assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
|
|
|
|
We can export the model for further transformations and deployment:
|
|
|
|
.. code-block:: python
|
|
|
|
inp = torch.randn(4, 3)
|
|
dim_batch = torch.export.Dim("batch", min=2)
|
|
ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
|
|
print(ep)
|
|
|
|
This gives us an exported program as shown below:
|
|
|
|
.. code-block::
|
|
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, arg0_1: f32[s0, 3]):
|
|
sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
|
|
gt: Sym(s0 > 4) = sym_size > 4; sym_size = None
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
|
return (conditional,)
|
|
|
|
class <lambda>(torch.nn.Module):
|
|
def forward(self, arg0_1: f32[s0, 3]):
|
|
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
|
|
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
|
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
|
|
return add
|
|
|
|
class <lambda>(torch.nn.Module):
|
|
def forward(self, arg0_1: f32[s0, 3]):
|
|
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
|
return sin
|
|
|
|
Notice that `torch.cond` is lowered to `torch.ops.higher_order.cond`, its predicate becomes a Symbolic expression over the shape of input,
|
|
and branch functions becomes two sub-graph attributes of the top level graph module.
|
|
|
|
Here is another example that showcases how to express a data-dependent control flow:
|
|
|
|
.. code-block:: python
|
|
|
|
class DataDependentCondPredicate(torch.nn.Module):
|
|
"""
|
|
A basic usage of cond based on data dependent predicate.
|
|
"""
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
|
|
|
|
The exported program we get after export:
|
|
|
|
.. code-block::
|
|
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, arg0_1: f32[s0, 3]):
|
|
sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
|
|
gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None
|
|
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
|
return (conditional,)
|
|
|
|
class <lambda>(torch.nn.Module):
|
|
def forward(self, arg0_1: f32[s0, 3]):
|
|
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
|
|
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
|
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
|
|
return add
|
|
|
|
class <lambda>(torch.nn.Module):
|
|
def forward(self, arg0_1: f32[s0, 3]):
|
|
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
|
return sin
|
|
|
|
|
|
Invariants of torch.ops.higher_order.cond
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
There are several useful invariants for `torch.ops.higher_order.cond`:
|
|
|
|
- For predicate:
|
|
- Dynamicness of predicate is preserved (e.g. `gt` shown in the above example)
|
|
- If the predicate in user-program is constant (e.g. a python bool constant), the `pred` of the operator will be a constant.
|
|
|
|
- For branches:
|
|
- The input and output signature will be a flattened tuple.
|
|
- They are `torch.fx.GraphModule`.
|
|
- Closures in original function becomes explicit inputs. No closures.
|
|
- No mutations on inputs or globals are allowed.
|
|
|
|
- For operands:
|
|
- It will also be a flat tuple.
|
|
|
|
- Nesting of `torch.cond` in user program becomes nested graph modules.
|
|
|
|
|
|
API Reference
|
|
-------------
|
|
.. autofunction:: torch._higher_order_ops.cond.cond
|