pytorch/test/cpp/api/parallel.cpp
Will Feng 595209bddc Fix bugs in torch::tensor constructor (#28523)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28523

New features:
1. Previously, `torch::tensor({true, false, true})` throws `"tensor_cpu" not implemented for 'Bool'`. After this PR, it produces the correct bool tensor, matching the Python API behavior.
2. Tensors with zero-size dimensions are now supported, e.g. `torch::tensor({{}, {}})` produces a tensor with sizes `{2, 0}`, matching the Python API behavior.

BC-breaking bug fixes:
1. Previously, `torch::tensor({{1}, {2}})` produces a tensor of sizes `{2}`. After this PR, it produces a tensor of sizes `{2, 1}`, matching the Python API behavior.
2. Fixed semantics of `torch::tensor(1.1)`: it now returns a 0-dim tensor instead of a 1-dim tensor, matching the Python API behavior.
3. Previously, when passed a non-dtype `TensorOptions` to the `torch::tensor` constructor, it always produces a tensor of dtype `float`. After this PR, it produces tensor of different dtypes based on the dtype of the braced-init-list, matching the behavior of the no-options case.
```cpp
// Previously:
torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float
torch::tensor({{1, 2, 3}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float
torch::tensor({1., 2., 3.}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float
torch::tensor({{1., 2., 3.}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float

// Now:
torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> int
torch::tensor({{1, 2, 3}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> int
torch::tensor({1., 2., 3.}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> double
torch::tensor({{1., 2., 3.}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> double

// As comparison, currently:
torch::tensor({1, 2, 3}).dtype() -> int
torch::tensor({{1, 2, 3}}).dtype() -> int
torch::tensor({1., 2., 3.}).dtype() -> double
torch::tensor({{1., 2., 3.}}).dtype() -> double
```

Notes:
1. From now on, the behavior of `at::tensor(scalar_value)` (which produces a 1-dim tensor) would be different from `torch::tensor(scalar_value)` (which produces a 0-dim tensor). I will fix the behavior of `at::tensor(scalar_value)` in a follow-up PR.
2. From now on, the behavior of `at::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/))` (which produces a `float` tensor) would be different from `torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/))` (which produces a an `int` tensor). I will fix this behavior of `at::tensor` constructor in a follow-up PR.

Context for the changes in this PR:

The motivation comes from fixing the "`torch::tensor({{1}, {2}})` gives tensor of wrong sizes" bug - in order to fix it, I have to move the handling of `at::ArrayRef` and `std::vector` into `InitListTensor` (see below on why we need to do this) and renamed `InitListTensor` to `TensorDataContainer`. After such changes, support for bool values comes out of the box without extra effort, and support for tensors with zero-size dimensions only requires adding a default constructor for `TensorDataContainer`, so I added those two in this PR.

For the semantic change of `torch::tensor(1.1)`, it's actually more effort to preserve the original wrong behavior (i.e. we need to check the sizes of the tensor converted from `TensorDataContainer` and reshape any scalar tensor to a 1-D tensor). I think preserving the original wrong behavior doesn't give us much value, and since the above changes naturally fix the problem, we should just start using the right behavior instead.

For the "constructor with non-dtype options behavior" fix, the code looks simpler and easier to reason about with the fix, so I included it in this PR.

--------

Why we need to move the handling of `at::ArrayRef` and `std::vector` into `TensorDataContainer`:

`torch::tensor({{1}, {2}})` can match this function overload:
`torch::tensor(at::ArrayRef<int> values)`, because `{1}` and `{2}` can be treated as
a list-initialization of an `int` value. However, this will produce a Tensor with sizes `{2}`,
but we actually want a Tensor with sizes `{2, 1}`. In order to avoid matching this function overload,
we removed the function overload and moved the ability to convert `at::ArrayRef<T>`
(and similarly `std::vector<T>`) into `TensorDataContainer`, and since for braced-init-list the
`TensorDataContainer(std::initializer_list<TensorDataContainer>)` constructor is always preferred over all other constructors, it will take the `std::initializer_list` path, and all is good.

Test Plan: Imported from OSS

Differential Revision: D18234625

Pulled By: yf225

fbshipit-source-id: 0f3f6912e82e2117d2103e31b74e7e97baaa8693
2019-10-31 12:53:06 -07:00

294 lines
9.3 KiB
C++

#include <gtest/gtest.h>
#include <torch/csrc/autograd/functions/comm.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/conv.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/parallel/data_parallel.h>
#include <torch/nn/pimpl.h>
#include <torch/optim/sgd.h>
#include <torch/types.h>
#include <torch/utils.h>
#include <test/cpp/api/support.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
using namespace torch::autograd;
using namespace torch::nn;
struct ParallelTest : torch::test::SeedingFixture {};
TEST_F(ParallelTest, DifferentiableScatter_MultiCUDA) {
Scatter scatter(
{torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
auto input = torch::ones(10, torch::requires_grad(true));
auto output = scatter.apply({input});
ASSERT_EQ(output.size(), 2);
ASSERT_EQ(output[0].size(0), 5);
ASSERT_EQ(output[1].size(0), 5);
ASSERT_TRUE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)})
.allclose(input));
torch::Tensor sum = output[0].to({torch::kCUDA, 1}) + output[1];
sum.backward(torch::ones_like(sum));
ASSERT_TRUE(input.grad().defined());
ASSERT_TRUE(input.grad().device().is_cpu());
ASSERT_EQ(input.grad().sum().item<int32_t>(), 10);
}
TEST_F(ParallelTest, DifferentiableGather_MultiCUDA) {
Gather gather(torch::Device(torch::kCUDA, 1));
auto a = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 0));
auto b = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 1));
auto outputs = gather.apply({a, b});
ASSERT_EQ(outputs.size(), 1);
torch::Tensor output = outputs.front();
ASSERT_EQ(output.size(0), 10);
ASSERT_EQ(output.device(), torch::Device(torch::kCUDA, 1));
auto chunks = output.chunk(2);
ASSERT_TRUE(chunks[0].to({torch::kCUDA, 0}).allclose(a));
ASSERT_TRUE(chunks[1].allclose(b));
output.backward(torch::ones_like(output));
ASSERT_TRUE(a.grad().defined());
ASSERT_EQ(a.grad().device(), torch::Device(torch::kCUDA, 0));
ASSERT_EQ(a.grad().sum().item<int32_t>(), 5);
ASSERT_TRUE(b.grad().defined());
ASSERT_EQ(b.grad().device(), torch::Device(torch::kCUDA, 1));
ASSERT_EQ(b.grad().sum().item<int32_t>(), 5);
}
TEST_F(ParallelTest, Replicate_MultiCUDA) {
Linear linear(3, 4);
auto replicas = parallel::replicate(
linear, {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
ASSERT_EQ(replicas.size(), 2);
auto original_parameters = linear->parameters();
auto replica1_parameters = replicas[0]->parameters();
for (auto& parameter : replica1_parameters) {
ASSERT_EQ(parameter.device(), torch::Device(torch::kCUDA, 0));
}
replicas[0]->to(torch::kCPU);
ASSERT_EQ(replica1_parameters.size(), original_parameters.size());
for (size_t i = 0; i < original_parameters.size(); ++i) {
ASSERT_TRUE(replica1_parameters[i].allclose(original_parameters[i]));
ASSERT_TRUE(
replica1_parameters[i].data_ptr<float>() !=
original_parameters[i].data_ptr<float>());
}
auto replica2_parameters = replicas[1]->parameters();
for (auto& parameter : replica2_parameters) {
ASSERT_EQ(parameter.device(), torch::Device(torch::kCUDA, 1));
}
replicas[1]->to(torch::kCPU);
ASSERT_EQ(replica2_parameters.size(), original_parameters.size());
for (size_t i = 0; i < original_parameters.size(); ++i) {
ASSERT_TRUE(replica2_parameters[i].allclose(original_parameters[i]));
ASSERT_TRUE(
replica2_parameters[i].data_ptr<float>() !=
original_parameters[i].data_ptr<float>());
}
}
TEST_F(ParallelTest, ParallelApply_MultiCUDA) {
Linear a(3, 4);
Linear b(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
b->to({torch::kCUDA, 0});
Linear c(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
c->to({torch::kCUDA, 1});
std::vector<Linear> modules = {a, b, c};
std::vector<torch::Tensor> inputs = {
torch::ones({2, 3}),
torch::ones({2, 3}, torch::device({torch::kCUDA, 0})),
torch::ones({2, 3}, torch::device({torch::kCUDA, 1}))};
auto outputs = parallel::parallel_apply(modules, inputs);
ASSERT_EQ(outputs.size(), 3);
ASSERT_TRUE(outputs[0].device().is_cpu());
ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0));
ASSERT_TRUE(outputs[1].to(torch::kCPU).allclose(outputs[0]));
ASSERT_EQ(outputs[2].device(), torch::Device(torch::kCUDA, 1));
ASSERT_TRUE(outputs[2].to(torch::kCPU).allclose(outputs[0]));
}
TEST_F(ParallelTest, ParallelApplyWithDifferentOutputDevice_MultiCUDA) {
struct M : torch::nn::Module {
torch::Tensor forward(torch::Tensor input) {
return torch::ones(5, torch::kInt32);
}
};
std::vector<std::shared_ptr<M>> modules = {
std::make_shared<M>(), std::make_shared<M>(), std::make_shared<M>()};
std::vector<torch::Tensor> inputs = {
torch::empty({}), torch::empty({}), torch::empty({})};
std::vector<torch::Device> devices = {
{torch::kCUDA, 1}, {torch::kCUDA, 0}, {torch::kCPU}};
auto outputs = parallel::parallel_apply(modules, inputs, devices);
ASSERT_EQ(outputs.size(), 3);
ASSERT_TRUE(outputs[0].device().is_cuda());
ASSERT_EQ(outputs[0].device(), torch::Device(torch::kCUDA, 1));
ASSERT_TRUE(outputs[1].device().is_cuda());
ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0));
ASSERT_TRUE(outputs[2].device().is_cpu());
}
TEST_F(ParallelTest, ParallelApplyRethrowsException_MultiCUDA) {
struct M : torch::nn::Cloneable<M> {
void reset() override {}
torch::Tensor forward(torch::Tensor input) {
throw std::runtime_error("Badness!");
}
};
auto m = std::make_shared<M>();
auto input = torch::ones({10, 3});
ASSERT_THROWS_WITH(parallel::data_parallel(m, input), "Badness!");
}
TEST_F(
ParallelTest,
DataParallelPlacesTheOutputOnTheRequestedDevice_MultiCUDA) {
struct M : torch::nn::Cloneable<M> {
void reset() override {}
torch::Tensor forward(torch::Tensor input) {
// The returned tensor should be on the output device.
return torch::ones(3);
}
};
auto m = std::make_shared<M>();
auto input = torch::ones({10, 3});
{
auto output = parallel::data_parallel(
m,
input,
/*devices=*/torch::nullopt,
/*output_device=*/torch::Device(torch::kCUDA, 1));
ASSERT_TRUE(output.defined());
ASSERT_TRUE(output.device().is_cuda());
ASSERT_EQ(output.device().index(), 1);
}
{
// Verify for the single-device case (where we don't scatter/gather).
auto output = parallel::data_parallel(
m,
input,
/*devices=*/std::vector<torch::Device>{torch::Device(torch::kCUDA, 0)},
/*output_device=*/torch::Device(torch::kCUDA, 1));
ASSERT_TRUE(output.defined());
ASSERT_TRUE(output.device().is_cuda());
ASSERT_EQ(output.device().index(), 1);
}
}
TEST_F(ParallelTest, DataParallelUsesAllAvailableCUDADevices_CUDA) {
struct M : torch::nn::Cloneable<M> {
void reset() override {}
torch::Tensor forward(torch::Tensor input) {
return torch::tensor({input.device().index()});
}
};
auto m = std::make_shared<M>();
auto input = torch::ones({10, 3});
auto output = parallel::data_parallel(m, input);
const auto device_count = torch::cuda::device_count();
ASSERT_EQ(output.numel(), device_count);
for (size_t i = 0; i < device_count; ++i) {
ASSERT_EQ(output[i].item<int32_t>(), i);
}
}
TEST_F(ParallelTest, DataParallelNumericalEquivalence_MultiCUDA) {
struct M : torch::nn::Cloneable<M> {
M() {
reset();
}
void reset() override {
conv = register_module("conv",
torch::nn::Conv2d(torch::nn::Conv2dOptions(2, 2, /*kernel_size=*/2)));
fc = register_module("fc", torch::nn::Linear(8, 2));
}
torch::Tensor forward(torch::Tensor x) {
x = conv->forward(x);
x = torch::relu(x);
x = x.view({-1, 8});
x = fc->forward(x);
return torch::log_softmax(x, /*dim=*/1);
}
torch::nn::Conv2d conv{nullptr};
torch::nn::Linear fc{nullptr};
};
// prepare modules and inputs
auto input = torch::ones({16, 2, 3, 3});
auto input_dp = torch::ones({16, 2, 3, 3});
auto model = std::make_shared<M>();
auto model_dp = std::dynamic_pointer_cast<M>(model->clone());
// run 3 training iterations
for (int i = 0; i < 3; ++i) {
input += i;
input_dp += i;
// non-prallel training
torch::optim::SGD optim(
model->parameters(), torch::optim::SGDOptions(0.1));
auto output = model->forward(input);
auto loss = torch::mse_loss(output, torch::zeros_like(output));
loss.backward();
optim.step();
// data-parallel training
torch::optim::SGD optim_dp(
model_dp->parameters(), torch::optim::SGDOptions(0.1));
auto output_dp = parallel::data_parallel(model_dp, input_dp);
auto loss_dp = torch::mse_loss(output_dp, torch::zeros_like(output_dp));
loss_dp.backward();
optim_dp.step();
// make sure that weights are the same
model->to(torch::kCPU);
model_dp->to(torch::kCPU);
auto params = model->parameters();
auto params_dp = model_dp->parameters();
ASSERT_EQ(params.size(), params_dp.size());
for (auto it = params.begin(), it_dp = params_dp.begin();
it != params.end() && it_dp != params.end();
++it, ++it_dp) {
ASSERT_TRUE(torch::allclose(*it, *it_dp));
}
}
}