pytorch/torch/distributed/_pipeline/sync/phony.py
Pritam Damania 06d50b5eb0 Pull in fairscale.nn.Pipe into PyTorch. (#44090)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44090

This is an initial commit pulling in the torchgpipe fork at
https://github.com/facebookresearch/fairscale.

The purpose of this commit is to just pull in the code and ensure all tests and
builds work fine. We will slowly modify this to match our intended API
mentioned in https://fb.quip.com/txurAV3zIFox#RPZACAfAKMq. Follow up PRs would
address further changes needed on top of the initial commit..

We're pulling the code into the `torch.distributed._pipeline.sync` package. The
package is private on purpose since there is a lot of work (ex: docs, API
changes etc.) that needs to go in before we can actually officially support
this.
ghstack-source-id: 114864254

Test Plan:
1) waitforbuildbot
2) Ran all tests on my devgpu

Reviewed By: mrshenli

Differential Revision: D23493316

fbshipit-source-id: fe3c8b7dadeeb86abdc00e8a8652491b0b16743a
2020-10-22 10:59:02 -07:00

50 lines
1.5 KiB
Python

# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Provides phony for arbitrary dependency in a autograd graph."""
from typing import Dict, List, Tuple
import torch
from torch import Tensor
from .stream import default_stream, use_stream
__all__: List[str] = []
_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor:
"""Gets a phony. Phony is tensor without space. It is useful to make
arbitrary dependency in a autograd graph because it doesn't require any
gradient accumulation.
.. note::
Phonies for each device are cached. If an autograd function gets a phony
internally, the phony must be detached to be returned. Otherwise, the
autograd engine will mutate the cached phony in-place::
class Phonify(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
phony = get_phony(input.device, requires_grad=False)
return phony.detach() # detach() is necessary.
"""
key = (device, requires_grad)
try:
phony = _phonies[key]
except KeyError:
with use_stream(default_stream(device)):
phony = torch.empty(0, device=device, requires_grad=requires_grad)
_phonies[key] = phony
return phony