Commit Graph

54 Commits

Author SHA1 Message Date
Maggie Moss
eb83c3ca23 Clean up unused Pyrefly suppressions (#166178)
Cleaning up ignores that are no longer needed in the repo and adding select suppressions so the main branch is clean.

test plan:
`lintrunner -a`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166178
Approved by: https://github.com/oulgen
2025-10-25 05:32:21 +00:00
Maggie Moss
7457d139c5 Add pyrefly suppressions to torch/distributed (7/n) (#165002)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

One more PR after this one.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165002
Approved by: https://github.com/oulgen
2025-10-09 04:08:25 +00:00
Rohit Singh Rathaur
6389658ec6 Fix type hints in PrepareModuleInput and PrepareModuleInputOutput (#164482)
Fixes #161646

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164482
Approved by: https://github.com/Skylion007
2025-10-02 21:40:43 +00:00
Tianyu Liu
d2ad9aa2f2 [dtensor][tp] add a ParallelStyle PrepareModuleInputOutput (#150372)
Needed this class for because `parallelize_module` takes a dict, which doesn't allow `PrepareModuleInput` and `PrepareModuleOutput` to be applied at the same time.

The `PrepareModuleInputOutput` in this PR initializes two variables `prepare_module_input` and `prepare_module_output` and uses them to process module / inputs / outputs.

I had another implementation which put all code in `PrepareModuleInputOutput` and let `PrepareModuleInput` and `PrepareModuleOutput` inherit the monolithic `PrepareModuleInputOutput`. But it is
1. less cleaner
2. conceptually abusing inheritance because `PrepareModuleInput` shouldn't be able to access class methods of `PrepareModuleOutput` and vice versa

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150372
Approved by: https://github.com/wanchaol
2025-04-01 19:15:43 +00:00
Yuanhao Ji
bf6621d08f [Distributed] Add repr methods for ParallelStyles (#149478)
Fixes #149470

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149478
Approved by: https://github.com/wanchaol
2025-03-21 03:59:25 +00:00
Xuehai Pan
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
Aaron Orenstein
c95efc37ba PEP585 update - torch/distributed/tensor (#145141)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145141
Approved by: https://github.com/bobrenjc93
2025-01-18 20:01:59 +00:00
bobrenjc93
08be9ec312 Migrate from Tuple -> tuple in torch/distributed (#144258)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144258
Approved by: https://github.com/aorenste
2025-01-10 08:34:54 +00:00
Wanchao Liang
0431d47eaa [tp] propagate src_data_rank kwarg in TP API (#144005)
as titled, this PR propagates the src_data_rank in the TP API, so that
module level APIs could leverage the flexibility to choose
src_data_rank, and avoid the communication if it does not need to

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144005
Approved by: https://github.com/tianyu-l
ghstack dependencies: #143883
2025-01-02 05:35:52 +00:00
Wanchao Liang
cfc227ad43 [reland][dtensor] move DTensor to public namespace (#134203)
reland of https://github.com/pytorch/pytorch/pull/133113

I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(

----

Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
2024-09-08 17:08:40 +00:00
PyTorch MergeBot
35f36363ec Revert "[dtensor] move DTensor to public namespace (#133113)"
This reverts commit 2ee6b97464.

Reverted https://github.com/pytorch/pytorch/pull/133113 on behalf of https://github.com/wanchaol due to looks like it break some internal type imports ([comment](https://github.com/pytorch/pytorch/pull/133113#issuecomment-2295670911))
2024-08-19 05:00:19 +00:00
Wanchao Liang
2ee6b97464 [dtensor] move DTensor to public namespace (#133113)
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
  PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
  I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
2024-08-17 05:09:52 +00:00
Wanchao Liang
9f17037e8b [dtensor] move tensor constructors to the api module (#133129)
This is to ensure __init__.py only contain public APIs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133129
Approved by: https://github.com/awgu, https://github.com/tianyu-l
2024-08-13 06:09:56 +00:00
PyTorch MergeBot
00aa086298 Revert "[dtensor] move tensor constructors to a separate module (#133129)"
This reverts commit e890d888d9.

Reverted https://github.com/pytorch/pytorch/pull/133129 on behalf of https://github.com/fbgheith due to breaking internal tests ([comment](https://github.com/pytorch/pytorch/pull/133129#issuecomment-2285090400))
2024-08-12 23:55:08 +00:00
Wanchao Liang
e890d888d9 [dtensor] move tensor constructors to a separate module (#133129)
This is to ensure __init__.py only contain public APIs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133129
Approved by: https://github.com/awgu, https://github.com/tianyu-l
2024-08-10 02:51:42 +00:00
Ke Wen
3d7f541597 [BE][TP] Check module has bias before access (#132137)
Some linear modules, such as the ones reconstructed by `torch.export.unflatten()`, may not have the `bias` attribute, if the original linear module has `bias=None`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132137
Approved by: https://github.com/wanchaol
2024-07-31 13:45:28 +00:00
Wanchao Liang
35a0e0f018 [tp] improve SequenceParallel and its documentation (#131346)
SequenceParallel style assumes the input torch.Tensor ALREADY sharded on
the sequence dimension if not passing in DTensor. Since it causes some
user confusion on the documentation, this PR:

1. for the case where input passed in is already a DTensor, we check the
   input placements and redistribute if it's not sharded on the sequence
dimension
2. update the doc to make it more explicit about the case when user
   passed in a torch.Tensor and DTensor

This would fix https://github.com/pytorch/pytorch/issues/129355

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131346
Approved by: https://github.com/awgu
2024-07-23 03:57:01 +00:00
Xuehai Pan
cec31050b4 [BE][Easy] enable UFMT for torch/distributed/{tensor,_tensor}/ (#128868)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128868
Approved by: https://github.com/fegin
2024-06-18 21:49:02 +00:00
Wanchao Liang
7775fee10f [tp] refactor and fix PrepareModuleInput for DTensor inputs (#128431)
as titled, this PR refactors the PrepareModuleInput style to have common
method prepare_input_arg, allow both args/kwargs to reuse this logic

This also fixes https://github.com/pytorch/pytorch/issues/128365

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128431
Approved by: https://github.com/awgu
2024-06-12 19:16:33 +00:00
PyTorch MergeBot
a421699998 Revert "[tp] refactor and fix PrepareModuleInput for DTensor inputs (#128431)"
This reverts commit 089f9a116a.

Reverted https://github.com/pytorch/pytorch/pull/128431 on behalf of https://github.com/DanilBaibak due to Sorry for the revert. Your changes broke the linter. Here you can find more details - 089f9a116a ([comment](https://github.com/pytorch/pytorch/pull/128431#issuecomment-2162197858))
2024-06-12 06:25:53 +00:00
Wanchao Liang
089f9a116a [tp] refactor and fix PrepareModuleInput for DTensor inputs (#128431)
as titled, this PR refactors the PrepareModuleInput style to have common
method prepare_input_arg, allow both args/kwargs to reuse this logic

This also fixes https://github.com/pytorch/pytorch/issues/128365

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128431
Approved by: https://github.com/awgu
2024-06-12 05:22:24 +00:00
Aaron Orenstein
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
Wanchao Liang
05addd5658 [tp] add kwargs support to prepare_module_input (#124114)
as titled, this PR adds kwargs support to PrepareModuleInput style,
where there might be modules who have only kwargs inputs but no
positional args, so we should support this

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124114
Approved by: https://github.com/XilunWu
2024-04-22 21:46:31 +00:00
Wanchao Liang
a26480a4d1 [dtensor] move early return check into redistribute autograd function (#121653)
This PR fixed the bug of redistribute to move early return check into the
redistribute autograd function, so that even though we redistribute the
same placement, the grad_placements from the `to_local` call might be
different, the redistribute backward still need to happen

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121653
Approved by: https://github.com/awgu
2024-03-12 17:37:30 +00:00
Wanchao Liang
242e03ba86 [dtensor] add async_op option to redistribute and some refactor (#121477)
async output option was only available in `full_tensor()` call, but I think it's
generally good to make this option available in the `redistribute` call directly
so that user can control it

This PR adds async_op option to redistribute call, to allow user control
whether to perform tensor redistribution asynchronously or not.

By default we set this to False, this is to follow the semantics of the c10d
collectives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121477
Approved by: https://github.com/wz337
2024-03-09 06:17:23 +00:00
Wanchao Liang
30982ce072 [tp] doc fixes (#121431)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121431
Approved by: https://github.com/wz337
2024-03-08 17:46:44 +00:00
Wanchao Liang
1a28ebffb3 [TP] Introduce Sequence Parallel Style for Laynorm/RMSNorm/Dropout (#121295)
As titled, this PR introduces a dedicated `ParallelStyle` to shard the
nn.LayerNorm/nn.Dropout/RMSNorm layers. We were mainly using a manual
distribute_module calls before when sharding the RMSNorm layer, but I
think we should have a dedicate TP API to easily shard those layers,
instead of user manually using DTensors.

I call this SequenceParallel, which might bring some confusion that we
technically "deprecated" a SequenceParallel style months ago. But this
time the SeuqenceParallel style is significantly different with the
previous ones (which used to shard two consecutive Linear layers). I
believe making it the right name is the first priority, instead of
worrying about the issue of reusing the old name

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121295
Approved by: https://github.com/awgu, https://github.com/tianyu-l
ghstack dependencies: #121294
2024-03-07 02:04:59 +00:00
Wanchao Liang
2e50566722 [dtensor] change distribute_module input/output_fn to accept module (#120895)
This is a BC breaking change to distribute_module. The underlying rationle
for this change is that sometimes in the input_fn/output_fn, user would want
to access to the current module for some attributes. This might not be
common enough, but in some cases it's worth to access to the module.

An outstanding use case we want to support is float8, if we want to make
float8 works with the TP API, the input_fn/output_fn of TP parallel
styles would need to get access to the module, where the module might
encapsulates `dynamic_linear.emulate` attribute, that is useful for
input/output casting

Since this is needed for fp8 and DTensor still under prototype release,
I feel it's worth the change and it's better we make the change as
early.

Right now making it a soft BC breaking, which means we maintain BC still
but throw deprecation messages.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120895
Approved by: https://github.com/tianyu-l
2024-03-04 07:22:32 +00:00
Wanchao Liang
e696fa1ee7 [tp] enable rowwise embedding sharding in RowwiseParallel (#118242)
As titled, this PR enables the rowwise embedding sharding in the
RowwiseParallel style, and add tests to ensure it's working as expected

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118242
Approved by: https://github.com/tianyu-l
ghstack dependencies: #118079, #118080
2024-01-26 19:01:24 +00:00
PyTorch MergeBot
bc67f87559 Revert "[tp] enable rowwise embedding sharding in RowwiseParallel (#118242)"
This reverts commit 7a9012d7e8.

Reverted https://github.com/pytorch/pytorch/pull/118242 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/118079#issuecomment-1911681293))
2024-01-26 08:47:14 +00:00
Wanchao Liang
7a9012d7e8 [tp] enable rowwise embedding sharding in RowwiseParallel (#118242)
As titled, this PR enables the rowwise embedding sharding in the
RowwiseParallel style, and add tests to ensure it's working as expected

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118242
Approved by: https://github.com/tianyu-l
ghstack dependencies: #118079, #118080
2024-01-26 01:36:24 +00:00
Wanchao Liang
2bb2cc0b71 [tp] add clarification to doc and improve TP examples (#117618)
This PR adds a clarification about evenly sharded assumption in the main
tp doc and improved the tp examples by adding device mesh constructions

fixes https://github.com/pytorch/pytorch/issues/100044

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117618
Approved by: https://github.com/wconstab, https://github.com/awgu
2024-01-22 18:56:50 +00:00
Wanchao Liang
b10cb168a7 [tp] disable some assertion temporarily for torch.compile (#116573)
Disable some runtime assertion first as it does not work with
torch.compile properly, I'll have a follow up fix in dynamo and reenable
this check again

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116573
Approved by: https://github.com/awgu, https://github.com/XilunWu
ghstack dependencies: #116426, #116559
2024-01-03 23:01:19 +00:00
Carlos Mocholí
9df4ee8d38 Fix ColwiseParallel typo (#116151)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116151
Approved by: https://github.com/wanchaol
2023-12-20 06:40:32 +00:00
Tianyu Liu
2a5659a797 add length assertion to PrepareModuleInput and PrepareModuleOutput (#115957)
## summary

`zip(inputs, self.input_layouts, self.desired_input_layouts)` is used in `_prepare_input_fn`; similar for `_prepare_output_fn`. Without assertion, unmatched dimension in inputs/outputs will be lost, potentially causing unexpected behabiors.

## test plan
`python test/distributed/tensor/parallel/test_tp_style.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115957
Approved by: https://github.com/wanchaol
2023-12-18 21:50:18 +00:00
Wanchao Liang
a1a0b290d2 [tp] further fix the docs (#115974)
some typo result in the note section not rendered properly, can't see
this from the last PR directly as the last PR only show the first commit
documentation :(

Also make the parallelize_module doc example more concrete

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115974
Approved by: https://github.com/wz337
2023-12-18 20:41:53 +00:00
Wanchao Liang
61abacf829 [tp] improve documentation (#115880)
Improve the TP documentation in terms of format and descriptions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115880
Approved by: https://github.com/XilunWu
2023-12-15 18:44:22 +00:00
Wanchao Liang
28925902fa [TP] fully rewrite Tensor Parallel APIs (#114732)
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.

In particular this PR:

* Make ParallelStyle to be a real contract API for parallelize_module to
  take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
  refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
  both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
  output_layouts to desired_input/output_layouts, group
  the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
  standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
  style, all prepare input/output functions) as we throw deprecation
 msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
  mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
  documentation more clear about what each style is doing

TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes

Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 08:18:12 +00:00
NVS Abhilash
44c0521e8c fix: docstring error in torch/distributed module (#113241)
Fixes: #113193

`pydocstyle <all_files_in_issue> --count`

- Before: 345
- After: 130

For deprecated methods, I have added a `noqa` to ignore them. I was not able to find the file `torch/distributed/tensor/parallel/multihead_attention_tp.py`, so I've ignored it for this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113241
Approved by: https://github.com/kit1980
2023-11-09 19:10:20 +00:00
Wanchao Liang
033680c9af [tp] fix PrepareModuleInput for multiple inputs (#112204)
Not all inputs needs to annotate shardings and convert to DTensors, if
user annotate only one inputs are mark the rest as Nones, we should skip
creating DTensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112204
Approved by: https://github.com/fduwjj
2023-10-27 05:08:05 +00:00
fduwjj
fdc29f58c6 [TP] Refactor style to make it work with torch.compile (#111625)
We are refactoring parallel style to solve the following things:
1. To further simplifying code logic to make more readable for users.
2. To remove tuple check so that we can work with dynamo for now. Ideally dynamo needs to support this case and we will fix it in parallel.
3. Add tests for newly added parallel style in UT and torch compile test so that we can capture regression due to code change.
4. Move placements early return check into DTensor since it is by passed by dynamo.
5. Remove PairwiseParallelStyle from unit tests to use the new Col/Rowwise parallel style.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111625
Approved by: https://github.com/wanchaol
2023-10-20 19:20:43 +00:00
Wanchao Liang
03e28bde2e [tp] fix torch compile regression (#111521)
The most recent refactor of TP
https://github.com/pytorch/pytorch/pull/111160 breaks torch compile
path, so reverting the behavior back by:
1. use the old default prepare_input/output
2. add the colwise/rowwise parallel test instead
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111521
Approved by: https://github.com/fduwjj
2023-10-19 10:27:10 +00:00
Wanchao Liang
59281d5631 [tp] fix SP style regression (#111353)
[tp] fix SP style regression

Although we want to remove prepare_input/output, we should still keep
the old behavior for SequenceParallel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111353
Approved by: https://github.com/fduwjj
2023-10-16 17:18:17 +00:00
fduwjj
bfcd86955e [TP] Fix TP doc format to show examples correctly (#111346)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111346
Approved by: https://github.com/wanchaol
ghstack dependencies: #111160, #111166, #111176, #111177
2023-10-16 06:15:10 +00:00
fduwjj
25a2845d78 [TP] Enable embedding sharding in TP API (#111177)
We see use cases where embedding sharding is also needed in TP API so we enabled it in the API since DTensor already support colwise embedding sharding.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111177
Approved by: https://github.com/wanchaol
ghstack dependencies: #111160, #111166, #111176
2023-10-15 11:49:56 +00:00
fduwjj
8085e08a84 [TP] Add prepareInput and output for input/output DTensor layout annotation in the parent module in TP API (#111166)
In some use cases, we found that users might want to annote the input/output DTensor layout for the parent module rather than the submodule whose parameters are to be distributed so that we want to have these two class for users to annote input/output DTensor layouts so that we register pre-FWD/FWD hook for the TP-lized module.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111166
Approved by: https://github.com/wanchaol
ghstack dependencies: #111160
2023-10-14 15:37:52 +00:00
fduwjj
3a8b10e2da [TP] Refactor Parallel Style to make it more usable (#111160)
One thing we find it challenging for users is that we don't want to expose the concept of prepare_input and prepare_out to users since there are so many func names for users to select from which is quite confusing. On the other hand, the colwise and rowwise parallel always need input(out) and output(in) to be certain layout so we can somehow simplify the logic here and make it more usable.

So we added three public attributes to the parallelStyle here and the code logic is like:

```python
class ParallelStyle(ABC):
    """
    The parallel style user wants the module or submodule to be parallelized.
    We can add more in future, but this seems sufficient for immediate needs. Users can extend this class to build their own parallel style with customized input/output preparations.
  """
    input_layouts: Union[placement, Tuple[placement]]
    output_layouts: Union[placement, Tuple[placement]]
    use_local: bool

class RowwiseParallel(ParallelStyle):
    """
    Partitioning the row of a module. We assume the input to be a sharded DTensor and output to be a replicate Tensor.
    """
    def __init__(self):
        super().__init__(input_layouts=Shard(-1), output_layouts=Replicate(), use_local=True)

Class ColwiseParallel(ParallelStyle):
    """
    Partitioning the column of a module. We assume the input to be a Replicated DTensor and output to be a sharded DTensor.
    """
    def __init__(self):
        super().__init__(input_layouts=Replicate(), output_layouts=Shard(-1), use_local=True)

# For the case of Sequence parallel, users just set different input_shard, Shard(0) or Shard(1) instead of Replicate()

Class PrepareModuleInput(ParallelStyle):
    """
    Only used to specify the input distribute spec for a module.
    """
    def __init__(self):
        super().__init__(input_layouts=Shard(0), output_layouts=Replicate(), use_local=False)

Class PrepareModuleOutput(ParallelStyle):
    """
    Only used to specify the output distribute spec for a module.
    """
    def __init__(self):
        super().__init__(input_layouts=Replicate(), output_layouts=Shard(0), use_local=True)

parallelize_plan = {
    "embedding": ColwiseParallel(output_shard=Replicate()),
    "attn": PrepareModuleInput(),
    "attn.w1": ColwiseParallel(),
    "attn.w2": ColwiseParallel(),
    "attn.w3": ColwiseParallel(),
    "attn.wo": RowwiseParallel(),
}

parallelize_module(
    module=block, # this can be a submodule or module
    device_mesh=mesh['tp'],
    parallelize_plan=parallelize_plan,
)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111160
Approved by: https://github.com/wanchaol
2023-10-14 15:26:36 +00:00
fduwjj
3828cd4b79 [TP][EZ] Update doc for TP parallel style (#107819)
We need to update the doc for PairwiseParallel and SequenceParallel so that users don't get wrong impressions that these working for ``nn.Transformer``.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107819
Approved by: https://github.com/awgu, https://github.com/wanchaol
2023-08-24 00:13:52 +00:00
fduwjj
953aa6d90e [TP] Enable more generic attn in Tensor Parallelism (#100508)
To make TP more generic for Attention module, we come up with this new col/rowwise parallel style.

Basically, the idea behind is that:
We only do DTensor op for Col/Rowwise sharded part. For the rest of ATen ops, we will leave it to Tensor ops.

And we set this behavior as default for Colwise and Rowwise parallel style. If people want to customize it, they can always pass in different prepare_input or prepare_output

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100508
Approved by: https://github.com/wanchaol
2023-05-07 18:15:49 +00:00
fduwjj
89b1e67d0a [Tensor Parallel] Add a new Colwise Parallel style when Pairwise cannot directly used (#100137)
Some use cases, users cannot directly `PairwiseParallelStyle` and they might need to specify colwise and rowwise separately.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100137
Approved by: https://github.com/wz337
2023-04-28 03:27:51 +00:00