Commit Graph

133 Commits

Author SHA1 Message Date
wz337
febbc48f43 [DeviceMesh] Make our mesh_dim kwarg naming consistent (#114707)
Changing size(self, dim: Optional[int] = None) to def size(self, mesh_dim: Optional[int] = None) so it is consistent with the rest of our APIs.

We also update this API usage change in both PT and internal (pyper, APS).

Differential Revision: [D51602986](https://our.internmc.facebook.com/intern/diff/D51602986/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114707
Approved by: https://github.com/XilunWu, https://github.com/wanchaol, https://github.com/fegin
2023-11-29 19:43:23 +00:00
Aaron Gokaslan
b7b2178204 [BE]: Remove useless lambdas (#113602)
Applies PLW0108 which removes useless lambda calls in Python, the rule is in preview so it is not ready to be enabled by default just yet. These are the autofixes from the rule.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113602
Approved by: https://github.com/albanD
2023-11-14 20:06:48 +00:00
Wanchao Liang
6ed20af10e [dtensor] refactor op dispatch and fix is_same_size/equal (#112927)
torch.equal/is_same_size currently skips sharding prop and directly do
local tensor compute, this is wrong. for these two ops:

- torch.equal: should not skip sharding prop, need to have two DTensor
have the SAME sharding before compare local shard values
- torch.is_same_size: need to completely skip both sharding prop and
local compute

This PR refactors the existing op_dispatch to make it a class instance
so that we can do custom op handling, then fixes both torch.equal and
torch.is_same_size

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112927
Approved by: https://github.com/fduwjj, https://github.com/XilunWu
2023-11-13 22:46:31 +00:00
Adrian Wälchli
866457e746 Fix pydocstyle errors in fully_sharded_data_parallel.py, api.py, graph_utils.py, distribute.py, iter_graph_module.py, comm_tensor.py, experimental_ops.py, batch_dim_utils.py, data_parallel.py, graph_optimization.py (#113216)
Fixes #113191

```
pydocstyle torch/distributed/fsdp/fully_sharded_data_parallel.py --count
```

On master: 80
After my changes on this PR: 3

```
pydocstyle torch/distributed/_spmd/comm_tensor.py --count
```
On master: 5
After my changes on this PR: 3

```
pydocstyle torch/distributed/_spmd/experimental_ops.py --count
```
On master: 3
After my changes on this PR: 1

```
pydocstyle torch/distributed/_spmd/iter_graph_module.py --count
```
On master: 39
After my changes on this PR: 27

```
pydocstyle torch/distributed/_spmd/graph_utils.py --count
```
On master: 16
After my changes on this PR: 4

```
pydocstyle torch/distributed/_spmd/distribute.py --count
```
On master: 19
After my changes on this PR: 10

```
pydocstyle torch/distributed/_spmd/api.py --count
```
On master: 10
After my changes on this PR: 3

```
pydocstyle torch/distributed/_spmd/batch_dim_utils.py  --count
```
On master: 14
After my changes on this PR: 3

```
pydocstyle torch/distributed/_spmd/data_parallel.py --count
```
On master: 34
After my changes on this PR: 2

```
pydocstyle torch/distributed/_spmd/graph_optimization.py --count
```
On master: 35
After my changes on this PR: 13

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113216
Approved by: https://github.com/ezyang
2023-11-10 03:08:32 +00:00
Peter Bell
04024926f4 Use pytree.tree_map_ everywhere (#112417)
Wherever we discard the output of `tree_map` it's better to call `tree_map_`
which doesn't unflatten the mapped results and so is a lot cheaper.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112417
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393, #112394
2023-10-31 15:57:06 +00:00
Peter Bell
66c32d099a Use pytree.arg_tree_leaves everywhere (#112394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112394
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393
2023-10-31 15:57:06 +00:00
Peter Bell
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
Kazuaki Ishizaki
b5f9696d81 Fix typo under torch directory (#110824)
This PR fixes typo `the the` of comments and exception messages in files under `torch` directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110824
Approved by: https://github.com/H-Huang
2023-10-09 19:16:43 +00:00
Wanchao Liang
09f3e08bcc [dtensor][3/n] use dedicated TensorMeta instead of the fx one (#108261)
This PR switches the usage of fx's shape prop TensorMetadata to
dtensor's own dedicated defined TensorMeta, this is because DTensor
only cares three fields: shape/stride/dtype, all other fields are not
necessary and can be inferred from local_tensor directly. This would
help significantly simplify how we deal with the tensor metadata by not
caring other fields.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108261
Approved by: https://github.com/fduwjj
ghstack dependencies: #107306
2023-09-13 04:08:02 +00:00
Wanchao Liang
fc1dcfb9ab [dtensor][2/n] use op overload instead of function schema (#107306)
function schema doesn't provide us anything as we can also get the schema from `op._schema`, include the op directly in op_schema makes easier for sharding prop to do fake execution, and in principle it should also make the hash comparison faster as we don't need to hash the function schema, instead we just hash the `id(op)` which is constant

This PR is just a refactor to include op to OpSchema instead of func schema, no other logic changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107306
Approved by: https://github.com/fduwjj
2023-09-13 04:08:02 +00:00
Wanchao Liang
d8f2ef10a6 [dtensor][1/n] refactor op dispatch logic to reduce overhead (#107305)
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.

This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms
![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76)

after (with this change), aten.addmm latency: 0.341ms
![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f)

overall one layer of mlp time reduced from 13.535 -> 9.665ms

Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305
Approved by: https://github.com/fduwjj
2023-08-18 18:30:46 +00:00
fduwjj
4a6ca4cc05 [TP][DTensor Perf] Some perf improvement to reduce DTensor CPU overhead (#106524)
By inspecting a small TP benchmark, we found couple things we can optimize:
1. We call deep_copy so many times when we initialize DTensor.
2. Some shading_prop is not cached successfully.
3. We are still calling redistribute when not necessary.

![image](https://github.com/pytorch/pytorch/assets/6937752/b847d110-eea1-45df-9298-066d0ba07dd7)

![image](https://github.com/pytorch/pytorch/assets/6937752/fc08f564-caed-496b-80d7-275c1dba3806)

![image](https://github.com/pytorch/pytorch/assets/6937752/fdc06cc4-a4ba-48e8-a118-c041bbd04f5e)

So we want to:
1. Remove the deep_copy, and we now make placements a tuple so we are sure it's immutable.
2. Somehow the op_schema gets changed during sharding_op propogation, so we store a hash version of it before passing it to sharding_prop. Ideally we want to figure out why `op_schema` gets changed, but looks like in both index and detach/view op, all get changed, it might take more time to debug.
3. Also when we do hashing of op_schema, we want to hash the entire args_schema not just the args_spec which only contains the DTensorSpec from args which are Dtensors.
4. It turns out that sometimes, DTensor has mem_format to be None (not contiguous) and this will lead to redistribute get triggered, so that we only need to compare type/shape and stride in the metadata.

Also we need to ensure _Partial and Shard have different hash value in the DTensorSpec.

![image](https://github.com/pytorch/pytorch/assets/6937752/321e6890-1ab6-4975-adc9-524c6ef9a76b)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106524
Approved by: https://github.com/wanchaol
2023-08-14 20:03:19 +00:00
Matthew Hoffman
0616952d13 Merge and improve torch optim optimizer type stubs (#102593)
Fixes #102428

Also improves hook registration type hints:

```python
from typing import Any, Dict, Tuple

from torch import nn
from torch.optim import Adam, Adagrad, Optimizer

linear = nn.Linear(2,2)
optimizer = Adam(linear.parameters(), lr=0.001)

def pre_hook_fn_return_none(optimizer: Adam, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def pre_hook_fn_return_modified(
    optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
    return inputs, kwargs

def hook_fn(optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def hook_fn_other_optimizer(optimizer: Adagrad, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

optimizer.register_step_post_hook(hook_fn)  # OK

optimizer.register_step_pre_hook(pre_hook_fn_return_none)  # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_modified)  # OK

optimizer.register_step_post_hook(hook_fn_other_optimizer)  # Parameter 1: type "Adam" cannot be assigned to type "Adagrad"

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102593
Approved by: https://github.com/janeyx99, https://github.com/malfet
2023-07-26 11:56:42 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
1646d6f939 Revert "Merge and improve torch optim optimizer type stubs (#102593)"
This reverts commit 3279f06410.

Reverted https://github.com/pytorch/pytorch/pull/102593 on behalf of https://github.com/malfet due to There is nothing wrong with this PR, but it fails some internal builds that depend on outdated typing_extensions, will reland when update is done ([comment](https://github.com/pytorch/pytorch/pull/102593#issuecomment-1636062515))
2023-07-14 16:04:54 +00:00
PyTorch MergeBot
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
Nikita Shulga
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
Matthew Hoffman
3279f06410 Merge and improve torch optim optimizer type stubs (#102593)
Fixes #102428

Also improves hook registration type hints:

```python
from typing import Any, Dict, Tuple

from torch import nn
from torch.optim import Adam, Adagrad, Optimizer

linear = nn.Linear(2,2)
optimizer = Adam(linear.parameters(), lr=0.001)

def pre_hook_fn_return_none(optimizer: Adam, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def pre_hook_fn_return_modified(
    optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
    return inputs, kwargs

def hook_fn(optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def hook_fn_other_optimizer(optimizer: Adagrad, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

optimizer.register_step_post_hook(hook_fn)  # OK

optimizer.register_step_pre_hook(pre_hook_fn_return_none)  # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_modified)  # OK

optimizer.register_step_post_hook(hook_fn_other_optimizer)  # Parameter 1: type "Adam" cannot be assigned to type "Adagrad"

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102593
Approved by: https://github.com/janeyx99
2023-07-11 00:07:30 +00:00
Jane Xu
e25f5732c8 Add meta registrations and distributed decomps: _foreach_div_.Scalar, sqrt_.default (#104779)
This PR unblocks #104780 by resolving spmd tracing test issues and by adding meta registrations for foreach inplace ops (div_ and sqrt_)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104779
Approved by: https://github.com/fegin, https://github.com/albanD
2023-07-10 17:38:46 +00:00
Yeonju Ro
06f656c5d1 [distributed] implemented find_all_descendants (#102138)
Fixes #100397

Implemented find_all_descendants function that identifies the list of nodes that need to be moved. Added unit test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102138
Approved by: https://github.com/fegin
2023-05-24 21:47:59 +00:00
Wanchao Liang
d316a2dd5c [spmd] Enable data parallel to work with non 0 batch dim (#100073)
This PR enables data parallel to work with non 0 batch dim, the only
thing we need to do is to expose the input_batch_dim to DataParallelMode
and the data parallel expansion automatically works as we have done
things correctly in batch dim analysis.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100073
Approved by: https://github.com/mrshenli
2023-05-24 17:55:10 +00:00
Wanchao Liang
d378837039 [spmd] add more decomp and fix a sharding bug (#100938)
This PR adds native_layernorm_backward op to the decomp table and fixes
a sharding bug to not automatically do padding
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100938
Approved by: https://github.com/mrshenli
2023-05-24 17:55:10 +00:00
Wanchao Liang
dd1f295201 [spmd] Improve activation handling, factory ops and batch dim reduction (#100853)
This PR improves the activation handling logic of data parallel, to
support the cases where there're tensor factory ops that does not depend
on any input node, it would still produce activation, with either
sharded act (i.e. if output shape have batch size) or replcate act

It also significantly simplify the full reduction logic, now we don't
need the full reduction detection, we only need to ensure that when
compute the batch dim, we detected full reduction and mark it as sharded
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100853
Approved by: https://github.com/mrshenli
2023-05-24 17:55:09 +00:00
Wanchao Liang
4d55ea8548 [spmd] enhance batch dim analysis of data parallel (#100852)
This PR enhances batch dim analysis of data parallel to understand
more on the cases where batch dim get flattened or split, using
dtensor's view ops, we could be able to track the batch dim that got
transformed in non-trival ways.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100852
Approved by: https://github.com/mrshenli
2023-05-24 17:55:07 +00:00
Wanchao Liang
b2eaba6b62 [spmd] by default average gradients for nccl backend (#99964)
This PR by default average gradient for NCCL backend, this allows
SPMD's data parallel match with DDP/FSDP results.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99964
Approved by: https://github.com/mrshenli
2023-05-24 17:55:06 +00:00
Wanchao Liang
942cd12d55 [spmd] add option to preserve node types (#100072)
This PR adds a option to preserve node types for the entire graph,
this could allow some exploration about using those node types to do
things like act checkpoint, etc.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100072
Approved by: https://github.com/mrshenli
2023-05-24 17:55:05 +00:00
Edward Z. Yang
3318a832b3 Tighten FakeTensor reentrancy asserts, add debugging (#102091)
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.

In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.

This PR does a number of things:

* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
2023-05-24 05:37:51 +00:00
Shen Li
af841f38bd [SPMD] Allow Override.replacement to have a global view (#101427)
It's easier for users to implement one Override that takes care of
all target submodules of different types, instead of specifying one
mapping pair for each FQN/type. For example, when calculating
sharding for sparse layers, the decision needs to be make globally.
In this, case it's helpful to allow user Override to get access to
all submodules and make replacement decisions accordingly.

Differential Revision: [D45879732](https://our.internmc.facebook.com/intern/diff/D45879732)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101427
Approved by: https://github.com/fegin
2023-05-15 21:27:41 +00:00
Chien-Chin Huang
49c8a0cad0 [SPMD][BE] Remove the legacy tracing code (#100858)
Remove the legacy tracing code as it cause several test and benchmark issues.

Differential Revision: [D45649123](https://our.internmc.facebook.com/intern/diff/D45649123/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100858
Approved by: https://github.com/wanchaol
2023-05-11 23:08:27 +00:00
lessw2020
ec144b9412 handle new param from torch.compile (Inductor pattern matcher), enable_log (#100814)
This PR puts a placeholder param handler for a new param being passed in from Inductor, enable log.
Fixes this error below, where I've been unable to run torch.compile on NanoGPT due to this error:

~~~
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/_inductor/fx_passes/fuse_attention.py", line 219, in _sfdp_init
    register_replacement(
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/_inductor/pattern_matcher.py", line 658, in register_replacement
    search_gm = trace_fn(search_fn, example_inputs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/_inductor/pattern_matcher.py", line 828, in training_graph
    aot_function(
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
TypeError: patched_aot_function() got an unexpected keyword argument 'enable_log'
~~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100814
Approved by: https://github.com/fegin
2023-05-08 18:34:45 +00:00
Shen Li
2ebb48ff28 [SPMD] add FQN argument to Override.replacement (#100473)
Differential Revision: [D45486089](https://our.internmc.facebook.com/intern/diff/D45486089)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100473
Approved by: https://github.com/wanchaol
2023-05-03 14:20:01 +00:00
Chien-Chin Huang
e0a2b49f0b [SPMD] Introduce prerequisites to graph_optimization_pass (#99970)
Some optimizations require prerequisite passes. It is hard to debug why a optimization pass because of the prerequisites condition does not match. Adding this check makes it easier to discover the error.

Differential Revision: [D45255377](https://our.internmc.facebook.com/intern/diff/D45255377/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99970
Approved by: https://github.com/lessw2020
2023-04-28 18:38:01 +00:00
Chien-Chin Huang
01de8ee845 [SPMD][Easy] Add time counter in graph_optimization_pass (#99969)
This can give the idea how expensive the pass is.

Differential Revision: [D45255366](https://our.internmc.facebook.com/intern/diff/D45255366/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99969
Approved by: https://github.com/lessw2020
2023-04-27 17:56:07 +00:00
Chien-Chin Huang
33fba6ef07 [SPMD] Add arange and zeros to default factory ops (#100037)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100037
Approved by: https://github.com/mrshenli, https://github.com/wanchaol
2023-04-26 16:32:10 +00:00
Wanchao Liang
0901b41a5e [spmd] Add a few more loss ops to the reduction op list (#99900)
This PR adds a few more loss ops to the reduction op list
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99900
Approved by: https://github.com/mrshenli
2023-04-25 19:31:00 +00:00
Wanchao Liang
932ed333f7 [spmd] expose input_batch_dim to DataParallelMode (#99899)
This PR exposes the input batch dim to the DataParallelMode so that
we could have explicit control of which input dim is batch dim
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99899
Approved by: https://github.com/awgu, https://github.com/mrshenli
2023-04-25 19:30:58 +00:00
Wanchao Liang
c6949db481 [spmd] enable fully_shard fused_adam test (#99898)
This PR enables fully_shard fused adam tests with some additional tweaks
about how to handle scalar tensor. Now we treat scalar tensors as if
it's just a scalar value, we don't distribute it as there's no need to
shard a scalar tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99898
Approved by: https://github.com/mrshenli
2023-04-25 19:30:55 +00:00
Wanchao Liang
ad882c5210 [spmd] Use TupleStrategy and enable replicate fused_adam (#99374)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99374
Approved by: https://github.com/mrshenli
2023-04-25 19:30:53 +00:00
Aaron Gokaslan
e2a3817dfd [BE] Enable C419 rule for any all shortcircuiting (#99890)
Apparently https://github.com/pytorch/pytorch/pull/78142 made torch.JIT allow for simple generator expressions which allows us to enable rules that replace unnecessary list comprehensions with generators in any/all. This was originally part of #99280 but I split it off into this PR so that it can be easily reverted should anything break.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99890
Approved by: https://github.com/justinchuby, https://github.com/kit1980, https://github.com/malfet
2023-04-25 15:02:13 +00:00
Wanchao Liang
855f611baf [spmd] skip gradient copying for fused adam (#99489)
gradients does not need to be copy back as it's not useful
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99489
Approved by: https://github.com/mrshenli
2023-04-24 22:50:02 +00:00
Wanchao Liang
9db6920635 [spmd] Add list handling to data parallel and add foreach tests (#99373)
This PR adds list handling logic to the new DataParallel expansion and
add foreach optimizer tests, currently current testing sgd optimizers
in foreach mode, for both replicate and fully shard

Next step:

Add fused optim tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99373
Approved by: https://github.com/mrshenli
2023-04-22 05:39:20 +00:00
Wanchao Liang
c1e2fa8189 [dtensor] add StrategyType and TupleStrategy (#99435)
This PR refactors the current StrategyList. It introduces a
StrategyType, which is the base class of Strategy, and it have
two sub strategies:

1. Refactor the previous StrategyList to OpStrategy
2. Add TupleStrategy, the new strategy added to deal with tuple cases where
it could return multiple different OpStrategy for an op.

This would help support a more complicated op and unblocks compile mode
FSDP
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99435
Approved by: https://github.com/mrshenli
2023-04-22 05:39:20 +00:00
Wanchao Liang
e9bf94149e [spmd] Introduce Compile Mode FSDP with DTensor (#99062)
This PR introduces compile mode Data Parallel (FSDP/DDP) using DTensor sharding.

Along with the algorithm, it also introduces a new DataParallelMode so that `compile` API can take it
and apply data parallel. This PR trys to preserve the DTensorExpand
approach first to avoid BC, we shall discuss steps to remove
DTensorExpand.

The data parallel mode uses heuristics to determine node types in the
graphs and assign the corresponding sharding. The detailed algorithm
described in the design doc.

The benefits of this approach:
- Model parameters and optimizer states are all DTensors after  `spmd.compile`, which is necessary for FSDP, and also makes it super easier for checkpointing
- As model parameter/optim states are sharding in a per-parameter approach, it would be able to compose with sophisticated second order optimizer (i.e. Shampoo) in a easier way.
- We leverage the model parameter/grads information to derive data parallel pattern. In this way we don't need to worry about DTensor op coverage anymore! As data parallel is just a special case of DTensor operation.
- Use dtensor_expand might work for DDP but aren't going to work for FSDP as dtensor might choose to allgather activation, which might violate native fsdp algorithm.
- The approach is general enough to support both DDP/FSDP and a mixed mode

Follow ups:
- Add the "default" data parallel mode which supports mixing of
replicate/fully shard
- Test more e2e models with more different types of optimizers, etc
- migrate the existing stack from the DTensorExpand mode
- build optimizations on top of this prototype

Differential Revision: [D45174400](https://our.internmc.facebook.com/intern/diff/D45174400)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99062
Approved by: https://github.com/mrshenli
2023-04-22 03:13:05 +00:00
Horace He
547bef11ee tweak heuristic for sdpa selection based off of *data* (and a decision tree) (#99644)
High level approach:
1. I generated a bunch of data comparing FlashAttention and Cutlass implementations (https://pastebin.com/pe0j3YeK)
2. I trained a decision tree using standard train/val split methodology and hyperparameter sweeps (https://pastebin.com/fjYX1HjR).
2a. I did a bunch of feature augmentation to capture interactions between features.

The heuristic I ended up with is:
```
use_flash = seq_len / (num_heads * batch_size) > 6
```

TL;DR: On my dataset, where FlashAttention and Cutlass differ by more than 10%, the existing heuristic achieves 69% accuracy.  My new heuristic achieves 94% accuracy.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99644
Approved by: https://github.com/ngimel, https://github.com/drisspg
2023-04-21 23:28:44 +00:00
Wanchao Liang
b96bb2f1a6 [spmd] Introduce ParallelMode and add DTensorExpandMode (#98452)
This PR introduces a ParallelMode interface to define how to do
SPMD expansion and optimize the captured graph. This would be
beneifical for different parallelisms to expand differently
and apply different optimization passes

Put DTensorExpandMode as the first parallel mode that does the
existing dtensor_expand functionality.

Differential Revision: [D45174399](https://our.internmc.facebook.com/intern/diff/D45174399)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98452
Approved by: https://github.com/mrshenli
2023-04-21 17:24:54 +00:00
Edward Z. Yang
abdd1f4a38 Reuse tracing context and fake tensors from backwards in forwards (#99619)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99619
Approved by: https://github.com/wanchaol
2023-04-20 22:39:48 +00:00
Chien-Chin Huang
88c45a1954 [SPMD] Allow users to dynamically pass the last_iter to IterGraphModule (#99575)
The current design of IterGraphModule requires users to specify the concrete iteration count which is not always possible and not very precise. This PR introduce `last_iter` to IterGraphModule.forward() which allows users to dynamically specify the last iteration.

Differential Revision: [D45129585](https://our.internmc.facebook.com/intern/diff/D45129585/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99575
Approved by: https://github.com/lessw2020
2023-04-20 16:49:34 +00:00
Shen Li
ca89e7942a [SPMD][Easy] switch to tree_map_only to simplify code (#99547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99547
Approved by: https://github.com/fegin
2023-04-19 20:40:09 +00:00