mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
12 lines
301 B
C++
12 lines
301 B
C++
#include <catch.hpp>
|
|
|
|
#include <ATen/ATen.h>
|
|
|
|
#include <cmath>
|
|
|
|
TEST_CASE("Tensor/AllocatesTensorOnTheCorrectDevice", "[cuda]") {
|
|
auto tensor = at::tensor({1, 2, 3}, at::device({at::kCUDA, 1}));
|
|
REQUIRE(tensor.device().type() == at::Device::Type::CUDA);
|
|
REQUIRE(tensor.device().index() == 1);
|
|
}
|