pytorch/test/cpp/api/parallel.cpp
Shen Li b7b6b612a7 Fix C++ data parallel (#20910)
Summary:
Fixes #19540

CC nmerrill67

C++ data parallel was using Module.clone() to create module replicas on every destination device. However, clone() does not set up gradient edges to point from replicas to the original module. As a result, the gradient will not be aggregated into the original module. This commit fixes the the problem by manually setting gradient edges from every parameter X in every replica to the same parameter X in the original module.

## Failed Attempt

Initially I tried implementing what we did in `replicate.py`, which
1. create module replicas
2. use Python `Broadcast` autograd function to broadcast every parameter in the original module to all destination devices.
3. assign the broadcast result params to module replicas' `_parameters` dict.

This works in Python because derived module member field params (e.g., `Linear.weight`) and base module `_parameters` (e.g., `Linear._parameters['weight']`) are referencing the same parameter instance. Assigning one of them will apply to both. However, in C++, even though I can modify Module's `parameters_ `values and gradient edges to point to the broadcast source,  I cannot touch the weight and bias member fields in Linear, because replicate cannot (and should not) add special-case handlers to every different module. (See `Linear` [.h](https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/modules/linear.h), [.cpp](https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/src/nn/modules/linear.cpp)) Although they initially point to the same `TensorImpl` instance, after assigning to `Module.parameters_['weight']`, it will be different from `Linear.weight`.

## Solution Options

gchanan and I had several discussions on this issue and figured two solutions to this problem.

### Option One [implemented in this PR]

Replicate the module in two steps:
1. call `Module.clone()` to create a module replica on every destination device.
2. manually setting gradient edges from every parameter in every replica to the same parameter in the original module.

* Pro: Does not need to change any existing module, and relatively easier to implement
* Con: It is a little hackish.

### Options Two

Implement a `Replicatable` class (similar to `Cloneable`), and make it a friend class of `Module`. For more details see `Note [Replicating Modules]` in the code change.

* Pro: Maybe this aligns more with our existing approach implemented in `Cloneable`?
* Con: Require changes to every existing module.

I am inclined to go with option one, because `replicate` will only be used on data parallel. I feel it is too big an overkill if we have to change all existing module implementations due to a data parallel requirement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20910

Differential Revision: D15556426

Pulled By: mrshenli

fbshipit-source-id: aa836290ec657b32742e2bea80bd0ac2404ef3b0
2019-06-06 11:57:31 -07:00

294 lines
9.2 KiB
C++

#include <gtest/gtest.h>
#include <torch/csrc/autograd/functions/comm.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/conv.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/parallel/data_parallel.h>
#include <torch/nn/pimpl.h>
#include <torch/optim/sgd.h>
#include <torch/types.h>
#include <torch/utils.h>
#include <test/cpp/api/support.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
using namespace torch::autograd;
using namespace torch::nn;
struct ParallelTest : torch::test::SeedingFixture {};
TEST_F(ParallelTest, DifferentiableScatter_MultiCUDA) {
Scatter scatter(
{torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
auto input = torch::ones(10, torch::requires_grad(true));
auto output = scatter.apply({input});
ASSERT_EQ(output.size(), 2);
ASSERT_EQ(output[0].size(0), 5);
ASSERT_EQ(output[1].size(0), 5);
ASSERT_TRUE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)})
.allclose(input));
torch::Tensor sum = output[0].to({torch::kCUDA, 1}) + output[1];
sum.backward();
ASSERT_TRUE(input.grad().defined());
ASSERT_TRUE(input.grad().device().is_cpu());
ASSERT_EQ(input.grad().sum().item<int32_t>(), 10);
}
TEST_F(ParallelTest, DifferentiableGather_MultiCUDA) {
Gather gather(torch::Device(torch::kCUDA, 1));
auto a = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 0));
auto b = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 1));
auto outputs = gather.apply({a, b});
ASSERT_EQ(outputs.size(), 1);
torch::Tensor output = outputs.front();
ASSERT_EQ(output.size(0), 10);
ASSERT_EQ(output.device(), torch::Device(torch::kCUDA, 1));
auto chunks = output.chunk(2);
ASSERT_TRUE(chunks[0].to({torch::kCUDA, 0}).allclose(a));
ASSERT_TRUE(chunks[1].allclose(b));
output.backward();
ASSERT_TRUE(a.grad().defined());
ASSERT_EQ(a.grad().device(), torch::Device(torch::kCUDA, 0));
ASSERT_EQ(a.grad().sum().item<int32_t>(), 5);
ASSERT_TRUE(b.grad().defined());
ASSERT_EQ(b.grad().device(), torch::Device(torch::kCUDA, 1));
ASSERT_EQ(b.grad().sum().item<int32_t>(), 5);
}
TEST_F(ParallelTest, Replicate_MultiCUDA) {
Linear linear(3, 4);
auto replicas = parallel::replicate(
linear, {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
ASSERT_EQ(replicas.size(), 2);
auto original_parameters = linear->parameters();
auto replica1_parameters = replicas[0]->parameters();
for (auto& parameter : replica1_parameters) {
ASSERT_EQ(parameter.device(), torch::Device(torch::kCUDA, 0));
}
replicas[0]->to(torch::kCPU);
ASSERT_EQ(replica1_parameters.size(), original_parameters.size());
for (size_t i = 0; i < original_parameters.size(); ++i) {
ASSERT_TRUE(replica1_parameters[i].allclose(original_parameters[i]));
ASSERT_TRUE(
replica1_parameters[i].data<float>() !=
original_parameters[i].data<float>());
}
auto replica2_parameters = replicas[1]->parameters();
for (auto& parameter : replica2_parameters) {
ASSERT_EQ(parameter.device(), torch::Device(torch::kCUDA, 1));
}
replicas[1]->to(torch::kCPU);
ASSERT_EQ(replica2_parameters.size(), original_parameters.size());
for (size_t i = 0; i < original_parameters.size(); ++i) {
ASSERT_TRUE(replica2_parameters[i].allclose(original_parameters[i]));
ASSERT_TRUE(
replica2_parameters[i].data<float>() !=
original_parameters[i].data<float>());
}
}
TEST_F(ParallelTest, ParallelApply_MultiCUDA) {
Linear a(3, 4);
Linear b(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
b->to({torch::kCUDA, 0});
Linear c(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
c->to({torch::kCUDA, 1});
std::vector<Linear> modules = {a, b, c};
std::vector<torch::Tensor> inputs = {
torch::ones({2, 3}),
torch::ones({2, 3}, torch::device({torch::kCUDA, 0})),
torch::ones({2, 3}, torch::device({torch::kCUDA, 1}))};
auto outputs = parallel::parallel_apply(modules, inputs);
ASSERT_EQ(outputs.size(), 3);
ASSERT_TRUE(outputs[0].device().is_cpu());
ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0));
ASSERT_TRUE(outputs[1].to(torch::kCPU).allclose(outputs[0]));
ASSERT_EQ(outputs[2].device(), torch::Device(torch::kCUDA, 1));
ASSERT_TRUE(outputs[2].to(torch::kCPU).allclose(outputs[0]));
}
TEST_F(ParallelTest, ParallelApplyWithDifferentOutputDevice_MultiCUDA) {
struct M : torch::nn::Module {
torch::Tensor forward(torch::Tensor input) {
return torch::ones(5, torch::kInt32);
}
};
std::vector<std::shared_ptr<M>> modules = {
std::make_shared<M>(), std::make_shared<M>(), std::make_shared<M>()};
std::vector<torch::Tensor> inputs = {
torch::empty({}), torch::empty({}), torch::empty({})};
std::vector<torch::Device> devices = {
{torch::kCUDA, 1}, {torch::kCUDA, 0}, {torch::kCPU}};
auto outputs = parallel::parallel_apply(modules, inputs, devices);
ASSERT_EQ(outputs.size(), 3);
ASSERT_TRUE(outputs[0].device().is_cuda());
ASSERT_EQ(outputs[0].device(), torch::Device(torch::kCUDA, 1));
ASSERT_TRUE(outputs[1].device().is_cuda());
ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0));
ASSERT_TRUE(outputs[2].device().is_cpu());
}
TEST_F(ParallelTest, ParallelApplyRethrowsException_MultiCUDA) {
struct M : torch::nn::Cloneable<M> {
void reset() override {}
torch::Tensor forward(torch::Tensor input) {
throw std::runtime_error("Badness!");
}
};
auto m = std::make_shared<M>();
auto input = torch::ones({10, 3});
ASSERT_THROWS_WITH(parallel::data_parallel(m, input), "Badness!");
}
TEST_F(
ParallelTest,
DataParallelPlacesTheOutputOnTheRequestedDevice_MultiCUDA) {
struct M : torch::nn::Cloneable<M> {
void reset() override {}
torch::Tensor forward(torch::Tensor input) {
// The returned tensor should be on the output device.
return torch::ones(3);
}
};
auto m = std::make_shared<M>();
auto input = torch::ones({10, 3});
{
auto output = parallel::data_parallel(
m,
input,
/*devices=*/torch::nullopt,
/*output_device=*/torch::Device(torch::kCUDA, 1));
ASSERT_TRUE(output.defined());
ASSERT_TRUE(output.device().is_cuda());
ASSERT_EQ(output.device().index(), 1);
}
{
// Verify for the single-device case (where we don't scatter/gather).
auto output = parallel::data_parallel(
m,
input,
/*devices=*/std::vector<torch::Device>{torch::Device(torch::kCUDA, 0)},
/*output_device=*/torch::Device(torch::kCUDA, 1));
ASSERT_TRUE(output.defined());
ASSERT_TRUE(output.device().is_cuda());
ASSERT_EQ(output.device().index(), 1);
}
}
TEST_F(ParallelTest, DataParallelUsesAllAvailableCUDADevices_CUDA) {
struct M : torch::nn::Cloneable<M> {
void reset() override {}
torch::Tensor forward(torch::Tensor input) {
return torch::tensor(input.device().index());
}
};
auto m = std::make_shared<M>();
auto input = torch::ones({10, 3});
auto output = parallel::data_parallel(m, input);
const auto device_count = torch::cuda::device_count();
ASSERT_EQ(output.numel(), device_count);
for (size_t i = 0; i < device_count; ++i) {
ASSERT_EQ(output[i].item<int32_t>(), i);
}
}
TEST_F(ParallelTest, DataParallelNumericalEquivalence_MultiCUDA) {
struct M : torch::nn::Cloneable<M> {
M() {
reset();
}
void reset() override {
conv = register_module("conv",
torch::nn::Conv2d(torch::nn::Conv2dOptions(2, 2, /*kernel_size=*/2)));
fc = register_module("fc", torch::nn::Linear(8, 2));
}
torch::Tensor forward(torch::Tensor x) {
x = conv->forward(x);
x = torch::relu(x);
x = x.view({-1, 8});
x = fc->forward(x);
return torch::log_softmax(x, /*dim=*/1);
}
torch::nn::Conv2d conv{nullptr};
torch::nn::Linear fc{nullptr};
};
// prepare modules and inputs
auto input = torch::ones({16, 2, 3, 3});
auto input_dp = torch::ones({16, 2, 3, 3});
auto model = std::make_shared<M>();
auto model_dp = std::dynamic_pointer_cast<M>(model->clone());
// run 3 training iterations
for (int i = 0; i < 3; ++i) {
input += i;
input_dp += i;
// non-prallel training
torch::optim::SGD optim(
model->parameters(), torch::optim::SGDOptions(0.1));
auto output = model->forward(input);
auto loss = torch::mse_loss(output, torch::zeros_like(output));
loss.backward();
optim.step();
// data-parallel training
torch::optim::SGD optim_dp(
model_dp->parameters(), torch::optim::SGDOptions(0.1));
auto output_dp = parallel::data_parallel(model_dp, input_dp);
auto loss_dp = torch::mse_loss(output_dp, torch::zeros_like(output_dp));
loss_dp.backward();
optim_dp.step();
// make sure that weights are the same
model->to(torch::kCPU);
model_dp->to(torch::kCPU);
auto params = model->parameters();
auto params_dp = model_dp->parameters();
ASSERT_EQ(params.size(), params_dp.size());
for (auto it = params.begin(), it_dp = params_dp.begin();
it != params.end() && it_dp != params.end();
++it, ++it_dp) {
ASSERT_TRUE(torch::allclose(*it, *it_dp));
}
}
}