Commit Graph

4 Commits

Author SHA1 Message Date
Angela Yi
d4225c55d9 [fx] Prioritize runtime assertions ops (#124213)
Summary:
We want to prioritize operators involved in data-dependent runtime assertions when legalizing the graph. For example, in the following piece of code, the `assert_scalar` and `assert_async` calls need to occur before the `slice_copy` for the program to run correctly with fake tensors. Otherwise we will run into a data-dependent error.

```
        _local_scalar_dense: "Sym(u113)" = torch.ops.aten._local_scalar_dense.default(aten_minimum_default);  aten_minimum_default = None

        ge_1: "Sym(u113 >= 2)" = _local_scalar_dense >= 2
        aten_scalar_tensor_default_3: "f32[]" = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(ge_1);  ge_1 = None
        aten__assert_async_msg_2 = executorch_exir_dialects_edge__ops_aten__assert_async_msg(aten_scalar_tensor_default_3, '_local_scalar_dense is outside of inline constraint [2, 1000].');  aten_scalar_tensor_default_3 = None
        le_1: "Sym(u113 <= 1000)" = _local_scalar_dense <= 1000
        aten_scalar_tensor_default_4: "f32[]" = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(le_1);  le_1 = None
        aten__assert_async_msg_3 = executorch_exir_dialects_edge__ops_aten__assert_async_msg(aten_scalar_tensor_default_4, '_local_scalar_dense is outside of inline constraint [2, 1000].');  aten_scalar_tensor_default_4 = None

        mul: "Sym(-u112)" = -1 * sym_size;  sym_size = None
        add: "Sym(-u112 + u113)" = _local_scalar_dense + mul;  mul = None
        lt: "Sym(-u112 + u113 < 0)" = add < 0;  add = None
        aten__assert_scalar_default = executorch_exir_dialects_edge__ops_aten__assert_scalar_default(lt, 'Deferred runtime assertion failed -u0 + u1 < 0');  lt = None

        aten_slice_copy_tensor_3: "f32[u113]" = executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor(getitem, 0, 0, _local_scalar_dense);  getitem = None
```

Test Plan: test case

Differential Revision: D56201450

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124213
Approved by: https://github.com/SherlockNoMad
2024-05-07 21:31:10 +00:00
Yuanhao Ji
c96bd3de06 Enable UFMT on all of test/fx (#123622)
Partially addresses #123062

Ran lintrunner on:

- `test/fx`

with command:

```bash
lintrunner -a --take UFMT --all-files
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123622
Approved by: https://github.com/ezyang
2024-04-09 15:59:17 +00:00
Edward Z. Yang
7b9d250f06 Change _dynamo.export to be export(f)(*args, **kwargs) (#106109)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106109
Approved by: https://github.com/voznesenskym
2023-07-27 21:41:13 +00:00
Angela Yi
3c5ec6af14 Partition modules (#98628)
Added helper functions to match nodes in the graph that are decomposed from their source (leaf modules, or functional ops), as a result of dynamo tracing.

`get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]`

Args:
* graph: The graph we want to partition
* wanted_sources: List of sources of nodes that were decomposed from this source. This can be a function (ex. torch.nn.functional.linear) or a leaf module type (ex. torch.nn.Linear)

Returns:
* Dictionary mapping sources (ex. torch.nn.modules.linear.Linear) to a list of SourcePartitions that correspond to the list of nodes that were flattened from a module of that type.

```
@dataclass
class SourcePartition():
    # Nodes in a particular partition
    nodes: List[Node]
    # Module type
    module_type: Type
    # Nodes in the graph that are needed as inputs to the partition
    input_nodes: List[Node] = field(default_factory=list)
    # Nodes in the partition that are being used by nodes outside of the partition
    output_nodes: List[Node] = field(default_factory=list)
    # Parameters that are being used
    params: List[str] = field(default_factory=list)
```

Example:

Original:
```
x -> linear -> linear -> relu -> linear
```
Traced graph:
```
.graph():
    %arg0 : [#users=1] = placeholder[target=arg0]
    %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
    %t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0,), kwargs = {})
    %_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
    %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
    %_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
    %t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0_1,), kwargs = {})
    %_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
    %addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
    %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
    %_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
    %t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant2,), kwargs = {})
    %_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
    %addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
    return [addmm_default_2]
```
Result of `get_module_partitions`:
```
{<class 'torch.nn.modules.linear.Linear'>: [
    ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]),
    ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]),
    ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])],

 <class 'torch.nn.modules.activation.ReLU'>: [
    ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}
```

Also added helper function to check if two module partitions are connected:
`check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98628
Approved by: https://github.com/cccclai
2023-05-03 23:31:56 +00:00