Commit Graph

20 Commits

Author SHA1 Message Date
Rohan Varma
88ce6215f5 [FSDP/DDP] Unify _cast_forward_inputs (#102680)
Closes https://github.com/pytorch/pytorch/issues/96380

Differential Revision: [D46342814](https://our.internmc.facebook.com/intern/diff/D46342814/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102680
Approved by: https://github.com/awgu
2023-06-04 18:31:21 +00:00
Matthew Hoffman
c28f8e314d Add type hints in torch/distributed/utils.py (#102262)
Fixes #77190

Pretty similar to the typing in `torch/nn/parallel`, which was also improved recently: #102194

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102262
Approved by: https://github.com/Skylion007, https://github.com/Neilblaze
2023-05-30 19:57:45 +00:00
Matthew Hoffman
0ed22fce97 Merge type stubs torch nn parallel (#102194)
Fixes merge issue for #101528

In the above PR, `torch.nn.parallel.parallel_apply.get_a_var` was marked private to appease the [public interface linter](https://github.com/pytorch/pytorch/actions/runs/4999216467/jobs/8955582204#step:14:21666): ceeb242bc7

This broke CI pipelines running external dependencies that expected `get_a_var`'s name to not change. In this PR, we change the name back to `get_a_var` and include it in the `__all__` instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102194
Approved by: https://github.com/ezyang
2023-05-26 20:10:47 +00:00
Aaron Gokaslan
dfe484a3b3 [BE]: Bugfix functorch and some generic typing improvements (#101337)
Fixes some typing bugs found with newer versions of mypy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101337
Approved by: https://github.com/ezyang
2023-05-14 14:20:56 +00:00
Xing Liu
0731420645 [PyTorch/Distributed]Only sync buffers when broadcast_buffers is True (#100729)
Summary: Disable buffers sync in _sync_module_states(...) when broadcast_buffers is False. This change will memory usage when a model has huge buffers and does not need broadcast buffers.

Test Plan: .

Differential Revision: D45610709

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100729
Approved by: https://github.com/mrshenli
2023-05-08 16:34:29 +00:00
feifan
bd07f8d2e0 DDP forward support custom stream accelerated copy. (#98723)
At present, DDP forward uses `_get_stream` to get a stream,which is cudaStream.
If the custom module already registered to torch, I can use `getattr` to get it and it's stream. Then, the custom stream is used to copy the tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98723
Approved by: https://github.com/ezyang
2023-04-14 20:19:56 +00:00
feifan
d95ee64b58 ddp forward support custom backend. (#98283)
Currently DDP only considers CUDA backend,DDP forward will transfer tensor to CUDA. We want ddp to run on custom backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98283
Approved by: https://github.com/ezyang
2023-04-09 01:30:42 +00:00
Aaron Gokaslan
5471621497 [BE] Remove unnecessary dict comprehensions (#97116)
Removes unnecessary dict comprehensions that optimize creation of dicts from iterables

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97116
Approved by: https://github.com/kit1980
2023-03-20 00:56:57 +00:00
Rohan Varma
c43e88665a [Resubmit] helpers to torch.dist.utils (#95025)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95025
Approved by: https://github.com/fegin
2023-02-17 18:24:20 +00:00
Aaron Gokaslan
1e2d82b8e4 [BE] Merge isinstance calls together (#94419)
Simplify and speeds up isinstance calls by checking for multiple types at the same time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94419
Approved by: https://github.com/ezyang
2023-02-09 00:47:26 +00:00
Andrew Gu
ce7751188a [DDP] Add PackedSequence support when device_ids is specified (#86614)
Before this PR, if a user runs DDP with `device_ids` specified and with a `PackedSequence` input, then the execution will error with something like:
```
raise ValueError(
  ValueError: batch_sizes should always be on CPU. Instances of PackedSequence should never be created manually. They should be instantiated by
 functions like pack_sequence and pack_padded_sequences in nn.utils.rnn. https://pytorch.org/docs/stable/nn.html...
```
This is because the DDP forward calls `_to_kwargs()`, which calls `_recursive_to()`, which moves the inputs to GPU. However, `_is_namedtuple(packed_sequence)` returns `True`, leading to the branch `return [type(obj)(*args) for args in zip(*map(to_map, obj))]`, which tries to construct a `PackedSequence` directly via `type(obj)(*args)`, leading to the error.

Repro for `_is_namedtuple(packed_sequence)` returning `True`:
```
import random

import torch
import torch.nn.utils.rnn as rnn_utils
from torch.nn.parallel.scatter_gather import _is_namedtuple

def _ordered_sequence(tensor_type):
    seqs = [tensor_type(random.randint(1, 256))
            for _ in range(32)]
    seqs = [s.random_(-128, 128) for s in seqs]
    ordered = sorted(seqs, key=len, reverse=True)
    return ordered

def _padded_sequence(tensor_type):
    ordered = _ordered_sequence(tensor_type)
    lengths = [len(i) for i in ordered]
    padded_tensor = rnn_utils.pad_sequence(ordered)
    return padded_tensor, lengths

padded, lengths = _padded_sequence(torch.Tensor)
packed = rnn_utils.pack_padded_sequence(
    padded, lengths, enforce_sorted=False)
print(type(packed), packed.data.device)
print(_is_namedtuple(packed))
```

Test Plan:
```
python test/distributed/test_c10d_nccl.py -k test_ddp_packed_sequence
```
Without the fix, the added unit test fails with the expected error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86614
Approved by: https://github.com/rohan-varma
2022-10-10 21:50:59 +00:00
Rohan Varma
8cb7826889 [CheckpointWrapper] Reentrant kwarg support (#84908)
A temporary patch to support keyword args when reentrant checkpoint wrapper is used. This is need to unblock some crucial workloads, the ideal fix would be checking this directly into torch.utils.checkpoint.

Differential Revision: [D39453453](https://our.internmc.facebook.com/intern/diff/D39453453/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84908
Approved by: https://github.com/awgu
2022-09-15 00:30:23 +00:00
Rohan Varma
d2f37401b8 Silence namedtuple warning in dist (#84072)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84072
Approved by: https://github.com/awgu
2022-08-26 00:24:28 +00:00
Rohan Varma
b29a074882 [BE] Revert distributed change in https://github.com/pytorch/pytorch/pull/68779 (#83181)
https://github.com/pytorch/pytorch/issues/82641 points out a regression in how inputs / outputs are processed by DDP, blocking their HF use case. It was narrowed down to https://github.com/pytorch/pytorch/pull/68779 and reverting the distributed change there fixes the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83181
Approved by: https://github.com/kumpera
2022-08-23 02:38:23 +00:00
Rohan Varma
6f954d7bbb FSDP parameter sync
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77492

Approved by: https://github.com/zhaojuanmao
2022-05-17 19:58:49 +00:00
Rohan Varma
f9f8127414 CheckpointWrapper state_dict fix (#77224)
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.

This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.

Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
2022-05-17 03:39:31 +00:00
Rohan Varma
bbb1f106c7 Separate input moving to utils file
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77187

Test fix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77235

Lint fix

Approved by: https://github.com/awgu
2022-05-11 21:55:38 +00:00
Rohan Varma
ffb0946504 Generalize param verification and broadcast
New PR for https://github.com/pytorch/pytorch/pull/75970 to be compatible with GHF.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76374
Approved by: https://github.com/awgu
2022-04-26 22:25:53 +00:00
Pritam Damania
b8e6144e0a Add a _RemoteDevice structure for ShardedTensor/ShardingSpec. (#62927)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62927

As part of the ShardedTensor work, we realized we do need some sort of
_RemoteDevice structure that deals with our format of "workername/device" so
that users don't have to worry about parsing this string directly.

Right now this structure is just the bare minimum and is mostly a container for
describing a remote device. It is currently only used in ShardedTensor,
ShardingSpec and RemoteModule.

Once we actually have a consolidated remote device proposal, this class can be
extended appropriately if needed.
ghstack-source-id: 135534086

Test Plan:
1) unit tests
2) waitforbuildbot

Reviewed By: SciPioneer

Differential Revision: D30170689

fbshipit-source-id: 1ac2e81c7a597dc40bf3fbf2c1168c382c66649f
2021-08-11 11:27:32 -07:00
Pritam Damania
0d6fa1adc5 Introduce ChunkShardingSpec as a model sharding specification. (#55728)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55728

Full design: https://github.com/pytorch/pytorch/issues/55207

This PR introduces ChunkShardingSpec (SingleShardingSpec in the design). Used
the name ChunkShardingSpec since it is very similar to `torch.chunk` in terms
of how a Tensor is split up and feels more clear compared to SingleShardingSpec.
ghstack-source-id: 129603318

Test Plan: waitforbuildbot

Reviewed By: SciPioneer

Differential Revision: D27694108

fbshipit-source-id: c8764abe6a4d5fc56d023fda29b74b5af2a73b49
2021-05-23 16:04:57 -07:00