pytorch/test/distributed/_pipeline/sync/test_dependency.py
Pritam Damania 06d50b5eb0 Pull in fairscale.nn.Pipe into PyTorch. (#44090)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44090

This is an initial commit pulling in the torchgpipe fork at
https://github.com/facebookresearch/fairscale.

The purpose of this commit is to just pull in the code and ensure all tests and
builds work fine. We will slowly modify this to match our intended API
mentioned in https://fb.quip.com/txurAV3zIFox#RPZACAfAKMq. Follow up PRs would
address further changes needed on top of the initial commit..

We're pulling the code into the `torch.distributed._pipeline.sync` package. The
package is private on purpose since there is a lot of work (ex: docs, API
changes etc.) that needs to go in before we can actually officially support
this.
ghstack-source-id: 114864254

Test Plan:
1) waitforbuildbot
2) Ran all tests on my devgpu

Reviewed By: mrshenli

Differential Revision: D23493316

fbshipit-source-id: fe3c8b7dadeeb86abdc00e8a8652491b0b16743a
2020-10-22 10:59:02 -07:00

145 lines
3.0 KiB
Python

# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import weakref
import pytest
import torch
from torch.distributed._pipeline.sync.dependency import Fork, Join, fork, join
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_fork_join():
logs = []
class Log(torch.autograd.Function):
@staticmethod
def forward(ctx, number, tensor):
ctx.number = number
return tensor.detach()
@staticmethod
def backward(ctx, grad):
logs.append(ctx.number)
return None, grad
a = torch.rand(1, device="cpu", requires_grad=True)
b = torch.rand(1, device="cuda", requires_grad=True)
a = Log.apply(1, a)
a, phony = fork(a)
b = join(a, phony)
b = Log.apply(2, b)
b = b.to("cpu")
(a + b).backward()
assert logs == [2, 1]
def test_fork_join_enable_grad():
x = torch.rand(1, requires_grad=True)
with torch.enable_grad():
x2, p = fork(x)
assert p.requires_grad
assert x2 is not x
x = x2
assert x.requires_grad
assert p.requires_grad
assert x.grad_fn.__class__ is Fork._backward_cls
assert p.grad_fn.__class__ is Fork._backward_cls
with torch.enable_grad():
x2 = join(x, p)
assert x2 is not x
x = x2
assert x.requires_grad
assert x.grad_fn.__class__ is Join._backward_cls
def test_fork_join_no_grad(monkeypatch):
def do_not_apply(*args):
raise AssertionError("Function.apply called")
monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply)
x = torch.rand(1, requires_grad=True)
with torch.no_grad():
x2, p = fork(x)
assert not p.requires_grad
assert x2 is x
x = x2
with torch.no_grad():
x2 = join(x, p)
assert x2 is x
x = x2
def test_fork_leak():
leak = None
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad):
nonlocal leak
leak = weakref.ref(ctx)
return grad
x = torch.rand(1, requires_grad=True)
x = F.apply(x)
x, phony = fork(x)
x = join(x, phony)
x.backward()
del x, phony
assert leak() is None
def test_join_when_fork_not_requires_grad():
x = torch.rand(2, 1)
a, b = x.chunk(2)
assert not a.requires_grad
a, p = fork(a)
assert not a.requires_grad
assert not p.requires_grad
assert not b.requires_grad
b = join(b, p)
assert not b.requires_grad
def test_join_when_fork_requires_grad():
x = torch.rand(2, 1)
a, b = x.chunk(2)
a.requires_grad_()
assert a.requires_grad
a, p = fork(a)
assert a.requires_grad
assert p.requires_grad
assert not b.requires_grad
b = join(b, p)
assert b.requires_grad