mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Docs] Fix indentations in cond.md (#156147)
This is a follow-up PR to fix indentations mentioned by https://github.com/pytorch/pytorch/pull/155653#issuecomment-2971660356 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156147 Approved by: https://github.com/svekars, https://github.com/cyyever
This commit is contained in:
parent
f591bb5056
commit
4a96a6fa4a
|
|
@ -34,75 +34,75 @@ Read more about feature classification at: https://pytorch.org/blog/pytorch-feat
|
|||
Below is an example that uses cond to branch based on input shape:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch
|
||||
|
||||
def true_fn(x: torch.Tensor):
|
||||
return x.cos() + x.sin()
|
||||
def true_fn(x: torch.Tensor):
|
||||
return x.cos() + x.sin()
|
||||
|
||||
def false_fn(x: torch.Tensor):
|
||||
return x.sin()
|
||||
def false_fn(x: torch.Tensor):
|
||||
return x.sin()
|
||||
|
||||
class DynamicShapeCondPredicate(torch.nn.Module):
|
||||
"""
|
||||
A basic usage of cond based on dynamic shape predicate.
|
||||
"""
|
||||
class DynamicShapeCondPredicate(torch.nn.Module):
|
||||
"""
|
||||
A basic usage of cond based on dynamic shape predicate.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def true_fn(x: torch.Tensor):
|
||||
return x.cos()
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def true_fn(x: torch.Tensor):
|
||||
return x.cos()
|
||||
|
||||
def false_fn(x: torch.Tensor):
|
||||
return x.sin()
|
||||
def false_fn(x: torch.Tensor):
|
||||
return x.sin()
|
||||
|
||||
return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
|
||||
return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
|
||||
|
||||
dyn_shape_mod = DynamicShapeCondPredicate()
|
||||
dyn_shape_mod = DynamicShapeCondPredicate()
|
||||
```
|
||||
|
||||
We can eagerly run the model and expect the results vary based on input shape:
|
||||
|
||||
```python
|
||||
inp = torch.randn(3)
|
||||
inp2 = torch.randn(5)
|
||||
assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
|
||||
assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
|
||||
inp = torch.randn(3)
|
||||
inp2 = torch.randn(5)
|
||||
assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
|
||||
assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
|
||||
```
|
||||
|
||||
We can export the model for further transformations and deployment:
|
||||
|
||||
```python
|
||||
inp = torch.randn(4, 3)
|
||||
dim_batch = torch.export.Dim("batch", min=2)
|
||||
ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
|
||||
print(ep)
|
||||
inp = torch.randn(4, 3)
|
||||
dim_batch = torch.export.Dim("batch", min=2)
|
||||
ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
|
||||
print(ep)
|
||||
```
|
||||
|
||||
This gives us an exported program as shown below:
|
||||
|
||||
```
|
||||
class GraphModule(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
|
||||
gt: Sym(s0 > 4) = sym_size > 4; sym_size = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
return (conditional,)
|
||||
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
|
||||
gt: Sym(s0 > 4) = sym_size > 4; sym_size = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
return (conditional,)
|
||||
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
|
||||
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
|
||||
return add
|
||||
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
|
||||
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
|
||||
return add
|
||||
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
return sin
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
return sin
|
||||
```
|
||||
|
||||
Notice that `torch.cond` is lowered to `torch.ops.higher_order.cond`, its predicate becomes a Symbolic expression over the shape of input,
|
||||
|
|
@ -111,41 +111,41 @@ and branch functions becomes two sub-graph attributes of the top level graph mod
|
|||
Here is another example that showcases how to express a data-dependent control flow:
|
||||
|
||||
```python
|
||||
class DataDependentCondPredicate(torch.nn.Module):
|
||||
"""
|
||||
A basic usage of cond based on data dependent predicate.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
class DataDependentCondPredicate(torch.nn.Module):
|
||||
"""
|
||||
A basic usage of cond based on data dependent predicate.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
|
||||
```
|
||||
|
||||
The exported program we get after export:
|
||||
|
||||
```
|
||||
class GraphModule(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
|
||||
gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None
|
||||
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
return (conditional,)
|
||||
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
|
||||
gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None
|
||||
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
|
||||
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
|
||||
return add
|
||||
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
return (conditional,)
|
||||
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
|
||||
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
|
||||
return add
|
||||
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
return sin
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[s0, 3]):
|
||||
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
return sin
|
||||
```
|
||||
|
||||
## Invariants of torch.ops.higher_order.cond
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user