mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45018 Now that https://github.com/pytorch/pytorch/pull/44795 has landed, we can convert the bulk of our cpp tests to use gtest APIs. Eventually we'll want to get rid of our weird harness for cpp tests entirely in favor of using regular gtest everywhere. This PR demonstrates some of the benefits of this approach: 1. You don't need to register your test twice (once to define it, once in tests.h). 2. Consequently, it's easier to have many individual test cases. Failures can be reported independently (rather than having huge functions to test entire modules. 3. Some nicer testing APIs, notably test fixtures. Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D23802297 Pulled By: suo fbshipit-source-id: 774255da7716294ac573747dcd5e106e5fe3ac8f
128 lines
3.8 KiB
C++
128 lines
3.8 KiB
C++
#include "test/cpp/jit/test_base.h"
|
|
#include "test/cpp/jit/test_utils.h"
|
|
|
|
#include <stdexcept>
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void testTypeCheck() {
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%a.1 : Tensor,
|
|
%b.1 : Tensor):
|
|
%t0 : Float(2:2, 2:1, device=cpu, requires_grad=1), %t1 : Float(3:3, 3:1), %type_matched : bool = prim::TypeCheck(%a.1, %b.1)
|
|
return (%t0, %t1, %type_matched)
|
|
)IR",
|
|
&*graph,
|
|
vmap);
|
|
|
|
Code function(graph, "");
|
|
InterpreterState interp(function);
|
|
{
|
|
// TypeCheck yields to true! Shape, grad and device matches.
|
|
auto a = at::zeros({2, 2}, at::kFloat);
|
|
auto b = at::ones({3, 3}, at::kFloat);
|
|
a.set_requires_grad(true);
|
|
a = a.to(at::kCPU);
|
|
std::vector<IValue> stack({a, b});
|
|
interp.run(stack);
|
|
ASSERT_TRUE(exactlyEqual(stack[0].toTensor(), a));
|
|
ASSERT_TRUE(exactlyEqual(stack[1].toTensor(), b));
|
|
ASSERT_TRUE(stack[2].toBool());
|
|
}
|
|
{
|
|
auto a = at::zeros({2, 2}, at::kFloat);
|
|
auto b = at::ones({2, 2}, at::kFloat); // Size mismatch
|
|
a.set_requires_grad(true);
|
|
a = a.to(at::kCPU);
|
|
std::vector<IValue> stack({a, b});
|
|
interp.run(stack);
|
|
ASSERT_FALSE(stack[2].toBool());
|
|
}
|
|
{
|
|
auto a = at::zeros({2, 2}, at::kFloat);
|
|
auto b = at::ones({3, 3}, at::kFloat);
|
|
a = a.to(at::kCPU);
|
|
a.set_requires_grad(false); // Gradient mismatch
|
|
std::vector<IValue> stack({a, b});
|
|
interp.run(stack);
|
|
ASSERT_FALSE(stack[2].toBool());
|
|
}
|
|
{
|
|
auto a = at::zeros({2, 2}, at::kFloat);
|
|
auto b = at::ones({3, 3}, at::kFloat);
|
|
a = a.to(at::kCPU);
|
|
a.set_requires_grad(true);
|
|
a = a.to(at::kInt); // Scalar type mismatch
|
|
std::vector<IValue> stack({a, b});
|
|
interp.run(stack);
|
|
ASSERT_FALSE(stack[2].toBool());
|
|
}
|
|
{
|
|
auto a = at::zeros({2, 2}, at::kFloat);
|
|
auto b = at::ones({3, 3}, at::kFloat);
|
|
a.set_requires_grad(true);
|
|
a = a.to(at::kCUDA); // Device mismatch
|
|
std::vector<IValue> stack({a, b});
|
|
interp.run(stack);
|
|
ASSERT_FALSE(stack[2].toBool());
|
|
}
|
|
}
|
|
|
|
try { // Test empty Typecheck raises an internal assertion
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%a.1 : Tensor,
|
|
%b.1 : Tensor):
|
|
%type_matched : bool = prim::TypeCheck()
|
|
return (%type_matched)
|
|
)IR",
|
|
&*graph,
|
|
vmap);
|
|
} catch (const std::exception& e) {
|
|
}
|
|
try { // Test for assertion if num_inputs + 1 != num_outputs
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%a.1 : Tensor,
|
|
%b.1 : Tensor):
|
|
%type_matched : bool = prim::TypeCheck(%a.1)
|
|
return (%type_matched)
|
|
)IR",
|
|
&*graph,
|
|
vmap);
|
|
} catch (const std::exception& e) {
|
|
}
|
|
}
|
|
void testInterp() {
|
|
constexpr int batch_size = 4;
|
|
constexpr int input_size = 256;
|
|
constexpr int seq_len = 32;
|
|
|
|
int hidden_size = 2 * input_size;
|
|
|
|
auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA);
|
|
auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
|
|
auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
|
|
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
|
|
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
|
|
|
|
auto lstm_g = build_lstm();
|
|
Code lstm_function(lstm_g, "");
|
|
InterpreterState lstm_interp(lstm_function);
|
|
auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
|
|
std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
|
|
|
|
ASSERT_TRUE(exactlyEqual(outputs[0], hx));
|
|
ASSERT_TRUE(exactlyEqual(outputs[1], cx));
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|