#include #include #include #include #include #include #include #include #include #include #include using Catch::StartsWith; using namespace torch::autograd; using namespace torch::nn; #ifdef USE_CUDA TEST_CASE("Parallel/DifferentiableScatter", "[multi-cuda]") { Scatter scatter( {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)}); auto input = torch::ones(10, torch::requires_grad(true)); auto output = scatter.apply({input}); REQUIRE(output.size() == 2); REQUIRE(output[0].size(0) == 5); REQUIRE(output[1].size(0) == 5); REQUIRE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)}) .allclose(input)); auto sum = output[0].to({torch::kCUDA, 1}) + output[1]; sum.backward(); REQUIRE(input.grad().defined()); REQUIRE(input.grad().device().is_cpu()); REQUIRE(input.grad().sum().toCInt() == 10); } TEST_CASE("Parallel/DifferentiableGather", "[multi-cuda]") { Gather gather(torch::Device(torch::kCUDA, 1)); auto a = torch::ones(5, torch::requires_grad(true).device({torch::kCUDA, 0})); auto b = torch::ones(5, torch::requires_grad(true).device({torch::kCUDA, 1})); auto outputs = gather.apply({a, b}); REQUIRE(outputs.size() == 1); auto& output = outputs.front(); REQUIRE(output.size(0) == 10); REQUIRE(output.device() == torch::Device(torch::kCUDA, 1)); auto chunks = output.chunk(2); REQUIRE(chunks[0].to({torch::kCUDA, 0}).allclose(a)); REQUIRE(chunks[1].allclose(b)); output.backward(); REQUIRE(a.grad().defined()); REQUIRE(a.grad().device() == torch::Device(torch::kCUDA, 0)); REQUIRE(a.grad().sum().toCInt() == 5); REQUIRE(b.grad().defined()); REQUIRE(b.grad().device() == torch::Device(torch::kCUDA, 1)); REQUIRE(b.grad().sum().toCInt() == 5); } TEST_CASE("Parallel/Replicate", "[multi-cuda]") { Linear linear(3, 4); auto replicas = parallel::replicate( linear, {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)}); REQUIRE(replicas.size() == 2); auto original_parameters = linear->parameters(); auto replica1_parameters = replicas[0]->parameters(); for (auto& parameter : replica1_parameters) { REQUIRE(parameter->device() == torch::Device(torch::kCUDA, 0)); } replicas[0]->to(torch::kCPU); REQUIRE(replica1_parameters.size() == original_parameters.size()); for (size_t i = 0; i < original_parameters.size(); ++i) { REQUIRE(replica1_parameters[i]->allclose(*original_parameters[i])); REQUIRE( replica1_parameters[i].data() != original_parameters[i].data()); } auto replica2_parameters = replicas[1]->parameters(); for (auto& parameter : replica2_parameters) { REQUIRE(parameter->device() == torch::Device(torch::kCUDA, 1)); } replicas[1]->to(torch::kCPU); REQUIRE(replica2_parameters.size() == original_parameters.size()); for (size_t i = 0; i < original_parameters.size(); ++i) { REQUIRE(replica2_parameters[i]->allclose(*original_parameters[i])); REQUIRE( replica2_parameters[i].data() != original_parameters[i].data()); } } TEST_CASE("Parallel/ParallelApply", "[multi-cuda]") { Linear a(3, 4); Linear b(std::static_pointer_cast(a->clone())); b->to({torch::kCUDA, 0}); Linear c(std::static_pointer_cast(a->clone())); c->to({torch::kCUDA, 1}); std::vector modules = {a, b, c}; std::vector inputs = { torch::ones({2, 3}), torch::ones({2, 3}, torch::device({torch::kCUDA, 0})), torch::ones({2, 3}, torch::device({torch::kCUDA, 1}))}; auto outputs = parallel::parallel_apply(modules, inputs); REQUIRE(outputs.size() == 3); REQUIRE(outputs[0].device().is_cpu()); REQUIRE(outputs[1].device() == torch::Device(torch::kCUDA, 0)); REQUIRE(outputs[1].to(torch::kCPU).allclose(outputs[0])); REQUIRE(outputs[2].device() == torch::Device(torch::kCUDA, 1)); REQUIRE(outputs[2].to(torch::kCPU).allclose(outputs[0])); } TEST_CASE("Parallel/ParallelApplyWithDifferentOutputDevice", "[multi-cuda]") { struct M : torch::nn::Module { torch::Tensor forward(torch::Tensor input) { return torch::ones({5}, torch::dtype(torch::kInt32)); } }; std::vector> modules = { std::make_shared(), std::make_shared(), std::make_shared()}; std::vector inputs = { torch::empty({}), torch::empty({}), torch::empty({})}; std::vector devices = { {torch::kCUDA, 1}, {torch::kCUDA, 0}, {torch::kCPU}}; auto outputs = parallel::parallel_apply(modules, inputs, devices); REQUIRE(outputs.size() == 3); REQUIRE(outputs[0].device().is_cuda()); REQUIRE(outputs[0].device() == torch::Device(torch::kCUDA, 1)); REQUIRE(outputs[1].device().is_cuda()); REQUIRE(outputs[1].device() == torch::Device(torch::kCUDA, 0)); REQUIRE(outputs[2].device().is_cpu()); } TEST_CASE("Parallel/ParallelApplyRethrowsException", "[multi-cuda]") { struct M : torch::nn::Cloneable { void reset() override {} torch::Tensor forward(torch::Tensor input) { throw std::runtime_error("Badness!"); } }; auto m = std::make_shared(); auto input = torch::ones({10, 3}); REQUIRE_THROWS_WITH( parallel::data_parallel(m, input), StartsWith("Badness!")); } TEST_CASE( "Parallel/DataParallelPlacesTheOutputOnTheRequestedDevice", "[multi-cuda]") { struct M : torch::nn::Cloneable { void reset() override {} torch::Tensor forward(torch::Tensor input) { // Intermediate tensors should be on the replica's current device. intermediate_tensor = torch::rand(5); // The returned tensor should be on the output device. return torch::ones(3); } torch::Tensor intermediate_tensor; }; auto m = std::make_shared(); auto input = torch::ones({10, 3}); { auto output = parallel::data_parallel( m, input, /*devices=*/at::nullopt, /*output_device=*/torch::Device(torch::kCUDA, 1)); REQUIRE(output.defined()); REQUIRE(output.device().is_cuda()); REQUIRE(output.device().index() == 1); } { // Verify for the single-device case (where we don't scatter/gather). auto output = parallel::data_parallel( m, input, /*devices=*/std::vector{torch::Device(torch::kCUDA, 0)}, /*output_device=*/torch::Device(torch::kCUDA, 1)); REQUIRE(m->intermediate_tensor.defined()); REQUIRE(m->intermediate_tensor.device().is_cuda()); REQUIRE(m->intermediate_tensor.device().index() == 0); REQUIRE(output.defined()); REQUIRE(output.device().is_cuda()); REQUIRE(output.device().index() == 1); } } TEST_CASE("Parallel/DataParallelUsesAllAvailableCUDADevices", "[cuda]") { struct M : torch::nn::Cloneable { void reset() override {} torch::Tensor forward(torch::Tensor input) { return torch::tensor(torch::DefaultTensorOptions::get().device().index()); } }; auto m = std::make_shared(); auto input = torch::ones({10, 3}); auto output = parallel::data_parallel(m, input); const auto device_count = torch::cuda::device_count(); REQUIRE(output.numel() == device_count); for (size_t i = 0; i < device_count; ++i) { REQUIRE(output[i].toCInt() == i); } } #endif