[Docs] Fix indentations in cond.md (#156147)

This is a follow-up PR to fix indentations mentioned by https://github.com/pytorch/pytorch/pull/155653#issuecomment-2971660356

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156147
Approved by: https://github.com/svekars, https://github.com/cyyever
This commit is contained in:
windsonsea 2025-09-21 05:50:47 +00:00 committed by PyTorch MergeBot
parent f591bb5056
commit 4a96a6fa4a

View File

@ -34,75 +34,75 @@ Read more about feature classification at: https://pytorch.org/blog/pytorch-feat
Below is an example that uses cond to branch based on input shape:
```python
import torch
import torch
def true_fn(x: torch.Tensor):
return x.cos() + x.sin()
def true_fn(x: torch.Tensor):
return x.cos() + x.sin()
def false_fn(x: torch.Tensor):
return x.sin()
def false_fn(x: torch.Tensor):
return x.sin()
class DynamicShapeCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on dynamic shape predicate.
"""
class DynamicShapeCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on dynamic shape predicate.
"""
def __init__(self):
super().__init__()
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
def true_fn(x: torch.Tensor):
return x.cos()
def forward(self, x: torch.Tensor) -> torch.Tensor:
def true_fn(x: torch.Tensor):
return x.cos()
def false_fn(x: torch.Tensor):
return x.sin()
def false_fn(x: torch.Tensor):
return x.sin()
return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
dyn_shape_mod = DynamicShapeCondPredicate()
dyn_shape_mod = DynamicShapeCondPredicate()
```
We can eagerly run the model and expect the results vary based on input shape:
```python
inp = torch.randn(3)
inp2 = torch.randn(5)
assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
inp = torch.randn(3)
inp2 = torch.randn(5)
assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
```
We can export the model for further transformations and deployment:
```python
inp = torch.randn(4, 3)
dim_batch = torch.export.Dim("batch", min=2)
ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
print(ep)
inp = torch.randn(4, 3)
dim_batch = torch.export.Dim("batch", min=2)
ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
print(ep)
```
This gives us an exported program as shown below:
```
class GraphModule(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
gt: Sym(s0 > 4) = sym_size > 4; sym_size = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
gt: Sym(s0 > 4) = sym_size > 4; sym_size = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
```
Notice that `torch.cond` is lowered to `torch.ops.higher_order.cond`, its predicate becomes a Symbolic expression over the shape of input,
@ -111,41 +111,41 @@ and branch functions becomes two sub-graph attributes of the top level graph mod
Here is another example that showcases how to express a data-dependent control flow:
```python
class DataDependentCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on data dependent predicate.
"""
def __init__(self):
super().__init__()
class DataDependentCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on data dependent predicate.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
```
The exported program we get after export:
```
class GraphModule(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
```
## Invariants of torch.ops.higher_order.cond