mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
port all JIT tests to gtest (#45264)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45264 Context for why we are porting to gtest in: https://github.com/pytorch/pytorch/pull/45018. This PR completes the process of porting and removes unused files/macros. Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D23901392 Pulled By: suo fbshipit-source-id: 89526890e1a49462f3f77718f4ee273c5bc578ba
This commit is contained in:
parent
5a0514e3e6
commit
22401b850b
|
|
@ -1,6 +1,6 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <test/cpp/jit/test_base.h>
|
||||
#include <test/cpp/tensorexpr/test_base.h>
|
||||
#include <thread>
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,12 +19,9 @@ endif()
|
|||
|
||||
# Build the cpp gtest binary containing the cpp-only tests.
|
||||
set(JIT_TEST_SRCS
|
||||
${JIT_TEST_ROOT}/gtest.cpp
|
||||
${JIT_TEST_ROOT}/test_alias_analysis.cpp
|
||||
${JIT_TEST_ROOT}/test_argument_spec.cpp
|
||||
${JIT_TEST_ROOT}/test_autodiff.cpp
|
||||
${JIT_TEST_ROOT}/test_base.cpp
|
||||
${JIT_TEST_ROOT}/test_base.h
|
||||
${JIT_TEST_ROOT}/test_class_import.cpp
|
||||
${JIT_TEST_ROOT}/test_class_parser.cpp
|
||||
${JIT_TEST_ROOT}/test_class_type.cpp
|
||||
|
|
|
|||
|
|
@ -1,69 +1,44 @@
|
|||
# JIT C++ Tests
|
||||
|
||||
## How to add a new test
|
||||
## Adding a new test
|
||||
First, create a new test file. Test files should have be placed in this
|
||||
directory, with a name that starts with `test_`, like `test_foo.cpp`.
|
||||
|
||||
Here is an example test file you can copy-paste.
|
||||
In general a single test suite
|
||||
|
||||
Add your test file to the `JIT_TEST_SRCS` list in `test/cpp/jit/CMakeLists.txt`.
|
||||
|
||||
A test file may look like:
|
||||
```cpp
|
||||
#include <test/cpp/jit/test_base.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
// Tests go in torch::jit
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using namespace ::torch::jit
|
||||
|
||||
// 1. Test cases are void() functions.
|
||||
// 2. They start with the prefix `test`
|
||||
void testCaseOne() {
|
||||
// ...
|
||||
TEST(FooTest, BarBaz) {
|
||||
// ...
|
||||
}
|
||||
|
||||
void testCaseTwo() {
|
||||
// ...
|
||||
}
|
||||
// Append '_CUDA' to the test case name will automatically filter it out if CUDA
|
||||
// is not compiled.
|
||||
TEST(FooTest, NeedsAGpu_CUDA) {
|
||||
// ...
|
||||
}
|
||||
|
||||
// Similarly, if only one GPU is detected, tests with `_MultiCUDA` at the end
|
||||
// will not be run.
|
||||
TEST(FooTest, NeedsMultipleGpus_MultiCUDA) {
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
Then, register your test in `tests.h`:
|
||||
```cpp
|
||||
// Add to TH_FORALL_TESTS_CUDA instead for CUDA-requiring tests
|
||||
#define TH_FORALL_TESTS(_) \
|
||||
_(ADFormulas) \
|
||||
_(Attributes) \
|
||||
...
|
||||
_(CaseOne) // note that the `test` prefix is omitted.
|
||||
_(CaseTwo)
|
||||
```
|
||||
|
||||
We glob all the test files together in `CMakeLists.txt` so that you don't
|
||||
have to edit it every time you add a test. Unfortunately, this means that in
|
||||
order to get the build to pick up your new test file, you need to re-run
|
||||
cmake:
|
||||
```
|
||||
python setup.py build --cmake
|
||||
```
|
||||
|
||||
## Why do we have two different test runners?
|
||||
We have two different ways of running our cpp tests:
|
||||
1. With `gtest`, from a standalone binary.
|
||||
2. With Python, from `TestJit.test_cpp` and `TestJit.test_cpp_cuda` (in
|
||||
`test/test_jit.py`)
|
||||
|
||||
We want both because we need to test things from a pure-C++ environment and
|
||||
with all our various Python patch-points enabled.
|
||||
|
||||
## How do I run the tests?
|
||||
## Building and running the tests
|
||||
The following commands assume you are in PyTorch root.
|
||||
|
||||
1. With `gtest`:
|
||||
```bash
|
||||
# (re)build the test binary
|
||||
ninja build/bin/test_jit
|
||||
# run
|
||||
build/bin/test_jit --gtest_filter='glob_style_filter*'
|
||||
```
|
||||
2. With Python:
|
||||
```
|
||||
python test/test_jit.py TestJit.test_cpp TestJit.test_cpp_cuda
|
||||
```
|
||||
```bash
|
||||
# ... Build PyTorch from source, e.g.
|
||||
python setup.py develop
|
||||
# (re)build just the binary
|
||||
ninja -C build bin/test_jit
|
||||
# run tests
|
||||
build/bin/test_jit --gtest_filter='glob_style_filter*'
|
||||
```
|
||||
|
|
|
|||
|
|
@ -1,23 +0,0 @@
|
|||
#include <test/cpp/jit/tests.h>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
#define JIT_GTEST(name) \
|
||||
TEST(JitTest, name) { \
|
||||
test##name(); \
|
||||
}
|
||||
TH_FORALL_TESTS(JIT_GTEST)
|
||||
#undef JIT_TEST
|
||||
|
||||
#define JIT_GTEST_CUDA(name) \
|
||||
TEST(JitTest, name##_CUDA) { \
|
||||
test##name(); \
|
||||
}
|
||||
TH_FORALL_TESTS_CUDA(JIT_GTEST_CUDA)
|
||||
#undef JIT_TEST_CUDA
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
#include <test/cpp/jit/test_base.h>
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
#include "torch/csrc/jit/runtime/custom_operator.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
|
||||
return c10::AliasAnalysisKind::FROM_SCHEMA;
|
||||
}
|
||||
|
||||
namespace {
|
||||
RegisterOperators reg({
|
||||
// This operator is intended to be used in JIT analysis and transformation
|
||||
// pass unit tests in which Values with type Tensor are often required. It
|
||||
// should not be used in situations in which the graph is actually executed
|
||||
// because it always produces empty Tensors.
|
||||
Operator(
|
||||
"prim::MakeTestTensor() -> Tensor",
|
||||
[](Stack* stack) { push(stack, at::Tensor()); },
|
||||
aliasAnalysisFromSchema()),
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
// This file defines assertion macros that work in both gtest and non-gtest
|
||||
// builds, and has some common includes.
|
||||
#include "torch/csrc/jit/ir/ir.h"
|
||||
#include "torch/csrc/jit/runtime/operator.h"
|
||||
|
||||
#if defined(USE_GTEST)
|
||||
#include <gtest/gtest.h>
|
||||
#include <test/cpp/common/support.h>
|
||||
#else
|
||||
#include "c10/util/Exception.h"
|
||||
// Temporary: we are going to remove these polyfills entirely.
|
||||
// But for now avoid redefining them if they are already defined in gtest.
|
||||
// (ASSERT_EQ is a proxy for whether gtest is already present)
|
||||
#ifndef ASSERT_EQ
|
||||
#define ASSERT_EQ(x, y) TORCH_INTERNAL_ASSERT((x) == (y))
|
||||
#define ASSERT_NE(x, y) TORCH_INTERNAL_ASSERT((x) != (y))
|
||||
#define ASSERT_TRUE TORCH_INTERNAL_ASSERT
|
||||
#define ASSERT_FALSE(x) ASSERT_TRUE(!(x))
|
||||
#define ASSERT_THROWS_WITH(statement, substring) \
|
||||
try { \
|
||||
(void)statement; \
|
||||
ASSERT_TRUE(false); \
|
||||
} catch (const std::exception& e) { \
|
||||
ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
|
||||
}
|
||||
#define ASSERT_ANY_THROW(statement) \
|
||||
{ \
|
||||
bool threw = false; \
|
||||
try { \
|
||||
(void)statement; \
|
||||
} catch (const std::exception& e) { \
|
||||
threw = true; \
|
||||
} \
|
||||
ASSERT_TRUE(threw); \
|
||||
}
|
||||
#endif // ndef(ASSERT_EQ)
|
||||
|
||||
#endif // defined(USE_GTEST)
|
||||
|
||||
static inline bool isSandcastle() {
|
||||
return (
|
||||
(std::getenv("SANDCASTLE")) ||
|
||||
(std::getenv("TW_JOB_USER") &&
|
||||
std::string(std::getenv("TW_JOB_USER")) == "sandcastle"));
|
||||
}
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <test/cpp/jit/test_base.h>
|
||||
#include <torch/csrc/jit/frontend/parser.h>
|
||||
#include <torch/csrc/jit/frontend/resolver.h>
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
#include <test/cpp/jit/test_base.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testClassTypeAddRemoveAttr() {
|
||||
TEST(ClassTypeTest, AddRemoveAttr) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
auto cls = ClassType::create("foo.bar", cu, true);
|
||||
cls->addAttribute("attr1", TensorType::get(), true);
|
||||
|
|
@ -32,12 +33,12 @@ void testClassTypeAddRemoveAttr() {
|
|||
cls->addAttribute("attr1", IntType::get());
|
||||
}
|
||||
|
||||
void testClassTypeAddRemoveConstant() {
|
||||
TEST(ClassTypeTest, AddRemoveConstant) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
auto cls = ClassType::create("foo.bar", cu);
|
||||
cls->addConstant("const1", IValue(1));
|
||||
cls->addConstant("const2", IValue(2));
|
||||
cls->addConstant("const3", IValue(2));
|
||||
cls->addConstant("const3", IValue(3));
|
||||
ASSERT_EQ(cls->numConstants(), 3);
|
||||
ASSERT_TRUE(cls->hasConstant("const1"));
|
||||
ASSERT_TRUE(cls->hasConstant("const2"));
|
||||
|
|
@ -46,7 +47,7 @@ void testClassTypeAddRemoveConstant() {
|
|||
|
||||
ASSERT_EQ(cls->getConstant("const1").toInt(), 1);
|
||||
ASSERT_EQ(cls->getConstant("const2").toInt(), 2);
|
||||
ASSERT_EQ(cls->getConstant("const2").toInt(), 3);
|
||||
ASSERT_EQ(cls->getConstant("const3").toInt(), 3);
|
||||
|
||||
cls->unsafeRemoveConstant("const2");
|
||||
ASSERT_TRUE(cls->hasConstant("const1"));
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
#if defined(USE_CUDA)
|
||||
|
||||
#include <test/cpp/jit/test_base.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/arith.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/codegen.h>
|
||||
|
|
@ -93,7 +92,7 @@ void checkIntValue(
|
|||
// (These tests exercise IrGraphGenerator through a non-trivial IR,
|
||||
// to make sure that it runs w/o crashing. The actual output is not
|
||||
// validated)
|
||||
void testGPU_IrGraphGenerator() {
|
||||
TEST(NVFuserTest, IrGraphGenerator_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -145,7 +144,7 @@ void testGPU_IrGraphGenerator() {
|
|||
.empty());
|
||||
}
|
||||
|
||||
void testGPU_FusionDispatch() {
|
||||
TEST(NVFuserTest, FusionDispatch_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -160,7 +159,7 @@ void testGPU_FusionDispatch() {
|
|||
}
|
||||
|
||||
// Evaluate basic scalar operations with constant values
|
||||
void testGPU_FusionExprEvalConstants() {
|
||||
TEST(NVFuserTest, FusionExprEvalConstants_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -177,7 +176,7 @@ void testGPU_FusionExprEvalConstants() {
|
|||
}
|
||||
|
||||
// Evaluate basic scalar operations with bound values
|
||||
void testGPU_FusionExprEvalBindings() {
|
||||
TEST(NVFuserTest, FusionExprEvalBindings_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -222,7 +221,7 @@ void testGPU_FusionExprEvalBindings() {
|
|||
}
|
||||
|
||||
// Evaluate expressions in a simple IR
|
||||
void testGPU_FusionExprEvalBasic() {
|
||||
TEST(NVFuserTest, FusionExprEvalBasic_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -278,7 +277,7 @@ void testGPU_FusionExprEvalBasic() {
|
|||
}
|
||||
|
||||
// Evaluate expressions in a more complex IR
|
||||
void testGPU_FusionExprEvalComplex() {
|
||||
TEST(NVFuserTest, FusionExprEvalComplex_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -330,7 +329,7 @@ void testGPU_FusionExprEvalComplex() {
|
|||
}
|
||||
|
||||
// Evaluate expressions post lowering
|
||||
void testGPU_FusionExprEvalPostLower() {
|
||||
TEST(NVFuserTest, FusionExprEvalPostLower_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -387,7 +386,7 @@ void testGPU_FusionExprEvalPostLower() {
|
|||
checkIntValue(evaluator, tid_x, 128);
|
||||
}
|
||||
|
||||
void testGPU_FusionClear() {
|
||||
TEST(NVFuserTest, FusionClear_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -467,7 +466,7 @@ void testGPU_FusionClear() {
|
|||
TORCH_CHECK(output_ref.equal(outputs[0]));
|
||||
}
|
||||
|
||||
void testGPU_FusionCopy() {
|
||||
TEST(NVFuserTest, FusionCopy_CUDA) {
|
||||
Fusion original_fusion;
|
||||
|
||||
// Create the test IR
|
||||
|
|
@ -541,7 +540,7 @@ void testGPU_FusionCopy() {
|
|||
ASSERT_EQ(original_kernel, clone_kernel);
|
||||
}
|
||||
|
||||
void testGPU_FusionMove() {
|
||||
TEST(NVFuserTest, FusionMove_CUDA) {
|
||||
Fusion fusion;
|
||||
|
||||
// Create the test IR
|
||||
|
|
@ -611,7 +610,7 @@ void testGPU_FusionMove() {
|
|||
ASSERT_EQ(lowered_ir.str(), moved_lowered_ir.str());
|
||||
}
|
||||
|
||||
void testGPU_FusionSimpleArith() {
|
||||
TEST(NVFuserTest, FusionSimpleArith_CUDA) {
|
||||
std::stringstream ss1, ss2;
|
||||
|
||||
Fusion fusion;
|
||||
|
|
@ -640,7 +639,7 @@ void testGPU_FusionSimpleArith() {
|
|||
"Error where explicit add nodes don't match implicit add nodes.");
|
||||
}
|
||||
|
||||
void testGPU_FusionSimpleTypePromote() {
|
||||
TEST(NVFuserTest, FusionSimpleTypePromote_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -663,7 +662,7 @@ class ZeroMutator : public OptOutMutator {
|
|||
}
|
||||
};
|
||||
|
||||
void testGPU_FusionMutator() {
|
||||
TEST(NVFuserTest, FusionMutator_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -681,7 +680,7 @@ void testGPU_FusionMutator() {
|
|||
TORCH_CHECK(flhs->value().value() == 0.f);
|
||||
}
|
||||
|
||||
void testGPU_FusionRegister() {
|
||||
TEST(NVFuserTest, FusionRegister_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
Float* v1 = new Float{1.f};
|
||||
|
|
@ -712,7 +711,7 @@ struct DummyExpr : public Expr {
|
|||
DummyExpr& operator=(DummyExpr&& other) = delete;
|
||||
};
|
||||
|
||||
void testGPU_FusionTopoSort() {
|
||||
TEST(NVFuserTest, FusionTopoSort_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -779,7 +778,7 @@ void testGPU_FusionTopoSort() {
|
|||
TORCH_CHECK(fusion.origin(v6)->name() == 3);
|
||||
}
|
||||
|
||||
void testGPU_FusionTensor() {
|
||||
TEST(NVFuserTest, FusionTensor_CUDA) {
|
||||
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
|
||||
|
||||
Fusion fusion;
|
||||
|
|
@ -843,7 +842,7 @@ void testGPU_FusionTensor() {
|
|||
}
|
||||
}
|
||||
|
||||
void testGPU_FusionFilterVals() {
|
||||
TEST(NVFuserTest, FusionFilterVals_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -881,7 +880,7 @@ void testGPU_FusionFilterVals() {
|
|||
"Not expecting any results");
|
||||
}
|
||||
|
||||
void testGPU_FusionTVSplit() {
|
||||
TEST(NVFuserTest, FusionTVSplit_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -907,7 +906,7 @@ void testGPU_FusionTVSplit() {
|
|||
static_cast<Int*>(inner->extent())->value().value() == 2);
|
||||
}
|
||||
|
||||
void testGPU_FusionTVMerge() {
|
||||
TEST(NVFuserTest, FusionTVMerge_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -925,7 +924,7 @@ void testGPU_FusionTVMerge() {
|
|||
tv->getRootDomain()[2]->extent());
|
||||
}
|
||||
|
||||
void testGPU_FusionTVReorder() {
|
||||
TEST(NVFuserTest, FusionTVReorder_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -972,7 +971,7 @@ void testGPU_FusionTVReorder() {
|
|||
TORCH_CHECK(ref[1]->sameAs(tv->axis(1)));
|
||||
}
|
||||
|
||||
void testGPU_FusionEquality() {
|
||||
TEST(NVFuserTest, FusionEquality_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -1013,7 +1012,7 @@ void testGPU_FusionEquality() {
|
|||
TORCH_CHECK(!neg1->sameAs(neg2));
|
||||
}
|
||||
|
||||
void testGPU_FusionDependency() {
|
||||
TEST(NVFuserTest, FusionDependency_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -1083,7 +1082,7 @@ void testGPU_FusionDependency() {
|
|||
TORCH_CHECK(dep_chain.empty());
|
||||
}
|
||||
|
||||
void testGPU_FusionParser() {
|
||||
TEST(NVFuserTest, FusionParser_CUDA) {
|
||||
auto g = std::make_shared<Graph>();
|
||||
const auto graph0_string = R"IR(
|
||||
graph(%0 : Float(2:1),
|
||||
|
|
@ -1163,7 +1162,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Te
|
|||
TORCH_CHECK(output_ref.equal(outputs[0]));
|
||||
}
|
||||
|
||||
void testGPU_FusionForLoop() {
|
||||
TEST(NVFuserTest, FusionForLoop_CUDA) {
|
||||
// TODO(kir): re-enable this test
|
||||
// due to the current "GpuLower guard" approach, we can only create
|
||||
// kernel IR during GpuLower::lower()
|
||||
|
|
@ -1204,7 +1203,7 @@ void testGPU_FusionForLoop() {
|
|||
#endif
|
||||
}
|
||||
|
||||
void testGPU_FusionCodeGen() {
|
||||
TEST(NVFuserTest, FusionCodeGen_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -1241,7 +1240,7 @@ void testGPU_FusionCodeGen() {
|
|||
TORCH_CHECK(output_ref.equal(output));
|
||||
}
|
||||
|
||||
void testGPU_FusionCodeGen2() {
|
||||
TEST(NVFuserTest, FusionCodeGen2_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -1283,7 +1282,7 @@ void testGPU_FusionCodeGen2() {
|
|||
TORCH_CHECK(output_ref.equal(outputs[0]));
|
||||
}
|
||||
|
||||
void testGPU_FusionSimplePWise() {
|
||||
TEST(NVFuserTest, FusionSimplePWise_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
// dimensionality of the problem
|
||||
|
|
@ -1340,7 +1339,7 @@ void testGPU_FusionSimplePWise() {
|
|||
TORCH_CHECK(output_ref.equal(output));
|
||||
}
|
||||
|
||||
void testGPU_FusionExecKernel() {
|
||||
TEST(NVFuserTest, FusionExecKernel_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -1394,7 +1393,7 @@ int ceilDiv_(int a, int b) {
|
|||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
void testGPU_FusionAdvancedComputeAt() {
|
||||
TEST(NVFuserTest, FusionAdvancedComputeAt_CUDA) {
|
||||
// Case 1
|
||||
// tv1 = tv0 * 0.5
|
||||
// tv2 = tv1 * -1
|
||||
|
|
@ -1693,7 +1692,7 @@ void testGPU_FusionAdvancedComputeAt() {
|
|||
}
|
||||
}
|
||||
|
||||
void testGPU_FusionComputeAtMultiConsumers() {
|
||||
TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) {
|
||||
// tv1 = tv0 * 0.5
|
||||
// tv2 = tv1 * -1
|
||||
// tv3 = tv2 * -2
|
||||
|
|
@ -1754,7 +1753,7 @@ void testGPU_FusionComputeAtMultiConsumers() {
|
|||
}
|
||||
|
||||
// Similar to ComputeAtMultiConsumers, but with a common consumer.
|
||||
void testGPU_FusionComputeAtCommonConsumer1() {
|
||||
TEST(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) {
|
||||
// tv1 = tv0 * 0.5
|
||||
// tv2 = tv1 * -1
|
||||
// tv3 = tv2 * -2
|
||||
|
|
@ -1825,7 +1824,7 @@ void testGPU_FusionComputeAtCommonConsumer1() {
|
|||
TORCH_CHECK(at::allclose(kernel_tv5, t5));
|
||||
}
|
||||
|
||||
void testGPU_FusionComputeAtCommonConsumer2() {
|
||||
TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) {
|
||||
// tv1 = tv0 * 0.5
|
||||
// tv2 = tv1 * -1
|
||||
// tv3 = tv2 * -1
|
||||
|
|
@ -1912,7 +1911,7 @@ void testGPU_FusionComputeAtCommonConsumer2() {
|
|||
|
||||
// Similar to the above common consumer test but adds an additional
|
||||
// tensor that has no common consumer with the other tensors.
|
||||
void testGPU_FusionComputeAtCommonConsumer3() {
|
||||
TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) {
|
||||
// tv1 = tv0 * 0.5
|
||||
// tv2 = tv1 * -1
|
||||
// tv3 = tv2 * -1
|
||||
|
|
@ -2010,7 +2009,7 @@ void testGPU_FusionComputeAtCommonConsumer3() {
|
|||
|
||||
// Similar to ComputeAtCommonConsumer1 but with an addtiona ltensor
|
||||
// that does not have data dependency with the consumer.
|
||||
void testGPU_FusionComputeAtNoCommonConsumer() {
|
||||
TEST(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) {
|
||||
// tv1 = tv0 * 0.5
|
||||
// tv2 = tv1 * -1
|
||||
// tv3 = tv1 * -2
|
||||
|
|
@ -2102,7 +2101,7 @@ void checkConcretized(
|
|||
|
||||
} // namespace
|
||||
|
||||
void testGPU_FusionBCastConcretizeBasic() {
|
||||
TEST(NVFuserTest, FusionBCastConcretizeBasic_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -2132,7 +2131,7 @@ void testGPU_FusionBCastConcretizeBasic() {
|
|||
checkConcretized(tv2_0, 0, tv1, 1, false);
|
||||
}
|
||||
|
||||
void testGPU_FusionBCastConcretizeRfactor() {
|
||||
TEST(NVFuserTest, FusionBCastConcretizeRfactor_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -2181,7 +2180,7 @@ void checkIdProvedEquivalent(
|
|||
|
||||
} // namespace
|
||||
|
||||
void testGPU_FusionProveIdEqBasic() {
|
||||
TEST(NVFuserTest, FusionProveIdEqBasic_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -2206,7 +2205,7 @@ void testGPU_FusionProveIdEqBasic() {
|
|||
checkIdProvedEquivalent(tv0, 0, tv1, 1, false);
|
||||
}
|
||||
|
||||
void testGPU_FusionProveIdEqRfactor() {
|
||||
TEST(NVFuserTest, FusionProveIdEqRfactor_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -2236,7 +2235,7 @@ void testGPU_FusionProveIdEqRfactor() {
|
|||
checkIdProvedEquivalent(tv3, 0, tv0, 0, true);
|
||||
}
|
||||
|
||||
void testGPU_FusionScalarInputs() {
|
||||
TEST(NVFuserTest, FusionScalarInputs_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -2323,7 +2322,7 @@ void testGPU_FusionScalarInputs() {
|
|||
TORCH_CHECK(at::allclose(kernel_tv4, t4));
|
||||
}
|
||||
|
||||
void testGPU_FusionLoopUnroll() {
|
||||
TEST(NVFuserTest, FusionLoopUnroll_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -2564,7 +2563,7 @@ void test_op(
|
|||
std::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
void testGPU_FusionUnaryOps() {
|
||||
TEST(NVFuserTest, FusionUnaryOps_CUDA) {
|
||||
using OpTuple =
|
||||
std::tuple<at::Tensor (*)(const at::Tensor&), UnaryOpType, std::string>;
|
||||
|
||||
|
|
@ -2638,7 +2637,7 @@ void testGPU_FusionUnaryOps() {
|
|||
std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float)));
|
||||
}
|
||||
|
||||
void testGPU_FusionBinaryOps() {
|
||||
TEST(NVFuserTest, FusionBinaryOps_CUDA) {
|
||||
using AtenFuncSig = at::Tensor (*)(const at::Tensor&, const at::Tensor&);
|
||||
using OpTuple = std::tuple<AtenFuncSig, BinaryOpType, std::string>;
|
||||
|
||||
|
|
@ -2738,7 +2737,7 @@ void testGPU_FusionBinaryOps() {
|
|||
std::make_pair(ValType::Scalar, DataType::Float)));
|
||||
}
|
||||
|
||||
void testGPU_FusionTernaryOps() {
|
||||
TEST(NVFuserTest, FusionTernaryOps_CUDA) {
|
||||
test_op(
|
||||
/*blocks*/ 640,
|
||||
/*threads*/ 64,
|
||||
|
|
@ -2787,7 +2786,7 @@ void testGPU_FusionTernaryOps() {
|
|||
std::make_pair(ValType::TensorView, DataType::Float)));
|
||||
}
|
||||
|
||||
void testGPU_FusionCompoundOps() {
|
||||
TEST(NVFuserTest, FusionCompoundOps_CUDA) {
|
||||
test_op(
|
||||
/*blocks*/ 640,
|
||||
/*threads*/ 64,
|
||||
|
|
@ -2826,7 +2825,7 @@ void testGPU_FusionCompoundOps() {
|
|||
std::make_pair(ValType::Scalar, DataType::Float)));
|
||||
}
|
||||
|
||||
void testGPU_FusionCastOps() {
|
||||
TEST(NVFuserTest, FusionCastOps_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -2868,7 +2867,7 @@ void testGPU_FusionCastOps() {
|
|||
|
||||
// We want split/merge/reorder all tested both on and off rfactor domains, also
|
||||
// want compute at into the rfactor domain, and into its consumer
|
||||
void testGPU_FusionRFactorReplay() {
|
||||
TEST(NVFuserTest, FusionRFactorReplay_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -2961,7 +2960,7 @@ void testGPU_FusionRFactorReplay() {
|
|||
|
||||
// Start off simple, block on the outer dim
|
||||
// block stride + thread all reduce + unrolling on inner dim
|
||||
void testGPU_FusionReduction() {
|
||||
TEST(NVFuserTest, FusionReduction_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -3019,7 +3018,7 @@ void testGPU_FusionReduction() {
|
|||
TORCH_CHECK(aten_output.allclose(cg_output));
|
||||
}
|
||||
|
||||
void testGPU_FusionReduction2() {
|
||||
TEST(NVFuserTest, FusionReduction2_CUDA) {
|
||||
{
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
|
@ -3146,7 +3145,7 @@ void testGPU_FusionReduction2() {
|
|||
}
|
||||
}
|
||||
|
||||
void testGPU_FusionReduction3() {
|
||||
TEST(NVFuserTest, FusionReduction3_CUDA) {
|
||||
{
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
|
@ -3217,7 +3216,7 @@ void testGPU_FusionReduction3() {
|
|||
}
|
||||
}
|
||||
|
||||
void testGPU_FusionReduction4() {
|
||||
TEST(NVFuserTest, FusionReduction4_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -3269,7 +3268,7 @@ void testGPU_FusionReduction4() {
|
|||
aten_output.sub(cg_output).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionReduction5() {
|
||||
TEST(NVFuserTest, FusionReduction5_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -3330,7 +3329,7 @@ void testGPU_FusionReduction5() {
|
|||
TORCH_CHECK(aten_output.allclose(outputs[0]));
|
||||
}
|
||||
|
||||
void testGPU_FusionReductionTFT() {
|
||||
TEST(NVFuserTest, FusionReductionTFT_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -3387,7 +3386,7 @@ void testGPU_FusionReductionTFT() {
|
|||
TORCH_CHECK(aten_output.allclose(cg_output));
|
||||
}
|
||||
|
||||
void testGPU_FusionBranches() {
|
||||
TEST(NVFuserTest, FusionBranches_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -3444,7 +3443,7 @@ void testGPU_FusionBranches() {
|
|||
TORCH_CHECK(t6.allclose(outputs[0]));
|
||||
}
|
||||
|
||||
void testGPU_FusionSimpleBCast() {
|
||||
TEST(NVFuserTest, FusionSimpleBCast_CUDA) {
|
||||
{
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
|
@ -3708,7 +3707,7 @@ void testGPU_FusionSimpleBCast() {
|
|||
}
|
||||
}
|
||||
|
||||
void testGPU_FusionComplexBCast() {
|
||||
TEST(NVFuserTest, FusionComplexBCast_CUDA) {
|
||||
{
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
|
@ -3811,7 +3810,7 @@ void testGPU_FusionComplexBCast() {
|
|||
}
|
||||
}
|
||||
|
||||
void testGPU_FusionAdvancedIndexing() {
|
||||
TEST(NVFuserTest, FusionAdvancedIndexing_CUDA) {
|
||||
// Merging left to right is still broken in some instances. Indexing can't
|
||||
// complete because we assume we can simply traverse consumer->producer in the
|
||||
// index/extent map, but this case breaks this assumption.
|
||||
|
|
@ -3980,7 +3979,7 @@ void testGPU_FusionAdvancedIndexing() {
|
|||
}
|
||||
|
||||
// Test a simple Gemm but also play around with fusion executor features
|
||||
void testGPU_FusionSimpleGemm() {
|
||||
TEST(NVFuserTest, FusionSimpleGemm_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -4067,7 +4066,7 @@ void testGPU_FusionSimpleGemm() {
|
|||
}
|
||||
|
||||
// Softmax with a 1D tensor. Parallelized only with a single thread block.
|
||||
void testGPU_FusionSoftmax1D() {
|
||||
TEST(NVFuserTest, FusionSoftmax1D_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -4124,7 +4123,7 @@ void testGPU_FusionSoftmax1D() {
|
|||
}
|
||||
|
||||
// Softmax with a 1D tensor with input normalization.
|
||||
void testGPU_FusionSoftmax1DNormalized() {
|
||||
TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -4195,7 +4194,7 @@ void testGPU_FusionSoftmax1DNormalized() {
|
|||
|
||||
// Softmax with a 3D tensor, where the inner-most 3rd dimension is
|
||||
// normalized. Pallelized with multiple thread blocks.
|
||||
void testGPU_FusionSoftmax3D() {
|
||||
TEST(NVFuserTest, FusionSoftmax3D_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -4255,7 +4254,7 @@ void testGPU_FusionSoftmax3D() {
|
|||
}
|
||||
|
||||
// Softmax with a 3D tensor with input normalization.
|
||||
void testGPU_FusionSoftmax3DNormalized() {
|
||||
TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -4329,7 +4328,7 @@ void testGPU_FusionSoftmax3DNormalized() {
|
|||
t2.sub(outputs[0]).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionSoftmaxComputeAt() {
|
||||
TEST(NVFuserTest, FusionSoftmaxComputeAt_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -4355,7 +4354,7 @@ void testGPU_FusionSoftmaxComputeAt() {
|
|||
}
|
||||
|
||||
// Similar to FusionReduction but uses grid reduction
|
||||
void testGPU_FusionGridReduction1() {
|
||||
TEST(NVFuserTest, FusionGridReduction1_CUDA) {
|
||||
const int gdimx = 32;
|
||||
const int bdimx = 128;
|
||||
|
||||
|
|
@ -4413,7 +4412,7 @@ void testGPU_FusionGridReduction1() {
|
|||
}
|
||||
|
||||
// Same test as the above but uses BIDy and TIDx for reduction
|
||||
void testGPU_FusionGridReduction2() {
|
||||
TEST(NVFuserTest, FusionGridReduction2_CUDA) {
|
||||
const int gdimy = 32;
|
||||
const int bdimx = 128;
|
||||
|
||||
|
|
@ -4468,7 +4467,7 @@ void testGPU_FusionGridReduction2() {
|
|||
}
|
||||
|
||||
// Same test but uses BIDy and BIDz for reduction. No TID used.
|
||||
void testGPU_FusionGridReduction3dim1() {
|
||||
TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) {
|
||||
const int gdimz = 32;
|
||||
const int gdimy = 128;
|
||||
|
||||
|
|
@ -4524,7 +4523,7 @@ void testGPU_FusionGridReduction3dim1() {
|
|||
}
|
||||
|
||||
// Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0
|
||||
void testGPU_FusionGridReduction3dim0() {
|
||||
TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) {
|
||||
const int rdim = 0;
|
||||
const int gdimy = 128;
|
||||
const int gdimz = 32;
|
||||
|
|
@ -4577,7 +4576,7 @@ void testGPU_FusionGridReduction3dim0() {
|
|||
}
|
||||
|
||||
// This is similar to the FusionReduction, but swaps BIDx and TIDx
|
||||
void testGPU_FusionGridReduction4() {
|
||||
TEST(NVFuserTest, FusionGridReduction4_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -4640,7 +4639,7 @@ void testGPU_FusionGridReduction4() {
|
|||
|
||||
// Grid reduction with 2D thread blocks but only TIDx and BIDx are
|
||||
// mapped to a reduction dim
|
||||
void testGPU_FusionGridReduction5() {
|
||||
TEST(NVFuserTest, FusionGridReduction5_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -4692,7 +4691,7 @@ void testGPU_FusionGridReduction5() {
|
|||
}
|
||||
|
||||
// Similar to FusionGridReduction1 but with 3D tensors
|
||||
void testGPU_FusionGridReduction6() {
|
||||
TEST(NVFuserTest, FusionGridReduction6_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -4753,7 +4752,7 @@ void testGPU_FusionGridReduction6() {
|
|||
TORCH_CHECK(aten_output.allclose(cg_output));
|
||||
}
|
||||
|
||||
void testGPU_FusionNonRedAxisBind() {
|
||||
TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) {
|
||||
int bid_x = 3;
|
||||
int tid_x = 2;
|
||||
int red_dim = 0;
|
||||
|
|
@ -4788,7 +4787,7 @@ void testGPU_FusionNonRedAxisBind() {
|
|||
aten_output.sub(outputs[0]).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionSplitBCast() {
|
||||
TEST(NVFuserTest, FusionSplitBCast_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -4836,7 +4835,7 @@ void testGPU_FusionSplitBCast() {
|
|||
fe.runFusion({t0, t1}, {cg_output});
|
||||
}
|
||||
|
||||
void testGPU_FusionBCastInnerDim() {
|
||||
TEST(NVFuserTest, FusionBCastInnerDim_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -4850,7 +4849,7 @@ void testGPU_FusionBCastInnerDim() {
|
|||
TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast());
|
||||
}
|
||||
|
||||
void testGPU_FusionBCastReduce() {
|
||||
TEST(NVFuserTest, FusionBCastReduce_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -4866,7 +4865,7 @@ void testGPU_FusionBCastReduce() {
|
|||
|
||||
// Multiple consumer reduction with computeAt
|
||||
// https://github.com/csarofeen/pytorch/issues/110
|
||||
void testGPU_FusionReductionMultiConsumer() {
|
||||
TEST(NVFuserTest, FusionReductionMultiConsumer_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
TensorView* tv0 = makeDummyTensor(2);
|
||||
|
|
@ -4883,7 +4882,7 @@ void testGPU_FusionReductionMultiConsumer() {
|
|||
tv1->getThisComputeAtAxis() == 2 && tv1->getRelativeComputeAtAxis() == 2);
|
||||
}
|
||||
|
||||
void testGPU_FusionComputeAtExprOrder() {
|
||||
TEST(NVFuserTest, FusionComputeAtExprOrder_CUDA) {
|
||||
{
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
Fusion fusion;
|
||||
|
|
@ -4953,7 +4952,7 @@ void testGPU_FusionComputeAtExprOrder() {
|
|||
}
|
||||
}
|
||||
|
||||
void testGPU_FusionZeroDimComputeAt() {
|
||||
TEST(NVFuserTest, FusionZeroDimComputeAt_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -4980,7 +4979,7 @@ void testGPU_FusionZeroDimComputeAt() {
|
|||
aten_output.sub(outputs[0]).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionZeroDimBroadcast() {
|
||||
TEST(NVFuserTest, FusionZeroDimBroadcast_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -5016,7 +5015,7 @@ void testGPU_FusionZeroDimBroadcast() {
|
|||
aten_output.sub(output).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionZeroDimReduction() {
|
||||
TEST(NVFuserTest, FusionZeroDimReduction_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -5053,7 +5052,7 @@ void testGPU_FusionZeroDimReduction() {
|
|||
aten_output.sub(output).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionBCastAfterReduce() {
|
||||
TEST(NVFuserTest, FusionBCastAfterReduce_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
const int tidx = 128;
|
||||
|
|
@ -5104,7 +5103,7 @@ void testGPU_FusionBCastAfterReduce() {
|
|||
TORCH_CHECK(t5.allclose(outputs[0], 1e-5, 1e-5));
|
||||
}
|
||||
|
||||
void testGPU_FusionReductionScheduler() {
|
||||
TEST(NVFuserTest, FusionReductionScheduler_CUDA) {
|
||||
constexpr int bid_x = 80;
|
||||
constexpr int tid_x = 4096;
|
||||
constexpr int red_dim = 1;
|
||||
|
|
@ -5142,7 +5141,7 @@ void testGPU_FusionReductionScheduler() {
|
|||
}
|
||||
|
||||
// Simple reduction parallelized on a symbolic size.
|
||||
void testGPU_FusionSymbolicReduction() {
|
||||
TEST(NVFuserTest, FusionSymbolicReduction_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -5192,7 +5191,7 @@ void testGPU_FusionSymbolicReduction() {
|
|||
TORCH_CHECK(aten_output.allclose(outputs[0]));
|
||||
}
|
||||
|
||||
void testGPU_FusionReductionSchedulerMultiDimNonFastest() {
|
||||
TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) {
|
||||
const std::vector<int> red_dims = {0, 2};
|
||||
// Copy is because CodeGen requires int and Pytorch requires int64_t
|
||||
// for a vector of reduction dimensions
|
||||
|
|
@ -5232,7 +5231,7 @@ void testGPU_FusionReductionSchedulerMultiDimNonFastest() {
|
|||
aten_output.sub(outputs[0]).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionReductionSchedulerMultiDimFastest() {
|
||||
TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) {
|
||||
const std::vector<int> red_dims = {1, 3};
|
||||
// Copy is because CodeGen requires int and Pytorch requires int64_t
|
||||
// for a vector of reduction dimensions
|
||||
|
|
@ -5270,7 +5269,7 @@ void testGPU_FusionReductionSchedulerMultiDimFastest() {
|
|||
aten_output.sub(outputs[0]).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionReductionSchedulerDimShmoo() {
|
||||
TEST(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) {
|
||||
std::vector<bool> fp16_usage = {true, false};
|
||||
std::vector<int> red_axis = {1, 0};
|
||||
std::vector<int> output_dims = {320, 640};
|
||||
|
|
@ -5346,7 +5345,7 @@ void testGPU_FusionReductionSchedulerDimShmoo() {
|
|||
}
|
||||
}
|
||||
|
||||
void testGPU_FusionCacheBefore() {
|
||||
TEST(NVFuserTest, FusionCacheBefore_CUDA) {
|
||||
// TVM Cache Write
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
|
@ -5387,7 +5386,7 @@ void testGPU_FusionCacheBefore() {
|
|||
aten_output.sub(outputs[0]).abs().sum());
|
||||
}
|
||||
|
||||
void testGPU_FusionCacheAfter() {
|
||||
TEST(NVFuserTest, FusionCacheAfter_CUDA) {
|
||||
// TVM Cache Read
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
|
@ -5428,7 +5427,7 @@ void testGPU_FusionCacheAfter() {
|
|||
aten_output.sub(outputs[0]).abs().sum());
|
||||
}
|
||||
|
||||
void testGPU_FusionCacheIndirect() {
|
||||
TEST(NVFuserTest, FusionCacheIndirect_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -5477,7 +5476,7 @@ void testGPU_FusionCacheIndirect() {
|
|||
aten_output.sub(outputs[0]).abs().sum());
|
||||
}
|
||||
|
||||
void testGPU_FusionCacheBcast() {
|
||||
TEST(NVFuserTest, FusionCacheBcast_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -5536,7 +5535,7 @@ void testGPU_FusionCacheBcast() {
|
|||
aten_output.sub(outputs[0]).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionCacheComplex() {
|
||||
TEST(NVFuserTest, FusionCacheComplex_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -5591,7 +5590,7 @@ void testGPU_FusionCacheComplex() {
|
|||
aten_output.sub(outputs[0]).abs().sum());
|
||||
}
|
||||
|
||||
void testGPU_FusionCacheMultiConsumer() {
|
||||
TEST(NVFuserTest, FusionCacheMultiConsumer_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -5636,7 +5635,7 @@ void testGPU_FusionCacheMultiConsumer() {
|
|||
aten_output.sub(outputs[1]).abs().sum());
|
||||
}
|
||||
|
||||
void testGPU_FusionSmem() {
|
||||
TEST(NVFuserTest, FusionSmem_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -5691,7 +5690,7 @@ void testGPU_FusionSmem() {
|
|||
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
|
||||
}
|
||||
|
||||
void testGPU_FusionSmemReduce() {
|
||||
TEST(NVFuserTest, FusionSmemReduce_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -5742,7 +5741,7 @@ void testGPU_FusionSmemReduce() {
|
|||
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(24) == 1);
|
||||
}
|
||||
|
||||
void testGPU_FusionSmemBlockGemm() {
|
||||
TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -5805,7 +5804,7 @@ void testGPU_FusionSmemBlockGemm() {
|
|||
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
|
||||
}
|
||||
|
||||
void testGPU_FusionSmemBlockGemmCache() {
|
||||
TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -5891,7 +5890,7 @@ void testGPU_FusionSmemBlockGemmCache() {
|
|||
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
|
||||
}
|
||||
|
||||
void testGPU_FusionSmemDynamicReductionSymbolic() {
|
||||
TEST(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -5940,7 +5939,7 @@ void testGPU_FusionSmemDynamicReductionSymbolic() {
|
|||
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
|
||||
}
|
||||
|
||||
void testGPU_FusionSmemDynamicReductionSymbolicArg() {
|
||||
TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6000,7 +5999,7 @@ void testGPU_FusionSmemDynamicReductionSymbolicArg() {
|
|||
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(24) == 1);
|
||||
}
|
||||
|
||||
void testGPU_FusionSmemDynamicPwiseMulSymbolicArgWAR() {
|
||||
TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6059,7 +6058,7 @@ void testGPU_FusionSmemDynamicPwiseMulSymbolicArgWAR() {
|
|||
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(22) == 1);
|
||||
}
|
||||
|
||||
void testGPU_FusionSmemDynamicTiledGemm() {
|
||||
TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6185,7 +6184,7 @@ void testGPU_FusionSmemDynamicTiledGemm() {
|
|||
TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(41) == 1);
|
||||
}
|
||||
|
||||
void testGPU_FusionGlobalIntermediate() {
|
||||
TEST(NVFuserTest, FusionGlobalIntermediate_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6233,7 +6232,7 @@ void testGPU_FusionGlobalIntermediate() {
|
|||
aten_output.sub(outputs[0]).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionGlobalIntermediateDefaultSchedule() {
|
||||
TEST(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6273,7 +6272,7 @@ void testGPU_FusionGlobalIntermediateDefaultSchedule() {
|
|||
aten_output.sub(outputs[0]).abs().sum());
|
||||
}
|
||||
|
||||
void testGPU_FusionConstCheck() {
|
||||
TEST(NVFuserTest, FusionConstCheck_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6290,7 +6289,7 @@ void testGPU_FusionConstCheck() {
|
|||
TORCH_CHECK(one_x4->isConstScalar());
|
||||
}
|
||||
|
||||
void testGPU_FusionUnrollWithAlloc() {
|
||||
TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) {
|
||||
const std::vector<int64_t> tensor_dims_in = {128, 128};
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
|
@ -6338,7 +6337,7 @@ void testGPU_FusionUnrollWithAlloc() {
|
|||
}
|
||||
|
||||
// Test isZeroInt
|
||||
void testGPU_FusionIsZeroInt() {
|
||||
TEST(NVFuserTest, FusionIsZeroInt_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6351,7 +6350,7 @@ void testGPU_FusionIsZeroInt() {
|
|||
}
|
||||
|
||||
// Test isOneInt
|
||||
void testGPU_FusionIsOneInt() {
|
||||
TEST(NVFuserTest, FusionIsOneInt_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6366,7 +6365,7 @@ void testGPU_FusionIsOneInt() {
|
|||
// This is to verify no cycle of computeAt is created. A more complex
|
||||
// variation of this pattern appears in one of the Python tests
|
||||
// (test_random_topo).
|
||||
void testGPU_FusionComputeAtNonterminatingOutput() {
|
||||
TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6430,7 +6429,7 @@ void testGPU_FusionComputeAtNonterminatingOutput() {
|
|||
return;
|
||||
}
|
||||
|
||||
void testGPU_FusionTraversalOrder1() {
|
||||
TEST(NVFuserTest, FusionTraversalOrder1_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6478,7 +6477,7 @@ void testGPU_FusionTraversalOrder1() {
|
|||
t4.sub(cg_output_tv4).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionTraversalOrder2() {
|
||||
TEST(NVFuserTest, FusionTraversalOrder2_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6531,7 +6530,7 @@ void testGPU_FusionTraversalOrder2() {
|
|||
t5.sub(cg_output_tv5).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionTraversalOrder3() {
|
||||
TEST(NVFuserTest, FusionTraversalOrder3_CUDA) {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
|
@ -6599,7 +6598,7 @@ void testGPU_FusionTraversalOrder3() {
|
|||
}
|
||||
}
|
||||
|
||||
void testGPU_FusionTraversalOrder4() {
|
||||
TEST(NVFuserTest, FusionTraversalOrder4_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6663,7 +6662,7 @@ void testGPU_FusionTraversalOrder4() {
|
|||
t7.sub(cg_output_tv7).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionTraversalOrder5() {
|
||||
TEST(NVFuserTest, FusionTraversalOrder5_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6713,7 +6712,7 @@ void testGPU_FusionTraversalOrder5() {
|
|||
t5.sub(cg_output_tv5).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionTraversalOrder6() {
|
||||
TEST(NVFuserTest, FusionTraversalOrder6_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6755,7 +6754,7 @@ void testGPU_FusionTraversalOrder6() {
|
|||
t4.sub(cg_output_tv4).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionTraversalOrder7() {
|
||||
TEST(NVFuserTest, FusionTraversalOrder7_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6804,7 +6803,7 @@ void testGPU_FusionTraversalOrder7() {
|
|||
}
|
||||
|
||||
// Test predication of grid reduction
|
||||
void testGPU_FusionThreadPredicate() {
|
||||
TEST(NVFuserTest, FusionThreadPredicate_CUDA) {
|
||||
const int gdimx = 4;
|
||||
const int bdimx = 128;
|
||||
|
||||
|
|
@ -6860,7 +6859,7 @@ void testGPU_FusionThreadPredicate() {
|
|||
TORCH_CHECK(aten_output_tv3.allclose(cg_output_tv3));
|
||||
}
|
||||
|
||||
void testGPU_FusionLSTMCell() {
|
||||
TEST(NVFuserTest, FusionLSTMCell_CUDA) {
|
||||
const int hidden_features = 512;
|
||||
const int batch_size = 64;
|
||||
|
||||
|
|
@ -6940,7 +6939,7 @@ void testGPU_FusionLSTMCell() {
|
|||
TORCH_CHECK(at_hy.allclose(outputs[1], 1e-4, 1e-7));
|
||||
}
|
||||
|
||||
void testGPU_FusionComputeAtMultiBCast() {
|
||||
TEST(NVFuserTest, FusionComputeAtMultiBCast_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -6958,7 +6957,7 @@ void testGPU_FusionComputeAtMultiBCast() {
|
|||
ASSERT_ANY_THROW(tv1->computeAt(tv3, -1));
|
||||
}
|
||||
|
||||
void testGPU_FusionReductionHalf() {
|
||||
TEST(NVFuserTest, FusionReductionHalf_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
||||
|
|
@ -7011,7 +7010,7 @@ void testGPU_FusionReductionHalf() {
|
|||
aten_output.sub(outputs[0]).abs().max());
|
||||
}
|
||||
|
||||
void testGPU_FusionInputsIdLookup() {
|
||||
TEST(NVFuserTest, FusionInputsIdLookup_CUDA) {
|
||||
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
|
||||
at::Tensor t0 = at::randn({16, 8, 8}, options);
|
||||
at::Tensor t1 = at::randn({8, 8}, options);
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
#include "test/cpp/jit/test_base.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "test/cpp/jit/test_utils.h"
|
||||
#include "torch/csrc/jit/runtime/graph_executor.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testGraphExecutor() {
|
||||
TEST(GraphExecutorTest, Basic_CUDA) {
|
||||
constexpr int batch_size = 4;
|
||||
constexpr int input_size = 256;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#include <test/cpp/jit/test_base.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
|
|
@ -36,18 +36,16 @@ struct InlinerGuard {
|
|||
bool oldState_;
|
||||
};
|
||||
|
||||
void testInliner() {
|
||||
{
|
||||
// disable automatic inlining so we can test it manually
|
||||
InlinerGuard guard(/*shouldInline=*/false);
|
||||
TEST(InlinerTest, Basic) {
|
||||
// disable automatic inlining so we can test it manually
|
||||
InlinerGuard guard(/*shouldInline=*/false);
|
||||
|
||||
CompilationUnit cu(testSource);
|
||||
auto& fn = cu.get_function("foo3");
|
||||
CompilationUnit cu(testSource);
|
||||
auto& fn = cu.get_function("foo3");
|
||||
|
||||
auto g = fn.graph();
|
||||
Inline(*g);
|
||||
FileCheck().check_count("prim::Print", 3)->run(*g);
|
||||
}
|
||||
auto g = fn.graph();
|
||||
Inline(*g);
|
||||
FileCheck().check_count("prim::Print", 3)->run(*g);
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <test/cpp/jit/test_base.h>
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
#include <ATen/core/qualified_name.h>
|
||||
|
|
@ -44,7 +44,7 @@ static void import_libs(
|
|||
si.loadType(QualifiedName(class_name));
|
||||
}
|
||||
|
||||
void testModuleInterfaceSerialization() {
|
||||
TEST(InterfaceTest, ModuleInterfaceSerialization) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
Module parentMod("parentMod", cu);
|
||||
Module subMod("subMod", cu);
|
||||
|
|
|
|||
|
|
@ -1,12 +1,18 @@
|
|||
#include "test/cpp/jit/test_base.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "test/cpp/jit/test_utils.h"
|
||||
|
||||
#include <stdexcept>
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testTypeCheck() {
|
||||
{
|
||||
class TypeCheckTest : public ::testing::Test {
|
||||
protected:
|
||||
TypeCheckTest() : interp(makeInterp()) {}
|
||||
|
||||
InterpreterState interp;
|
||||
|
||||
private:
|
||||
static InterpreterState makeInterp() {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
parseIR(
|
||||
|
|
@ -20,88 +26,97 @@ graph(%a.1 : Tensor,
|
|||
vmap);
|
||||
|
||||
Code function(graph, "");
|
||||
InterpreterState interp(function);
|
||||
{
|
||||
// TypeCheck yields to true! Shape, grad and device matches.
|
||||
auto a = at::zeros({2, 2}, at::kFloat);
|
||||
auto b = at::ones({3, 3}, at::kFloat);
|
||||
a.set_requires_grad(true);
|
||||
a = a.to(at::kCPU);
|
||||
std::vector<IValue> stack({a, b});
|
||||
interp.run(stack);
|
||||
ASSERT_TRUE(exactlyEqual(stack[0].toTensor(), a));
|
||||
ASSERT_TRUE(exactlyEqual(stack[1].toTensor(), b));
|
||||
ASSERT_TRUE(stack[2].toBool());
|
||||
}
|
||||
{
|
||||
auto a = at::zeros({2, 2}, at::kFloat);
|
||||
auto b = at::ones({2, 2}, at::kFloat); // Size mismatch
|
||||
a.set_requires_grad(true);
|
||||
a = a.to(at::kCPU);
|
||||
std::vector<IValue> stack({a, b});
|
||||
interp.run(stack);
|
||||
ASSERT_FALSE(stack[2].toBool());
|
||||
}
|
||||
{
|
||||
auto a = at::zeros({2, 2}, at::kFloat);
|
||||
auto b = at::ones({3, 3}, at::kFloat);
|
||||
a = a.to(at::kCPU);
|
||||
a.set_requires_grad(false); // Gradient mismatch
|
||||
std::vector<IValue> stack({a, b});
|
||||
interp.run(stack);
|
||||
ASSERT_FALSE(stack[2].toBool());
|
||||
}
|
||||
{
|
||||
auto a = at::zeros({2, 2}, at::kFloat);
|
||||
auto b = at::ones({3, 3}, at::kFloat);
|
||||
a = a.to(at::kCPU);
|
||||
a.set_requires_grad(true);
|
||||
a = a.to(at::kInt); // Scalar type mismatch
|
||||
std::vector<IValue> stack({a, b});
|
||||
interp.run(stack);
|
||||
ASSERT_FALSE(stack[2].toBool());
|
||||
}
|
||||
{
|
||||
auto a = at::zeros({2, 2}, at::kFloat);
|
||||
auto b = at::ones({3, 3}, at::kFloat);
|
||||
a.set_requires_grad(true);
|
||||
a = a.to(at::kCUDA); // Device mismatch
|
||||
std::vector<IValue> stack({a, b});
|
||||
interp.run(stack);
|
||||
ASSERT_FALSE(stack[2].toBool());
|
||||
}
|
||||
return InterpreterState(function);
|
||||
}
|
||||
};
|
||||
|
||||
try { // Test empty Typecheck raises an internal assertion
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%a.1 : Tensor,
|
||||
%b.1 : Tensor):
|
||||
%type_matched : bool = prim::TypeCheck()
|
||||
return (%type_matched)
|
||||
)IR",
|
||||
&*graph,
|
||||
vmap);
|
||||
} catch (const std::exception& e) {
|
||||
}
|
||||
try { // Test for assertion if num_inputs + 1 != num_outputs
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%a.1 : Tensor,
|
||||
%b.1 : Tensor):
|
||||
%type_matched : bool = prim::TypeCheck(%a.1)
|
||||
return (%type_matched)
|
||||
)IR",
|
||||
&*graph,
|
||||
vmap);
|
||||
} catch (const std::exception& e) {
|
||||
}
|
||||
TEST_F(TypeCheckTest, MatchingType) {
|
||||
// TypeCheck yields to true! Shape, grad and device matches.
|
||||
auto a = at::zeros({2, 2}, at::kFloat);
|
||||
auto b = at::ones({3, 3}, at::kFloat);
|
||||
a.set_requires_grad(true);
|
||||
a = a.to(at::kCPU);
|
||||
std::vector<IValue> stack({a, b});
|
||||
interp.run(stack);
|
||||
ASSERT_TRUE(exactlyEqual(stack[0].toTensor(), a));
|
||||
ASSERT_TRUE(exactlyEqual(stack[1].toTensor(), b));
|
||||
ASSERT_TRUE(stack[2].toBool());
|
||||
}
|
||||
void testInterp() {
|
||||
|
||||
TEST_F(TypeCheckTest, SizeMismatch) {
|
||||
auto a = at::zeros({2, 2}, at::kFloat);
|
||||
auto b = at::ones({2, 2}, at::kFloat); // Size mismatch
|
||||
a.set_requires_grad(true);
|
||||
a = a.to(at::kCPU);
|
||||
std::vector<IValue> stack({a, b});
|
||||
interp.run(stack);
|
||||
ASSERT_FALSE(stack[2].toBool());
|
||||
}
|
||||
|
||||
TEST_F(TypeCheckTest, GradientMismatch) {
|
||||
auto a = at::zeros({2, 2}, at::kFloat);
|
||||
auto b = at::ones({3, 3}, at::kFloat);
|
||||
a = a.to(at::kCPU);
|
||||
a.set_requires_grad(false); // Gradient mismatch
|
||||
std::vector<IValue> stack({a, b});
|
||||
interp.run(stack);
|
||||
ASSERT_FALSE(stack[2].toBool());
|
||||
}
|
||||
|
||||
TEST_F(TypeCheckTest, ScalarTypeMismatch) {
|
||||
auto a = at::zeros({2, 2}, at::kFloat);
|
||||
auto b = at::ones({3, 3}, at::kFloat);
|
||||
a = a.to(at::kCPU);
|
||||
a.set_requires_grad(true);
|
||||
a = a.to(at::kInt); // Scalar type mismatch
|
||||
std::vector<IValue> stack({a, b});
|
||||
interp.run(stack);
|
||||
ASSERT_FALSE(stack[2].toBool());
|
||||
}
|
||||
|
||||
TEST_F(TypeCheckTest, DeviceMismatch_CUDA) {
|
||||
auto a = at::zeros({2, 2}, at::kFloat);
|
||||
auto b = at::ones({3, 3}, at::kFloat);
|
||||
a.set_requires_grad(true);
|
||||
a = a.to(at::kCUDA); // Device mismatch
|
||||
std::vector<IValue> stack({a, b});
|
||||
interp.run(stack);
|
||||
ASSERT_FALSE(stack[2].toBool());
|
||||
}
|
||||
|
||||
// TODO: These tests weren't doing anything.
|
||||
// TEST(TypeCheckErrorTest, EmptyCheckRaises) {
|
||||
// // Test empty Typecheck raises an internal assertion
|
||||
// auto graph = std::make_shared<Graph>();
|
||||
// std::unordered_map<std::string, Value*> vmap;
|
||||
// EXPECT_ANY_THROW(parseIR(
|
||||
// R"IR(
|
||||
// graph(%a.1 : Tensor,
|
||||
// %b.1 : Tensor):
|
||||
// %type_matched : bool = prim::TypeCheck()
|
||||
// return (%type_matched)
|
||||
// )IR",
|
||||
// &*graph,
|
||||
// vmap));
|
||||
// }
|
||||
|
||||
// TODO: These tests weren't doing anything.
|
||||
// TEST(TypeCheckErrorTest, WrongInputOutputCountRaises) {
|
||||
// // Test for assertion if num_inputs + 1 != num_outputs
|
||||
// auto graph = std::make_shared<Graph>();
|
||||
// std::unordered_map<std::string, Value*> vmap;
|
||||
// EXPECT_ANY_THROW(parseIR(
|
||||
// R"IR(
|
||||
// graph(%a.1 : Tensor,
|
||||
// %b.1 : Tensor):
|
||||
// %type_matched : bool = prim::TypeCheck(%a.1)
|
||||
// return (%type_matched)
|
||||
// )IR",
|
||||
// &*graph,
|
||||
// vmap));
|
||||
// }
|
||||
|
||||
TEST(InterpreterTest, Basic_CUDA) {
|
||||
constexpr int batch_size = 4;
|
||||
constexpr int input_size = 256;
|
||||
constexpr int seq_len = 32;
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
#include "test/cpp/jit/test_base.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "test/cpp/jit/test_utils.h"
|
||||
#include "torch/csrc/jit/ir/irparser.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testAttributes() {
|
||||
TEST(IRTest, Attributes) {
|
||||
Graph g;
|
||||
auto one = attr::alpha;
|
||||
auto two = attr::device;
|
||||
|
|
@ -33,7 +34,7 @@ void testAttributes() {
|
|||
ASSERT_EQ(attr2.f(one), 5);
|
||||
}
|
||||
|
||||
void testBlocks() {
|
||||
TEST(IRTest, Blocks) {
|
||||
auto g = std::make_shared<Graph>();
|
||||
const auto graph_string = R"IR(
|
||||
graph(%a : Tensor,
|
||||
|
|
@ -92,7 +93,7 @@ void testBlocks() {
|
|||
->run(*g2);
|
||||
}
|
||||
|
||||
void testCommonAncestor() {
|
||||
TEST(IRTest, CommonAncestor) {
|
||||
std::string input_str = R"(
|
||||
graph(%x : Tensor,
|
||||
%a.1 : bool,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
#include "test/cpp/jit/test_base.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
|
@ -38,52 +39,52 @@ static void checkRoundtrip(const std::string& s) {
|
|||
AT_ASSERT(original == parsed);
|
||||
}
|
||||
|
||||
void testIRParser() {
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
parseIR(
|
||||
R"IR(
|
||||
TEST(IRParserTest, Basic) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0 : Tensor, %1 : Tensor):
|
||||
%2 : Tensor = foo::add(%0, %1)
|
||||
%res, %3 = foo::mul(%0, %2)
|
||||
%x, %y = foo::combine(%res, %2, %3)
|
||||
return (%x, %y, %res))IR",
|
||||
&*graph,
|
||||
vmap);
|
||||
&*graph,
|
||||
vmap);
|
||||
|
||||
AT_ASSERT(graph->inputs().size() == 2);
|
||||
AT_ASSERT(graph->outputs().size() == 3);
|
||||
Value* x = graph->outputs()[0];
|
||||
Value* y = graph->outputs()[1];
|
||||
Value* res = graph->outputs()[2];
|
||||
Value* t0 = graph->inputs()[0];
|
||||
Value* t1 = graph->inputs()[1];
|
||||
AT_ASSERT(vmap["x"] == x);
|
||||
AT_ASSERT(vmap["y"] == y);
|
||||
AT_ASSERT(vmap["res"] == res);
|
||||
AT_ASSERT(vmap["0"] == t0);
|
||||
AT_ASSERT(vmap["1"] == t1);
|
||||
AT_ASSERT(x->node() == y->node());
|
||||
Node* comb = x->node();
|
||||
Value* t2 = comb->inputs()[1];
|
||||
Value* t3 = comb->inputs()[2];
|
||||
AT_ASSERT(vmap["2"] == t2);
|
||||
AT_ASSERT(vmap["3"] == t3);
|
||||
AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine"));
|
||||
AT_ASSERT(comb->outputs() == std::vector<Value*>({x, y}));
|
||||
AT_ASSERT(comb->inputs() == std::vector<Value*>({res, t2, t3}));
|
||||
Node* mul = res->node();
|
||||
AT_ASSERT(mul->kind().toQualString() == std::string("foo::mul"));
|
||||
AT_ASSERT(mul->inputs() == std::vector<Value*>({t0, t2}));
|
||||
AT_ASSERT(mul->outputs() == std::vector<Value*>({res, t3}));
|
||||
Node* add = t2->node();
|
||||
AT_ASSERT(add->kind().toQualString() == std::string("foo::add"));
|
||||
AT_ASSERT(add->inputs() == std::vector<Value*>({t0, t1}));
|
||||
AT_ASSERT(add->outputs() == std::vector<Value*>({t2}));
|
||||
}
|
||||
{
|
||||
checkRoundtrip(R"IR(
|
||||
AT_ASSERT(graph->inputs().size() == 2);
|
||||
AT_ASSERT(graph->outputs().size() == 3);
|
||||
Value* x = graph->outputs()[0];
|
||||
Value* y = graph->outputs()[1];
|
||||
Value* res = graph->outputs()[2];
|
||||
Value* t0 = graph->inputs()[0];
|
||||
Value* t1 = graph->inputs()[1];
|
||||
AT_ASSERT(vmap["x"] == x);
|
||||
AT_ASSERT(vmap["y"] == y);
|
||||
AT_ASSERT(vmap["res"] == res);
|
||||
AT_ASSERT(vmap["0"] == t0);
|
||||
AT_ASSERT(vmap["1"] == t1);
|
||||
AT_ASSERT(x->node() == y->node());
|
||||
Node* comb = x->node();
|
||||
Value* t2 = comb->inputs()[1];
|
||||
Value* t3 = comb->inputs()[2];
|
||||
AT_ASSERT(vmap["2"] == t2);
|
||||
AT_ASSERT(vmap["3"] == t3);
|
||||
AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine"));
|
||||
AT_ASSERT(comb->outputs() == std::vector<Value*>({x, y}));
|
||||
AT_ASSERT(comb->inputs() == std::vector<Value*>({res, t2, t3}));
|
||||
Node* mul = res->node();
|
||||
AT_ASSERT(mul->kind().toQualString() == std::string("foo::mul"));
|
||||
AT_ASSERT(mul->inputs() == std::vector<Value*>({t0, t2}));
|
||||
AT_ASSERT(mul->outputs() == std::vector<Value*>({res, t3}));
|
||||
Node* add = t2->node();
|
||||
AT_ASSERT(add->kind().toQualString() == std::string("foo::add"));
|
||||
AT_ASSERT(add->inputs() == std::vector<Value*>({t0, t1}));
|
||||
AT_ASSERT(add->outputs() == std::vector<Value*>({t2}));
|
||||
}
|
||||
|
||||
TEST(IRParserTest, NestedBlock) {
|
||||
checkRoundtrip(R"IR(
|
||||
graph():
|
||||
%0 : Tensor = a::a()
|
||||
block0():
|
||||
|
|
@ -95,9 +96,10 @@ graph():
|
|||
%3 : Tensor = d::d()
|
||||
return (%3)
|
||||
)IR");
|
||||
}
|
||||
{
|
||||
checkRoundtrip(R"IR(
|
||||
}
|
||||
|
||||
TEST(IRParserTest, If) {
|
||||
checkRoundtrip(R"IR(
|
||||
graph(%0 : Tensor,
|
||||
%1 : Tensor,
|
||||
%2 : Tensor):
|
||||
|
|
@ -114,9 +116,10 @@ graph(%0 : Tensor,
|
|||
%11 : Tensor = aten::add(%5, %3, %10)
|
||||
return (%11)
|
||||
)IR");
|
||||
}
|
||||
{
|
||||
checkRoundtrip(R"IR(
|
||||
}
|
||||
|
||||
TEST(IRParserTest, If2) {
|
||||
checkRoundtrip(R"IR(
|
||||
graph(%0 : Tensor,
|
||||
%1 : Tensor,
|
||||
%2 : Tensor):
|
||||
|
|
@ -133,40 +136,43 @@ graph(%0 : Tensor,
|
|||
%11 : Tensor = aten::add(%5, %3, %10)
|
||||
return (%11)
|
||||
)IR");
|
||||
}
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
}
|
||||
|
||||
TEST(IRParserTest, InferredTypeIsTensor) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%a):
|
||||
return (%a))IR",
|
||||
&*graph);
|
||||
AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get()));
|
||||
}
|
||||
{
|
||||
// Check that parser correctly handles values reusing the same name.
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
&*graph);
|
||||
AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get()));
|
||||
}
|
||||
|
||||
TEST(IRParserTest, ValueReuse) {
|
||||
// Check that parser correctly handles values reusing the same name.
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x):
|
||||
%x = a::a(%x)
|
||||
%x = b::b(%x)
|
||||
return (%x))IR",
|
||||
&*graph);
|
||||
Value* x0 = graph->inputs()[0];
|
||||
Value* x2 = graph->outputs()[0];
|
||||
Node* b = x2->node();
|
||||
Value* x1 = b->inputs()[0];
|
||||
Node* a = x1->node();
|
||||
AT_ASSERT(a->inputs() == std::vector<Value*>({x0}));
|
||||
AT_ASSERT(a->outputs() == std::vector<Value*>({x1}));
|
||||
AT_ASSERT(b->inputs() == std::vector<Value*>({x1}));
|
||||
AT_ASSERT(b->outputs() == std::vector<Value*>({x2}));
|
||||
}
|
||||
{
|
||||
// Check that parser handles attributes and types.
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
&*graph);
|
||||
Value* x0 = graph->inputs()[0];
|
||||
Value* x2 = graph->outputs()[0];
|
||||
Node* b = x2->node();
|
||||
Value* x1 = b->inputs()[0];
|
||||
Node* a = x1->node();
|
||||
AT_ASSERT(a->inputs() == std::vector<Value*>({x0}));
|
||||
AT_ASSERT(a->outputs() == std::vector<Value*>({x1}));
|
||||
AT_ASSERT(b->inputs() == std::vector<Value*>({x1}));
|
||||
AT_ASSERT(b->outputs() == std::vector<Value*>({x2}));
|
||||
}
|
||||
|
||||
TEST(IRParserTest, Attributes) {
|
||||
// Check that parser handles attributes and types.
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
graph(%0 : Tensor,
|
||||
%1 : Tensor,
|
||||
%2 : Tensor):
|
||||
|
|
@ -176,155 +182,147 @@ graph(%0 : Tensor,
|
|||
%8 : string = z::z()
|
||||
return (%7)
|
||||
)IR");
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
TEST(IRParserTest, OptionalTypes) {
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
graph(%0 : Tensor,
|
||||
%1 : Tensor,
|
||||
%2 : Tensor):
|
||||
%3 : int? = prim::Constant()
|
||||
return (%3)
|
||||
)IR");
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
TEST(IRParserTest, StarTensor) {
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
graph(%0 : Tensor,
|
||||
%1 : Tensor,
|
||||
%2 : Tensor):
|
||||
%3 : Float(*, *, *) = prim::Constant()
|
||||
return (%3)
|
||||
)IR");
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
TEST(IRParserTest, UnshapedTensor) {
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
graph(%0 : Tensor,
|
||||
%1 : Tensor,
|
||||
%2 : Tensor):
|
||||
%3 : Long() = prim::Constant()
|
||||
return (%3)
|
||||
)IR");
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
TEST(IRParserTest, ShapedTensor) {
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
graph(%0 : Tensor,
|
||||
%1 : Tensor,
|
||||
%2 : Tensor):
|
||||
%3 : Double(4, 4, 5) = prim::Constant()
|
||||
return (%3)
|
||||
)IR");
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
TEST(IRParserTest, NestedContrainer) {
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
graph():
|
||||
%0 : float[] = prim::Constant[value=[1., 2., 3.]]()
|
||||
%1 : str[] = prim::Constant[value=["ab", "cd", "ef"]]()
|
||||
%2 : (float[], str[]) = prim::TupleConstruct(%0, %1)
|
||||
return (%2)
|
||||
)IR");
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
bool error_thrown = false;
|
||||
try {
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
TEST(IRParserTest, MalformedShapeAnnotation) {
|
||||
EXPECT_ANY_THROW(checkRoundtrip(
|
||||
R"IR(
|
||||
graph(%0 : Tensor,
|
||||
%1 : Tensor,
|
||||
%2 : Tensor):
|
||||
%3 : Double(4!, 4, 5) = prim::Constant()
|
||||
return (%3)
|
||||
)IR");
|
||||
} catch (const std::exception& error) {
|
||||
error_thrown = true;
|
||||
}
|
||||
AT_ASSERT(error_thrown);
|
||||
}
|
||||
)IR"));
|
||||
}
|
||||
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
const std::string& text =
|
||||
R"IR(
|
||||
TEST(IRParserTest, FileCheck) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
const std::string& text =
|
||||
R"IR(
|
||||
graph(%a):
|
||||
# CHECK: return
|
||||
return (%a))IR";
|
||||
|
||||
parseIR(text, &*graph);
|
||||
AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get()));
|
||||
torch::jit::testing::FileCheck().run(text, *graph);
|
||||
}
|
||||
parseIR(text, &*graph);
|
||||
AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get()));
|
||||
torch::jit::testing::FileCheck().run(text, *graph);
|
||||
}
|
||||
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
parseIR(
|
||||
R"IR(
|
||||
TEST(IRParserTest, Strides) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%a : Float(4, 5),
|
||||
%b : Float(4:5, 5:1),
|
||||
%c : Double(*, *)):
|
||||
return (%a)
|
||||
)IR",
|
||||
&*graph,
|
||||
vmap);
|
||||
Value* a = graph->inputs()[0];
|
||||
Value* b = graph->inputs()[1];
|
||||
Value* c = graph->inputs()[2];
|
||||
&*graph,
|
||||
vmap);
|
||||
Value* a = graph->inputs()[0];
|
||||
Value* b = graph->inputs()[1];
|
||||
Value* c = graph->inputs()[2];
|
||||
|
||||
auto a_type = a->type()->cast<TensorType>();
|
||||
auto a_sizes = *a_type->sizes().concrete_sizes();
|
||||
auto a_strides = a_type->strides().concrete_sizes();
|
||||
AT_ASSERT(a_sizes[0] == 4 && a_sizes[1] == 5);
|
||||
AT_ASSERT(a_strides == c10::nullopt);
|
||||
auto a_type = a->type()->cast<TensorType>();
|
||||
auto a_sizes = *a_type->sizes().concrete_sizes();
|
||||
auto a_strides = a_type->strides().concrete_sizes();
|
||||
AT_ASSERT(a_sizes[0] == 4 && a_sizes[1] == 5);
|
||||
AT_ASSERT(a_strides == c10::nullopt);
|
||||
|
||||
auto b_type = b->type()->cast<TensorType>();
|
||||
auto b_sizes = *b_type->sizes().concrete_sizes();
|
||||
auto b_strides = *(b_type->strides().sizes());
|
||||
AT_ASSERT(b_sizes[0] == 4 && b_sizes[1] == 5);
|
||||
AT_ASSERT(*b_strides[0] == 5 && *b_strides[1] == 1);
|
||||
auto b_type = b->type()->cast<TensorType>();
|
||||
auto b_sizes = *b_type->sizes().concrete_sizes();
|
||||
auto b_strides = *(b_type->strides().sizes());
|
||||
AT_ASSERT(b_sizes[0] == 4 && b_sizes[1] == 5);
|
||||
AT_ASSERT(*b_strides[0] == 5 && *b_strides[1] == 1);
|
||||
|
||||
auto c_type = c->type()->cast<TensorType>();
|
||||
AT_ASSERT(*c_type->sizes().size() == 2);
|
||||
AT_ASSERT(c_type->sizes().concrete_sizes() == c10::nullopt);
|
||||
AT_ASSERT(c_type->strides().concrete_sizes() == c10::nullopt);
|
||||
}
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
bool error_thrown = false;
|
||||
try {
|
||||
parseIR(
|
||||
R"IR(
|
||||
auto c_type = c->type()->cast<TensorType>();
|
||||
AT_ASSERT(*c_type->sizes().size() == 2);
|
||||
AT_ASSERT(c_type->sizes().concrete_sizes() == c10::nullopt);
|
||||
AT_ASSERT(c_type->strides().concrete_sizes() == c10::nullopt);
|
||||
}
|
||||
|
||||
TEST(IRParserTest, MalformedStrides) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
bool error_thrown = false;
|
||||
EXPECT_ANY_THROW(parseIR(
|
||||
R"IR(
|
||||
graph(%a : Float(4:5, 5)):
|
||||
return (%a)
|
||||
)IR",
|
||||
&*graph,
|
||||
vmap);
|
||||
} catch (const std::exception& error) {
|
||||
error_thrown = true;
|
||||
}
|
||||
AT_ASSERT(error_thrown);
|
||||
}
|
||||
{
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
&*graph,
|
||||
vmap));
|
||||
}
|
||||
|
||||
TEST(IRParserTest, TensorShapes) {
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
graph(%a : Float(4, 5),
|
||||
%b : Float(4:5, 5:1),
|
||||
%c : Double(*, *)):
|
||||
return (%a)
|
||||
)IR");
|
||||
}
|
||||
{
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
}
|
||||
|
||||
TEST(IRParserTest, DeviceAndRequiresGradTensors) {
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
graph(%a : Float(*, *, device=cpu),
|
||||
%b : Float(*, *, requires_grad=1),
|
||||
%c : Long(5, 10, requires_grad=1, device=cpu),
|
||||
|
|
@ -337,41 +335,45 @@ graph(%a : Float(*, *, device=cpu),
|
|||
%j : Double(*, *, requires_grad=0)):
|
||||
return (%a)
|
||||
)IR");
|
||||
}
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
}
|
||||
|
||||
TEST(IRParserTest, ListConstant) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%d : int[] = prim::Constant[value=[1,2,3]]()
|
||||
return (%d)
|
||||
)IR",
|
||||
&*graph);
|
||||
Node* n = graph->outputs()[0]->node();
|
||||
AT_ASSERT(n->kind() == prim::Constant);
|
||||
AT_ASSERT(n->kindOf(attr::value) == AttributeKind::ival);
|
||||
const auto& genericList = n->ival(attr::value).toList();
|
||||
std::vector<int> int_vals;
|
||||
for (const IValue& ival : genericList) {
|
||||
int_vals.push_back(ival.toInt());
|
||||
}
|
||||
AT_ASSERT(int_vals.size() == 3);
|
||||
AT_ASSERT(int_vals[0] == 1 && int_vals[1] == 2 && int_vals[2] == 3);
|
||||
&*graph);
|
||||
Node* n = graph->outputs()[0]->node();
|
||||
AT_ASSERT(n->kind() == prim::Constant);
|
||||
AT_ASSERT(n->kindOf(attr::value) == AttributeKind::ival);
|
||||
const auto& genericList = n->ival(attr::value).toList();
|
||||
std::vector<int> int_vals;
|
||||
for (const IValue& ival : genericList) {
|
||||
int_vals.push_back(ival.toInt());
|
||||
}
|
||||
{
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
AT_ASSERT(int_vals.size() == 3);
|
||||
AT_ASSERT(int_vals[0] == 1 && int_vals[1] == 2 && int_vals[2] == 3);
|
||||
}
|
||||
|
||||
TEST(IRParserTest, PartialStarTensor) {
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
graph(%x : Float(10, *, 10)):
|
||||
return (%x)
|
||||
)IR");
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
}
|
||||
|
||||
TEST(IRParserTest, ComplexTensorAttributes) {
|
||||
checkRoundtrip(
|
||||
R"IR(
|
||||
graph(%x : Double(*, 200, *, requires_grad=1, device=cuda:1),
|
||||
%b : Float(5, *, requires_grad=1),
|
||||
%c : Long(*, 10, device=cpu)):
|
||||
return (%x)
|
||||
)IR");
|
||||
}
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <test/cpp/jit/test_base.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
#include "torch/csrc/jit/ir/ir.h"
|
||||
|
|
@ -7,7 +8,7 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testUnifyTypes() {
|
||||
TEST(JitTypeTest, UnifyTypes) {
|
||||
auto bool_tensor = TensorType::get()->withScalarType(at::kBool);
|
||||
auto opt_bool_tensor = OptionalType::create(bool_tensor);
|
||||
auto unified_opt_bool = unifyTypes(bool_tensor, opt_bool_tensor);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <test/cpp/jit/test_base.h>
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
|
|
@ -10,11 +11,19 @@
|
|||
|
||||
#include <unordered_set>
|
||||
|
||||
#define ASSERT_THROWS_WITH(statement, substring) \
|
||||
try { \
|
||||
(void)statement; \
|
||||
ASSERT_TRUE(false); \
|
||||
} catch (const std::exception& e) { \
|
||||
ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
|
||||
}
|
||||
|
||||
// Tests go in torch::jit
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testLiteInterpreterUpsampleNearest2d() {
|
||||
TEST(LiteInterpreterTest, UpsampleNearest2d) {
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def forward(self, input: Tensor, scale:float):
|
||||
|
|
@ -37,7 +46,7 @@ void testLiteInterpreterUpsampleNearest2d() {
|
|||
ASSERT_TRUE(resd.equal(refd));
|
||||
}
|
||||
|
||||
void testLiteInterpreterAdd() {
|
||||
TEST(LiteInterpreterTest, Add) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
// TODO: support default param val, which was pushed in
|
||||
|
|
@ -71,7 +80,7 @@ void testLiteInterpreterAdd() {
|
|||
AT_ASSERT(resd == refd);
|
||||
}
|
||||
|
||||
void testLiteInterpreterConv() {
|
||||
TEST(LiteInterpreterTest, Conv) {
|
||||
auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
|
||||
if (s && strcmp(s, "1") == 0)
|
||||
return;
|
||||
|
|
@ -103,7 +112,7 @@ void testLiteInterpreterConv() {
|
|||
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
||||
}
|
||||
|
||||
void testLiteInterpreterInline() {
|
||||
TEST(LiteInterpreterTest, Inline) {
|
||||
Module m("m");
|
||||
m.define(R"JIT(
|
||||
def foo1(self, x):
|
||||
|
|
@ -123,7 +132,7 @@ void testLiteInterpreterInline() {
|
|||
AT_ASSERT(output.toTensor().item<float>() == 7.0);
|
||||
}
|
||||
|
||||
void testLiteInterpreterTuple() {
|
||||
TEST(LiteInterpreterTest, Tuple) {
|
||||
Module m("m");
|
||||
m.define(R"JIT(
|
||||
def foo(self, x):
|
||||
|
|
@ -141,7 +150,7 @@ void testLiteInterpreterTuple() {
|
|||
AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2);
|
||||
}
|
||||
|
||||
void testLiteInterpreterDict() {
|
||||
TEST(LiteInterpreterTest, Dict) {
|
||||
Module m("m");
|
||||
m.define(R"JIT(
|
||||
def foo(self, x):
|
||||
|
|
@ -159,7 +168,7 @@ void testLiteInterpreterDict() {
|
|||
AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
|
||||
}
|
||||
|
||||
void testLiteInterpreterPrimOverload() {
|
||||
TEST(LiteInterpreterTest, PrimOverload) {
|
||||
/*
|
||||
// temporarily disabled
|
||||
script::Module m("m");
|
||||
|
|
@ -178,7 +187,7 @@ void testLiteInterpreterPrimOverload() {
|
|||
*/
|
||||
}
|
||||
|
||||
void testLiteInterpreterPrim() {
|
||||
TEST(LiteInterpreterTest, Prim) {
|
||||
Module m("m");
|
||||
m.define(R"JIT(
|
||||
def forward(self, x):
|
||||
|
|
@ -204,7 +213,7 @@ void testLiteInterpreterPrim() {
|
|||
AT_ASSERT(resi == refi);
|
||||
}
|
||||
|
||||
void testLiteInterpreterPrimScalar() {
|
||||
TEST(LiteInterpreterTest, PrimScalar) {
|
||||
Module m("m");
|
||||
m.define(R"JIT(
|
||||
def forward(self, x):
|
||||
|
|
@ -230,7 +239,7 @@ void testLiteInterpreterPrimScalar() {
|
|||
AT_ASSERT(resi == refi);
|
||||
}
|
||||
|
||||
void testLiteInterpreterLoadOrigJit() {
|
||||
TEST(LiteInterpreterTest, LoadOrigJit) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
|
|
@ -243,7 +252,7 @@ void testLiteInterpreterLoadOrigJit() {
|
|||
ASSERT_THROWS_WITH(_load_for_mobile(ss), "file not found");
|
||||
}
|
||||
|
||||
void testLiteInterpreterWrongMethodName() {
|
||||
TEST(LiteInterpreterTest, WrongMethodName) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
|
|
@ -260,7 +269,7 @@ void testLiteInterpreterWrongMethodName() {
|
|||
ASSERT_THROWS_WITH(bc.get_method("forward")(inputs), "is not defined");
|
||||
}
|
||||
|
||||
void testLiteInterpreterSetState() {
|
||||
TEST(LiteInterpreterTest, SetState) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
|
|
@ -308,7 +317,7 @@ class TorchBindLiteInterpreterTestStruct
|
|||
}
|
||||
};
|
||||
|
||||
void testLiteInterpreterBuiltinFunction() {
|
||||
TEST(LiteInterpreterTest, BuiltinFunction) {
|
||||
script::Module m("m");
|
||||
auto custom_class_obj =
|
||||
make_custom_class<TorchBindLiteInterpreterTestStruct>();
|
||||
|
|
@ -328,7 +337,7 @@ void testLiteInterpreterBuiltinFunction() {
|
|||
AT_ASSERT(str == expected);
|
||||
}
|
||||
|
||||
void testLiteInterpreterModuleInfoBasic() {
|
||||
TEST(LiteInterpreterTest, ModuleInfoBasic) {
|
||||
Module m("M");
|
||||
m.define(R"JIT(
|
||||
def forward(self, x):
|
||||
|
|
@ -357,7 +366,7 @@ void testLiteInterpreterModuleInfoBasic() {
|
|||
AT_ASSERT(module_debug_info_set == expected_result);
|
||||
}
|
||||
|
||||
void testLiteInterpreterNotSavingModuleInfo() {
|
||||
TEST(LiteInterpreterTest, NotSaveModuleInfo) {
|
||||
Module m("M");
|
||||
m.define(R"JIT(
|
||||
def forward(self, x):
|
||||
|
|
@ -380,7 +389,7 @@ void testLiteInterpreterNotSavingModuleInfo() {
|
|||
}
|
||||
}
|
||||
|
||||
void testLiteInterpreterOneSubmoduleModuleInfo() {
|
||||
TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) {
|
||||
Module a("A");
|
||||
a.define(R"JIT(
|
||||
def forward(self, x):
|
||||
|
|
@ -416,7 +425,7 @@ void testLiteInterpreterOneSubmoduleModuleInfo() {
|
|||
AT_ASSERT(module_debug_info_set == expected_result);
|
||||
}
|
||||
|
||||
void testLiteInterpreterTwoSubmodulesModuleInfo() {
|
||||
TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) {
|
||||
Module a("A");
|
||||
a.define(R"JIT(
|
||||
def forward(self, x):
|
||||
|
|
@ -458,7 +467,7 @@ void testLiteInterpreterTwoSubmodulesModuleInfo() {
|
|||
AT_ASSERT(module_debug_info_set == expected_result);
|
||||
}
|
||||
|
||||
void testLiteInterpreterSequentialModuleInfo() {
|
||||
TEST(LiteInterpreterTest, SequentialModuleInfo) {
|
||||
Module a("A");
|
||||
a.define(R"JIT(
|
||||
def forward(self, x):
|
||||
|
|
@ -500,7 +509,7 @@ void testLiteInterpreterSequentialModuleInfo() {
|
|||
AT_ASSERT(module_debug_info_set == expected_result);
|
||||
}
|
||||
|
||||
void testLiteInterpreterHierarchyModuleInfo() {
|
||||
TEST(LiteInterpreterTest, HierarchyModuleInfo) {
|
||||
Module a("A");
|
||||
a.define(R"JIT(
|
||||
def forward(self, x):
|
||||
|
|
@ -546,7 +555,7 @@ void testLiteInterpreterHierarchyModuleInfo() {
|
|||
AT_ASSERT(module_debug_info_set == expected_result);
|
||||
}
|
||||
|
||||
void testLiteInterpreterDuplicatedClassTypeModuleInfo() {
|
||||
TEST(LiteInterpreterTest, DuplicatedClassTypeModuleInfo) {
|
||||
Module a("A");
|
||||
a.define(R"JIT(
|
||||
def forward(self, x):
|
||||
|
|
@ -586,7 +595,7 @@ void testLiteInterpreterDuplicatedClassTypeModuleInfo() {
|
|||
AT_ASSERT(module_debug_info_set == expected_result);
|
||||
}
|
||||
|
||||
void testLiteInterpreterEval() {
|
||||
TEST(LiteInterpreterTest, Eval) {
|
||||
std::vector<torch::jit::IValue> inputs;
|
||||
|
||||
Module m("m");
|
||||
|
|
@ -619,7 +628,7 @@ void testLiteInterpreterEval() {
|
|||
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
||||
}
|
||||
|
||||
void testLiteInterpreterFindWrongMethodName() {
|
||||
TEST(LiteInterpreterTest, FindWrongMethodName) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
|
|
@ -633,7 +642,7 @@ void testLiteInterpreterFindWrongMethodName() {
|
|||
ASSERT_TRUE(bc.find_method("forward") == c10::nullopt);
|
||||
}
|
||||
|
||||
void testLiteInterpreterFindAndRunMethod() {
|
||||
TEST(LiteInterpreterTest, FindAndRunMethod) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
|
|
@ -663,7 +672,7 @@ void testLiteInterpreterFindAndRunMethod() {
|
|||
AT_ASSERT(resd == refd);
|
||||
}
|
||||
|
||||
void testLiteInterpreterRunMethodVariadic() {
|
||||
TEST(LiteInterpreterTest, RunMethodVariadic) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <test/cpp/jit/test_base.h>
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/mobile/export_data.h>
|
||||
|
|
@ -16,7 +17,7 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testLiteInterpreterParams() {
|
||||
TEST(LiteTrainerTest, Params) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
|
||||
m.define(R"(
|
||||
|
|
@ -74,7 +75,7 @@ void testLiteInterpreterParams() {
|
|||
AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
|
||||
}
|
||||
|
||||
void testMobileNamedParameters() {
|
||||
TEST(MobileTest, NamedParameters) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
|
|
@ -99,7 +100,7 @@ void testMobileNamedParameters() {
|
|||
}
|
||||
}
|
||||
|
||||
void testMobileSaveLoadData() {
|
||||
TEST(MobileTest, SaveLoadData) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
|
|
@ -127,7 +128,7 @@ void testMobileSaveLoadData() {
|
|||
}
|
||||
}
|
||||
|
||||
void testMobileSaveLoadParameters() {
|
||||
TEST(MobileTest, SaveLoadParameters) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
|
|
@ -157,7 +158,7 @@ void testMobileSaveLoadParameters() {
|
|||
}
|
||||
}
|
||||
|
||||
void testMobileSaveLoadParametersEmpty() {
|
||||
TEST(MobileTest, SaveLoadParametersEmpty) {
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def add_it(self, x):
|
||||
|
|
@ -180,7 +181,7 @@ void testMobileSaveLoadParametersEmpty() {
|
|||
AT_ASSERT(mobile_params.size() == 0);
|
||||
}
|
||||
|
||||
void testLiteSGD() {
|
||||
TEST(LiteTrainerTest, SGD) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
|
||||
m.define(R"(
|
||||
|
|
@ -253,7 +254,7 @@ struct DummyDataset : torch::data::datasets::Dataset<DummyDataset, int> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
void testLiteSequentialSampler() {
|
||||
TEST(LiteTrainerTest, SequentialSampler) {
|
||||
// test that sampler can be used with dataloader
|
||||
const int kBatchSize = 10;
|
||||
auto data_loader =
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,5 +1,6 @@
|
|||
#include "test/cpp/jit/test_base.h"
|
||||
//#include <gtest.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
|
||||
namespace c10 {
|
||||
// std::string serializeType(const Type &t);
|
||||
|
|
@ -8,50 +9,74 @@ TypePtr parseType(const std::string& pythonStr);
|
|||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
void testMobileTypeParser() {
|
||||
TEST(MobileTypeParserTest, Empty) {
|
||||
std::string empty_ps("");
|
||||
ASSERT_ANY_THROW(c10::parseType(empty_ps));
|
||||
}
|
||||
|
||||
TEST(MobileTypeParserTest, RoundTripAnnotationStr) {
|
||||
std::string int_ps("int");
|
||||
auto int_tp = c10::parseType(int_ps);
|
||||
std::string int_tps = int_tp->annotation_str();
|
||||
ASSERT_EQ(int_ps, int_tps);
|
||||
}
|
||||
|
||||
TEST(MobileTypeParserTest, NestedContainersAnnotationStr) {
|
||||
std::string tuple_ps(
|
||||
"Tuple[str, Optional[float], Dict[str, List[Tensor]], int]");
|
||||
auto tuple_tp = c10::parseType(tuple_ps);
|
||||
std::string tuple_tps = tuple_tp->annotation_str();
|
||||
ASSERT_EQ(tuple_ps, tuple_tps);
|
||||
}
|
||||
|
||||
TEST(MobileTypeParserTest, NestedContainersAnnotationStrWithSpaces) {
|
||||
std::string tuple_ps(
|
||||
"Tuple[str, Optional[float], Dict[str, List[Tensor]], int]");
|
||||
std::string tuple_space_ps(
|
||||
"Tuple[ str, Optional[float], Dict[str, List[Tensor ]] , int]");
|
||||
auto tuple_space_tp = c10::parseType(tuple_space_ps);
|
||||
// tuple_space_tps should not have weird white spaces
|
||||
std::string tuple_space_tps = tuple_space_tp->annotation_str();
|
||||
ASSERT_EQ(tuple_ps, tuple_space_tps);
|
||||
}
|
||||
|
||||
TEST(MobileTypeParserTest, TypoRaises) {
|
||||
std::string typo_token("List[tensor]");
|
||||
ASSERT_ANY_THROW(c10::parseType(typo_token));
|
||||
}
|
||||
|
||||
TEST(MobileTypeParserTest, MismatchBracketRaises) {
|
||||
std::string mismatch1("List[Tensor");
|
||||
ASSERT_ANY_THROW(c10::parseType(mismatch1));
|
||||
}
|
||||
|
||||
TEST(MobileTypeParserTest, MismatchBracketRaises2) {
|
||||
std::string mismatch2("List[[Tensor]");
|
||||
ASSERT_ANY_THROW(c10::parseType(mismatch2));
|
||||
}
|
||||
|
||||
TEST(MobileTypeParserTest, DictWithoutValueRaises) {
|
||||
std::string mismatch3("Dict[Tensor]");
|
||||
ASSERT_ANY_THROW(c10::parseType(mismatch3));
|
||||
}
|
||||
|
||||
TEST(MobileTypeParserTest, ListArgCountMismatchRaises) {
|
||||
// arg count mismatch
|
||||
std::string mismatch4("List[int, str]");
|
||||
ASSERT_ANY_THROW(c10::parseType(mismatch4));
|
||||
}
|
||||
|
||||
TEST(MobileTypeParserTest, DictArgCountMismatchRaises) {
|
||||
std::string trailing_commm("Dict[str,]");
|
||||
ASSERT_ANY_THROW(c10::parseType(trailing_commm));
|
||||
}
|
||||
|
||||
TEST(MobileTypeParserTest, ValidTypeWithExtraStuffRaises) {
|
||||
std::string extra_stuff("int int");
|
||||
ASSERT_ANY_THROW(c10::parseType(extra_stuff));
|
||||
}
|
||||
|
||||
TEST(MobileTypeParserTest, NonIdentifierRaises) {
|
||||
std::string non_id("(int)");
|
||||
ASSERT_ANY_THROW(c10::parseType(non_id));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <test/cpp/jit/test_base.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
#include <ATen/core/qualified_name.h>
|
||||
|
|
@ -42,7 +43,7 @@ static void import_libs(
|
|||
si.loadType(QualifiedName(class_name));
|
||||
}
|
||||
|
||||
void testModuleClone() {
|
||||
TEST(ModuleAPITest, Clone) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
// creating child module
|
||||
auto child = ClassType::create("child", cu, true);
|
||||
|
|
@ -71,7 +72,7 @@ void testModuleClone() {
|
|||
ASSERT_EQ(Module(p2.attr("c2").toObject()).attr(attr_name).toInt(), 3);
|
||||
}
|
||||
|
||||
void testModuleCloneWithModuleInterface() {
|
||||
TEST(ModuleAPITest, CloneWithModuleInterface) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
|
||||
// define a initial module with two submods share same interface
|
||||
|
|
@ -115,7 +116,7 @@ void testModuleCloneWithModuleInterface() {
|
|||
ASSERT_NE(clonedMod.type(), parentMod.type());
|
||||
}
|
||||
|
||||
void testModuleCopy() {
|
||||
TEST(ModuleAPITest, Copy) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
auto cls = ClassType::create("foo.bar", cu, true);
|
||||
auto attr_name = "attr";
|
||||
|
|
@ -144,7 +145,7 @@ void testModuleCopy() {
|
|||
ASSERT_EQ(m3.attr(attr_name).toInt(), 3);
|
||||
}
|
||||
|
||||
void testModuleDeepcopy() {
|
||||
TEST(ModuleAPITest, DeepCopy) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
auto cls = ClassType::create("foo.bar", cu, true);
|
||||
auto str_attr = "str_attr";
|
||||
|
|
@ -203,7 +204,7 @@ void testModuleDeepcopy() {
|
|||
ASSERT_TRUE(t1.equal(t3));
|
||||
}
|
||||
|
||||
void testModuleDeepcopyString() {
|
||||
TEST(ModuleAPITest, DeepCopyString) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
auto cls = ClassType::create("foo.bar", cu, true);
|
||||
auto attr1 = "attr1";
|
||||
|
|
@ -219,7 +220,7 @@ void testModuleDeepcopyString() {
|
|||
ASSERT_EQ(copied.attr(attr1).toString()->string(), original_str);
|
||||
}
|
||||
|
||||
void testModuleDeepcopyAliasing() {
|
||||
TEST(ModuleAPITest, DeepCopyPreservesAliasing) {
|
||||
// check deepcopy preserves aliasing
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
auto cls = ClassType::create("foo.bar", cu, true);
|
||||
|
|
@ -256,7 +257,7 @@ void testModuleDeepcopyAliasing() {
|
|||
ASSERT_TRUE(copied_attr3.isAliasOf(copied_attr4));
|
||||
}
|
||||
|
||||
void testModuleConstant() {
|
||||
TEST(ModuleAPITest, Constants) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
auto cls = ClassType::create("foo.bar", cu, true);
|
||||
auto attr_name = "attr";
|
||||
|
|
@ -272,7 +273,7 @@ void testModuleConstant() {
|
|||
ASSERT_EQ(m.attr(const_name).toInt(), 3);
|
||||
}
|
||||
|
||||
void testModuleParameter() {
|
||||
TEST(ModuleAPITest, Parameters) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
auto cls = ClassType::create("foo.bar", cu, true);
|
||||
Module m(cu, cls);
|
||||
|
|
@ -291,5 +292,39 @@ void testModuleParameter() {
|
|||
ASSERT_TRUE(m.hasattr("none_param2"));
|
||||
}
|
||||
|
||||
TEST(ModuleAPITest, Define) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
def add_it(self, x, b : int = 4):
|
||||
return self.foo + x + b
|
||||
)");
|
||||
auto result = m.run_method("add_it", torch::ones({}));
|
||||
AT_ASSERT(result.toTensor().item<float>() == 6);
|
||||
}
|
||||
|
||||
TEST(ModuleAPITest, To_CUDA) {
|
||||
Module m("test");
|
||||
{
|
||||
// test cuda to cpu for params and buffers
|
||||
m.register_parameter("foo", torch::ones({}, at::kCUDA), false);
|
||||
m.register_buffer("bar", torch::ones({}, at::kCUDA));
|
||||
|
||||
m.to(at::kCUDA);
|
||||
m.to(at::kCPU);
|
||||
AT_ASSERT(m.attr("foo").toTensor().device().is_cpu());
|
||||
AT_ASSERT(m.attr("bar").toTensor().device().is_cpu());
|
||||
}
|
||||
{
|
||||
// test cpu to cuda for params and buffers
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.register_buffer("bar", torch::ones({}));
|
||||
|
||||
m.to(at::kCUDA);
|
||||
AT_ASSERT(m.attr("foo").toTensor().device().is_cuda());
|
||||
AT_ASSERT(m.attr("bar").toTensor().device().is_cuda());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <test/cpp/jit/test_base.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
|
@ -8,47 +9,48 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testPeepholeOptimize() {
|
||||
// test is / is not none optimization
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
TEST(PeepholeOptimizeTest, IsAndIsNot)
|
||||
// test is / is not none optimization
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0 : int):
|
||||
%1 : None = prim::Constant()
|
||||
%2 : bool = aten::__is__(%0, %1)
|
||||
%3 : bool = aten::__isnot__(%0, %1)
|
||||
return (%2, %3)
|
||||
)IR",
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck()
|
||||
.check_not("aten::__is__")
|
||||
->check_not("aten::__isnot__")
|
||||
->run(*graph);
|
||||
}
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck()
|
||||
.check_not("aten::__is__")
|
||||
->check_not("aten::__isnot__")
|
||||
->run(*graph);
|
||||
}
|
||||
|
||||
TEST(PeepholeOptimizeTest, IsAndIsNot2) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0: int?):
|
||||
%1 : None = prim::Constant()
|
||||
%2 : bool = aten::__is__(%0, %1)
|
||||
%3 : bool = aten::__isnot__(%0, %1)
|
||||
return (%2, %3)
|
||||
)IR",
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck()
|
||||
.check("aten::__is__")
|
||||
->check("aten::__isnot__")
|
||||
->run(*graph);
|
||||
}
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck()
|
||||
.check("aten::__is__")
|
||||
->check("aten::__isnot__")
|
||||
->run(*graph);
|
||||
}
|
||||
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
TEST(PeepholeOptimizeTest, IsAndIsNot3) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0: int?):
|
||||
%1 : Tensor = prim::AutogradZero()
|
||||
%2 : None = prim::Constant()
|
||||
|
|
@ -56,48 +58,49 @@ graph(%0: int?):
|
|||
%5 : bool = aten::__isnot__(%1, %2)
|
||||
return (%4, %5)
|
||||
)IR",
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck()
|
||||
.check("aten::__is__")
|
||||
->check_not("aten::__isnot__")
|
||||
->run(*graph);
|
||||
}
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck()
|
||||
.check("aten::__is__")
|
||||
->check_not("aten::__isnot__")
|
||||
->run(*graph);
|
||||
}
|
||||
|
||||
// test unwrap optional
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
TEST(PeepholeOptimizeTest, UnwrapOptional)
|
||||
// test unwrap optional
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%1 : Float(*, *, *) = prim::Constant()
|
||||
%2 : bool = aten::_unwrap_optional(%1)
|
||||
%3 : bool = prim::unchecked_unwrap_optional(%1)
|
||||
return (%2, %3)
|
||||
)IR",
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck().check_not("unwrap")->run(*graph);
|
||||
}
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck().check_not("unwrap")->run(*graph);
|
||||
}
|
||||
|
||||
TEST(PeepholeOptimizeTest, UnwrapOptional2) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%1 : Float(*, *, *)?):
|
||||
%2 : bool = aten::_unwrap_optional(%1)
|
||||
%3 : bool = prim::unchecked_unwrap_optional(%1)
|
||||
return (%2, %3)
|
||||
)IR",
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck().check_count("unwrap", 2)->run(*graph);
|
||||
}
|
||||
graph.get());
|
||||
PeepholeOptimize(graph);
|
||||
testing::FileCheck().check_count("unwrap", 2)->run(*graph);
|
||||
}
|
||||
|
||||
// tests addmm fusion
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
TEST(PeepholeOptimizeTest, AddMMFusion) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(
|
||||
%0 : Float(2, 3, 4),
|
||||
%1 : Float(2, 3, 4),
|
||||
|
|
@ -108,10 +111,9 @@ graph(%1 : Float(*, *, *)?):
|
|||
%6 : Tensor = aten::add(%5, %2, %3)
|
||||
return (%6)
|
||||
)IR",
|
||||
graph.get());
|
||||
FuseAddMM(graph);
|
||||
testing::FileCheck().check("addmm")->run(*graph);
|
||||
}
|
||||
graph.get());
|
||||
FuseAddMM(graph);
|
||||
testing::FileCheck().check("addmm")->run(*graph);
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,68 +1,70 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include "test/cpp/jit/test_base.h"
|
||||
|
||||
using c10::QualifiedName;
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
void testQualifiedName() {
|
||||
{
|
||||
// Test prefix construction
|
||||
auto foo = QualifiedName("foo");
|
||||
auto bar = QualifiedName(foo, "bar");
|
||||
auto baz = QualifiedName(bar, "baz");
|
||||
ASSERT_EQ(baz.qualifiedName(), "foo.bar.baz");
|
||||
ASSERT_EQ(baz.prefix(), "foo.bar");
|
||||
ASSERT_EQ(baz.name(), "baz");
|
||||
auto nullstate = QualifiedName();
|
||||
ASSERT_EQ(nullstate.qualifiedName(), "");
|
||||
ASSERT_EQ(nullstate.prefix(), "");
|
||||
ASSERT_EQ(nullstate.name(), "");
|
||||
}
|
||||
{
|
||||
// Test dotted construction
|
||||
auto foo = QualifiedName("foo.bar.baz");
|
||||
ASSERT_EQ(foo.qualifiedName(), "foo.bar.baz");
|
||||
ASSERT_EQ(foo.prefix(), "foo.bar");
|
||||
ASSERT_EQ(foo.name(), "baz");
|
||||
TEST(QualifiedNameTest, PrefixConstruction) {
|
||||
// Test prefix construction
|
||||
auto foo = QualifiedName("foo");
|
||||
auto bar = QualifiedName(foo, "bar");
|
||||
auto baz = QualifiedName(bar, "baz");
|
||||
ASSERT_EQ(baz.qualifiedName(), "foo.bar.baz");
|
||||
ASSERT_EQ(baz.prefix(), "foo.bar");
|
||||
ASSERT_EQ(baz.name(), "baz");
|
||||
auto nullstate = QualifiedName();
|
||||
ASSERT_EQ(nullstate.qualifiedName(), "");
|
||||
ASSERT_EQ(nullstate.prefix(), "");
|
||||
ASSERT_EQ(nullstate.name(), "");
|
||||
}
|
||||
|
||||
auto bar = QualifiedName("bar");
|
||||
ASSERT_EQ(bar.qualifiedName(), "bar");
|
||||
ASSERT_EQ(bar.prefix(), "");
|
||||
ASSERT_EQ(bar.name(), "bar");
|
||||
}
|
||||
{
|
||||
// throw some bad inputs at it
|
||||
ASSERT_ANY_THROW(QualifiedName("foo..bar"));
|
||||
ASSERT_ANY_THROW(QualifiedName(".foo.bar"));
|
||||
ASSERT_ANY_THROW(QualifiedName("foo.bar."));
|
||||
ASSERT_ANY_THROW(QualifiedName(""));
|
||||
}
|
||||
{
|
||||
// test equality api
|
||||
auto foo1 = QualifiedName("foo.bar.baz");
|
||||
auto foo2 = QualifiedName("foo.bar.baz");
|
||||
auto foo3 = QualifiedName("bar.bar.baz");
|
||||
ASSERT_EQ(foo1, foo2);
|
||||
ASSERT_NE(foo1, foo3);
|
||||
auto bar1 = QualifiedName("sup");
|
||||
auto bar2 = QualifiedName("sup");
|
||||
ASSERT_EQ(foo1, foo2);
|
||||
}
|
||||
{
|
||||
// test prefix api
|
||||
auto foo1 = QualifiedName("foo.bar.baz");
|
||||
auto foo2 = QualifiedName("foo.bar");
|
||||
auto foo3 = QualifiedName("bar.bar.baz");
|
||||
auto foo4 = QualifiedName("foo.bar");
|
||||
ASSERT_TRUE(foo2.isPrefixOf(foo1));
|
||||
ASSERT_TRUE(foo2.isPrefixOf(foo4));
|
||||
ASSERT_TRUE(foo4.isPrefixOf(foo2));
|
||||
ASSERT_FALSE(foo1.isPrefixOf(foo2));
|
||||
ASSERT_FALSE(foo2.isPrefixOf(foo3));
|
||||
}
|
||||
TEST(QualifiedNameTest, DottedConstruction) {
|
||||
// Test dotted construction
|
||||
auto foo = QualifiedName("foo.bar.baz");
|
||||
ASSERT_EQ(foo.qualifiedName(), "foo.bar.baz");
|
||||
ASSERT_EQ(foo.prefix(), "foo.bar");
|
||||
ASSERT_EQ(foo.name(), "baz");
|
||||
|
||||
auto bar = QualifiedName("bar");
|
||||
ASSERT_EQ(bar.qualifiedName(), "bar");
|
||||
ASSERT_EQ(bar.prefix(), "");
|
||||
ASSERT_EQ(bar.name(), "bar");
|
||||
}
|
||||
|
||||
TEST(QualifiedNameTest, BadInputRaises) {
|
||||
// throw some bad inputs at it
|
||||
ASSERT_ANY_THROW(QualifiedName("foo..bar"));
|
||||
ASSERT_ANY_THROW(QualifiedName(".foo.bar"));
|
||||
ASSERT_ANY_THROW(QualifiedName("foo.bar."));
|
||||
ASSERT_ANY_THROW(QualifiedName(""));
|
||||
}
|
||||
|
||||
TEST(QualifiedNameTest, Equality) {
|
||||
// test equality api
|
||||
auto foo1 = QualifiedName("foo.bar.baz");
|
||||
auto foo2 = QualifiedName("foo.bar.baz");
|
||||
auto foo3 = QualifiedName("bar.bar.baz");
|
||||
ASSERT_EQ(foo1, foo2);
|
||||
ASSERT_NE(foo1, foo3);
|
||||
auto bar1 = QualifiedName("sup");
|
||||
auto bar2 = QualifiedName("sup");
|
||||
ASSERT_EQ(foo1, foo2);
|
||||
}
|
||||
|
||||
TEST(QualifiedNameTest, IsPrefixOf) {
|
||||
// test prefix api
|
||||
auto foo1 = QualifiedName("foo.bar.baz");
|
||||
auto foo2 = QualifiedName("foo.bar");
|
||||
auto foo3 = QualifiedName("bar.bar.baz");
|
||||
auto foo4 = QualifiedName("foo.bar");
|
||||
ASSERT_TRUE(foo2.isPrefixOf(foo1));
|
||||
ASSERT_TRUE(foo2.isPrefixOf(foo4));
|
||||
ASSERT_TRUE(foo4.isPrefixOf(foo2));
|
||||
ASSERT_FALSE(foo1.isPrefixOf(foo2));
|
||||
ASSERT_FALSE(foo2.isPrefixOf(foo3));
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <test/cpp/jit/test_base.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <sstream>
|
||||
|
||||
|
|
@ -12,10 +13,10 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// Tests that an extra file written explicitly has precedence over
|
||||
// extra files written by a hook
|
||||
// TODO: test for the warning, too
|
||||
void testExtraFilesHookPreference() {
|
||||
TEST(SerializationTest, ExtraFilesHookPreference) {
|
||||
// Tests that an extra file written explicitly has precedence over
|
||||
// extra files written by a hook
|
||||
// TODO: test for the warning, too
|
||||
const auto script = R"JIT(
|
||||
def forward(self):
|
||||
x = torch.rand(5, 5)
|
||||
|
|
@ -43,52 +44,50 @@ void testExtraFilesHookPreference() {
|
|||
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
|
||||
}
|
||||
|
||||
void testSaveExtraFilesHook() {
|
||||
TEST(SerializationTest, ExtraFileHooksNoSecret) {
|
||||
// no secrets
|
||||
std::stringstream ss;
|
||||
{
|
||||
std::stringstream ss;
|
||||
{
|
||||
Module m("__torch__.m");
|
||||
ExtraFilesMap extra;
|
||||
extra["metadata.json"] = "abc";
|
||||
m.save(ss, extra);
|
||||
}
|
||||
ss.seekg(0);
|
||||
{
|
||||
ExtraFilesMap extra;
|
||||
extra["metadata.json"] = "";
|
||||
extra["secret.json"] = "";
|
||||
jit::load(ss, c10::nullopt, extra);
|
||||
ASSERT_EQ(extra["metadata.json"], "abc");
|
||||
ASSERT_EQ(extra["secret.json"], "");
|
||||
}
|
||||
Module m("__torch__.m");
|
||||
ExtraFilesMap extra;
|
||||
extra["metadata.json"] = "abc";
|
||||
m.save(ss, extra);
|
||||
}
|
||||
// some secret
|
||||
ss.seekg(0);
|
||||
{
|
||||
std::stringstream ss;
|
||||
{
|
||||
SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
|
||||
return {{"secret.json", "topsecret"}};
|
||||
});
|
||||
Module m("__torch__.m");
|
||||
ExtraFilesMap extra;
|
||||
extra["metadata.json"] = "abc";
|
||||
m.save(ss, extra);
|
||||
SetExportModuleExtraFilesHook(nullptr);
|
||||
}
|
||||
ss.seekg(0);
|
||||
{
|
||||
ExtraFilesMap extra;
|
||||
extra["metadata.json"] = "";
|
||||
extra["secret.json"] = "";
|
||||
jit::load(ss, c10::nullopt, extra);
|
||||
ASSERT_EQ(extra["metadata.json"], "abc");
|
||||
ASSERT_EQ(extra["secret.json"], "topsecret");
|
||||
}
|
||||
ExtraFilesMap extra;
|
||||
extra["metadata.json"] = "";
|
||||
extra["secret.json"] = "";
|
||||
jit::load(ss, c10::nullopt, extra);
|
||||
ASSERT_EQ(extra["metadata.json"], "abc");
|
||||
ASSERT_EQ(extra["secret.json"], "");
|
||||
}
|
||||
}
|
||||
|
||||
void testTypeTags() {
|
||||
TEST(SerializationTest, ExtraFileHooksWithSecret) {
|
||||
std::stringstream ss;
|
||||
{
|
||||
SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
|
||||
return {{"secret.json", "topsecret"}};
|
||||
});
|
||||
Module m("__torch__.m");
|
||||
ExtraFilesMap extra;
|
||||
extra["metadata.json"] = "abc";
|
||||
m.save(ss, extra);
|
||||
SetExportModuleExtraFilesHook(nullptr);
|
||||
}
|
||||
ss.seekg(0);
|
||||
{
|
||||
ExtraFilesMap extra;
|
||||
extra["metadata.json"] = "";
|
||||
extra["secret.json"] = "";
|
||||
jit::load(ss, c10::nullopt, extra);
|
||||
ASSERT_EQ(extra["metadata.json"], "abc");
|
||||
ASSERT_EQ(extra["secret.json"], "topsecret");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SerializationTest, TypeTags) {
|
||||
auto list = c10::List<c10::List<int64_t>>();
|
||||
list.push_back(c10::List<int64_t>({1, 2, 3}));
|
||||
list.push_back(c10::List<int64_t>({4, 5, 6}));
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
#include <torch/jit.h>
|
||||
#include "test/cpp/jit/test_base.h"
|
||||
#include "torch/csrc/jit/runtime/custom_operator.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
|
@ -10,80 +11,79 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testSchemaMatching() {
|
||||
{
|
||||
RegisterOperators reg({
|
||||
Operator(
|
||||
"aten::test_vartype(t[] a, t b) -> (t)",
|
||||
[](Stack* stack) {
|
||||
c10::List<double> list;
|
||||
double a;
|
||||
pop(stack, list, a);
|
||||
push(stack, a);
|
||||
},
|
||||
c10::AliasAnalysisKind::FROM_SCHEMA),
|
||||
});
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
TEST(SchemaMatchingTest, VarType) {
|
||||
RegisterOperators reg({
|
||||
Operator(
|
||||
"aten::test_vartype(t[] a, t b) -> (t)",
|
||||
[](Stack* stack) {
|
||||
c10::List<double> list;
|
||||
double a;
|
||||
pop(stack, list, a);
|
||||
push(stack, a);
|
||||
},
|
||||
c10::AliasAnalysisKind::FROM_SCHEMA),
|
||||
});
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def test(self):
|
||||
a = (1.0, 2.0)
|
||||
return torch.test_vartype(a, 2.0)
|
||||
)");
|
||||
auto result = m.run_method("test");
|
||||
TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0);
|
||||
auto result = m.run_method("test");
|
||||
TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0);
|
||||
|
||||
const std::string error_example = R"JIT(
|
||||
const std::string error_example = R"JIT(
|
||||
def test_2(self):
|
||||
a = (1.0, 2.0)
|
||||
non_float = (1, 1)
|
||||
return torch.test_vartype(a, non_float)
|
||||
)JIT";
|
||||
|
||||
std::string err = "";
|
||||
try {
|
||||
m.define(error_example);
|
||||
} catch (const std::exception& e) {
|
||||
err = e.what();
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
err.find("previously matched to type") != std::string::npos);
|
||||
std::string err = "";
|
||||
try {
|
||||
m.define(error_example);
|
||||
} catch (const std::exception& e) {
|
||||
err = e.what();
|
||||
}
|
||||
{
|
||||
RegisterOperators reg({
|
||||
Operator(
|
||||
"aten::test_vartype2(t a, t[] b) -> (t[])",
|
||||
[](Stack* stack) {
|
||||
double a;
|
||||
c10::List<double> list;
|
||||
pop(stack, a, list);
|
||||
push(stack, a);
|
||||
},
|
||||
AliasAnalysisKind::FROM_SCHEMA),
|
||||
});
|
||||
Module m("m");
|
||||
m.define(R"JIT(
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
err.find("previously matched to type") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST(SchemaMatchingTest, VarType2) {
|
||||
RegisterOperators reg({
|
||||
Operator(
|
||||
"aten::test_vartype2(t a, t[] b) -> (t[])",
|
||||
[](Stack* stack) {
|
||||
double a;
|
||||
c10::List<double> list;
|
||||
pop(stack, a, list);
|
||||
push(stack, a);
|
||||
},
|
||||
AliasAnalysisKind::FROM_SCHEMA),
|
||||
});
|
||||
Module m("m");
|
||||
m.define(R"JIT(
|
||||
def test(self):
|
||||
a = (1.0, 2.0)
|
||||
return torch.test_vartype2(3.0, a)
|
||||
)JIT");
|
||||
auto result = m.run_method("test");
|
||||
TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0);
|
||||
auto result = m.run_method("test");
|
||||
TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0);
|
||||
|
||||
static const auto error_exam2 = R"JIT(
|
||||
static const auto error_exam2 = R"JIT(
|
||||
def test_2(self):
|
||||
a = (1, 2)
|
||||
return torch.test_vartype2(3.0, a)
|
||||
)JIT";
|
||||
|
||||
std::string err = "";
|
||||
try {
|
||||
m.define(error_exam2);
|
||||
} catch (const std::exception& e) {
|
||||
err = e.what();
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
err.find("previously matched to type") != std::string::npos);
|
||||
std::string err = "";
|
||||
try {
|
||||
m.define(error_exam2);
|
||||
} catch (const std::exception& e) {
|
||||
err = e.what();
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
err.find("previously matched to type") != std::string::npos);
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
#include "test/cpp/jit/test_base.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "test/cpp/jit/test_utils.h"
|
||||
#include "torch/csrc/jit/ir/subgraph_matcher.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testTrivial1() {
|
||||
TEST(SubgraphMatcherTest, Trivial1) {
|
||||
Graph graph, pattern;
|
||||
parseIR(
|
||||
R"IR(
|
||||
|
|
@ -22,7 +23,7 @@ graph(%0):
|
|||
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
||||
}
|
||||
|
||||
void testTrivial2() {
|
||||
TEST(SubgraphMatcherTest, Trivial2) {
|
||||
Graph graph;
|
||||
auto* g_in = graph.addInput();
|
||||
auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1));
|
||||
|
|
@ -45,7 +46,7 @@ void testTrivial2() {
|
|||
}
|
||||
}
|
||||
|
||||
void testTrivial3() {
|
||||
TEST(SubgraphMatcherTest, Trivial3) {
|
||||
Graph graph, pattern;
|
||||
parseIR(
|
||||
R"IR(
|
||||
|
|
@ -64,7 +65,7 @@ graph(%a, %b):
|
|||
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
||||
}
|
||||
|
||||
void testTrivial4() {
|
||||
TEST(SubgraphMatcherTest, Trivial4) {
|
||||
Graph graph;
|
||||
auto* g_in0 = graph.addInput();
|
||||
auto* g_in1 = graph.addInput();
|
||||
|
|
@ -92,7 +93,7 @@ void testTrivial4() {
|
|||
}
|
||||
}
|
||||
|
||||
void testLinear1() {
|
||||
TEST(SubgraphMatcherTest, Linear1) {
|
||||
Graph graph, pattern;
|
||||
parseIR(
|
||||
R"IR(
|
||||
|
|
@ -114,7 +115,7 @@ graph(%0):
|
|||
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
||||
}
|
||||
|
||||
void testLinear2() {
|
||||
TEST(SubgraphMatcherTest, Linear2) {
|
||||
Graph graph;
|
||||
auto* g_in = graph.addInput();
|
||||
|
||||
|
|
@ -164,7 +165,7 @@ void testLinear2() {
|
|||
* |
|
||||
* eee
|
||||
*/
|
||||
void testDiamond1() {
|
||||
TEST(SubgraphMatcherTest, Diamond1) {
|
||||
Graph graph, pattern1, pattern2;
|
||||
parseIR(
|
||||
R"IR(
|
||||
|
|
@ -215,7 +216,7 @@ graph(%0):
|
|||
* |
|
||||
* o1
|
||||
*/
|
||||
void testDiamond2() {
|
||||
TEST(SubgraphMatcherTest, Diamond2) {
|
||||
Graph graph;
|
||||
auto* g_in = graph.addInput();
|
||||
|
||||
|
|
@ -253,7 +254,7 @@ void testDiamond2() {
|
|||
}
|
||||
}
|
||||
|
||||
void testXPattern() {
|
||||
TEST(SubgraphMatcherTest, XPattern) {
|
||||
Graph graph, pattern;
|
||||
parseIR(
|
||||
R"IR(
|
||||
|
|
@ -280,7 +281,7 @@ graph(%0, %1):
|
|||
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
||||
}
|
||||
|
||||
void testMultipleMatches() {
|
||||
TEST(SubgraphMatcherTest, MultipleMatches) {
|
||||
Graph graph, pattern;
|
||||
parseIR(
|
||||
R"IR(
|
||||
|
|
@ -301,7 +302,7 @@ graph(%t0):
|
|||
AT_ASSERT(matches.size() == 4);
|
||||
}
|
||||
|
||||
void testOverlappingMatches() {
|
||||
TEST(SubgraphMatcherTest, OverlappingMatches) {
|
||||
Graph graph, pattern;
|
||||
parseIR(
|
||||
R"IR(
|
||||
|
|
@ -323,7 +324,7 @@ graph(%t0):
|
|||
AT_ASSERT(matches.size() == 3);
|
||||
}
|
||||
|
||||
void testMatchInBasicBlocks1() {
|
||||
TEST(SubgraphMatcherTest, MatchInBasicBlocks1) {
|
||||
Graph graph;
|
||||
parseIR(
|
||||
R"IR(
|
||||
|
|
@ -360,7 +361,7 @@ graph(%x, %y):
|
|||
AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0);
|
||||
}
|
||||
|
||||
void testMatchInBasicBlocks2() {
|
||||
TEST(SubgraphMatcherTest, MatchInBasicBlocks2) {
|
||||
Graph graph;
|
||||
parseIR(
|
||||
R"IR(
|
||||
|
|
@ -395,7 +396,7 @@ graph(%x, %y):
|
|||
AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0);
|
||||
}
|
||||
|
||||
void testMatchesAttributes() {
|
||||
TEST(SubgraphMatcherTest, MatchesAttributes) {
|
||||
Graph graph;
|
||||
parseIR(
|
||||
R"IR(
|
||||
|
|
@ -479,7 +480,7 @@ graph(%a, %b):
|
|||
}
|
||||
}
|
||||
|
||||
void testBadPattern() {
|
||||
TEST(SubgraphMatcherTest, BadPattern) {
|
||||
Graph graph, pattern1, pattern2;
|
||||
parseIR(
|
||||
R"IR(
|
||||
|
|
@ -509,23 +510,5 @@ graph(%x):
|
|||
ASSERT_ANY_THROW(findPatternMatches(pattern2, graph));
|
||||
}
|
||||
|
||||
void testSubgraphMatching() {
|
||||
testTrivial1();
|
||||
testTrivial2();
|
||||
testTrivial3();
|
||||
testTrivial4();
|
||||
testLinear1();
|
||||
testLinear2();
|
||||
testDiamond1();
|
||||
testDiamond2();
|
||||
testXPattern();
|
||||
testMultipleMatches();
|
||||
testOverlappingMatches();
|
||||
testMatchInBasicBlocks1();
|
||||
testMatchInBasicBlocks2();
|
||||
testMatchesAttributes();
|
||||
testBadPattern();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <test/cpp/jit/test_base.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <torch/csrc/jit/ir/subgraph_matcher.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
|
|
@ -8,7 +9,7 @@ namespace torch {
|
|||
namespace jit {
|
||||
using namespace testing;
|
||||
|
||||
void testFilterMatch() {
|
||||
TEST(SubgraphRewriterTest, FilterMatch) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
|
||||
parseIR(
|
||||
|
|
@ -80,7 +81,7 @@ graph(%a, %b):
|
|||
}
|
||||
}
|
||||
|
||||
void testFilterNoMatch() {
|
||||
TEST(SubgraphRewriterTest, FilterNoMatch) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
|
|
@ -121,10 +122,5 @@ graph(%a, %b):
|
|||
FileCheck().check("c::ccc")->check_not("d::ddd")->run(*graph);
|
||||
}
|
||||
|
||||
void testSubgraphRewriter() {
|
||||
testFilterMatch();
|
||||
testFilterNoMatch();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include "test/cpp/jit/test_base.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "test/cpp/jit/test_utils.h"
|
||||
|
||||
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
|
||||
|
|
@ -7,7 +8,7 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testSubgraphUtils() {
|
||||
TEST(SubgraphUtilsTest, Basic) {
|
||||
auto graph = build_lstm();
|
||||
EliminateCommonSubexpression(graph);
|
||||
|
||||
|
|
@ -37,7 +38,7 @@ void testSubgraphUtils() {
|
|||
ASSERT_EQ(originalNodes.size(), newNodes.size());
|
||||
}
|
||||
|
||||
void testSubgraphUtilsVmap() {
|
||||
TEST(SubgraphUtilsTest, Vmap) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
|
||||
std::unordered_map<std::string, Value*> parse_map;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/clear_undefinedness.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
|
@ -137,5 +140,22 @@ std::pair<at::Tensor, at::Tensor> lstm(
|
|||
return {hy, cy};
|
||||
}
|
||||
|
||||
inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
|
||||
return c10::AliasAnalysisKind::FROM_SCHEMA;
|
||||
}
|
||||
|
||||
namespace {
|
||||
RegisterOperators reg({
|
||||
// This operator is intended to be used in JIT analysis and transformation
|
||||
// pass unit tests in which Values with type Tensor are often required. It
|
||||
// should not be used in situations in which the graph is actually executed
|
||||
// because it always produces empty Tensors.
|
||||
Operator(
|
||||
"prim::MakeTestTensor() -> Tensor",
|
||||
[](Stack* stack) { push(stack, at::Tensor()); },
|
||||
aliasAnalysisFromSchema()),
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
#include "test/cpp/jit/test_base.h"
|
||||
#include "torch/csrc/jit/ir/irparser.h"
|
||||
#include "torch/csrc/jit/runtime/autodiff.h"
|
||||
#include "torch/csrc/jit/runtime/interpreter.h"
|
||||
|
|
|
|||
|
|
@ -1,242 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
/**
|
||||
* See README.md for instructions on how to add a new test.
|
||||
*/
|
||||
#include <c10/macros/Export.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
#define TH_FORALL_TESTS(_) \
|
||||
_(Attributes) \
|
||||
_(Blocks) \
|
||||
_(CallStack) \
|
||||
_(CallStackCaching) \
|
||||
_(ControlFlow) \
|
||||
_(IValueKWargs) \
|
||||
_(CustomFusion) \
|
||||
_(SchemaMatching) \
|
||||
_(FromQualString) \
|
||||
_(InternedStrings) \
|
||||
_(PassManagement) \
|
||||
_(Proto) \
|
||||
_(SchemaParser) \
|
||||
_(TopologicalIndex) \
|
||||
_(SubgraphUtils) \
|
||||
_(SubgraphUtilsVmap) \
|
||||
_(IRParser) \
|
||||
_(THNNConv) \
|
||||
_(ATenNativeBatchNorm) \
|
||||
_(NoneSchemaMatch) \
|
||||
_(UnifyTypes) \
|
||||
_(Profiler) \
|
||||
_(FallbackGraphs) \
|
||||
_(InsertAndEliminateRedundantGuards) \
|
||||
_(LoopPeeler) \
|
||||
_(InsertBailOuts) \
|
||||
_(PeepholeOptimize) \
|
||||
_(RecordFunction) \
|
||||
_(ThreadLocalDebugInfo) \
|
||||
_(SubgraphMatching) \
|
||||
_(SubgraphRewriter) \
|
||||
_(ModuleClone) \
|
||||
_(ModuleConstant) \
|
||||
_(ModuleParameter) \
|
||||
_(ModuleCopy) \
|
||||
_(ModuleDeepcopy) \
|
||||
_(ModuleDeepcopyString) \
|
||||
_(ModuleDeepcopyAliasing) \
|
||||
_(ModuleDefine) \
|
||||
_(QualifiedName) \
|
||||
_(ExtraFilesHookPreference) \
|
||||
_(SaveExtraFilesHook) \
|
||||
_(TypeTags) \
|
||||
_(CustomFusionNestedBlocks) \
|
||||
_(ModuleInterfaceSerialization) \
|
||||
_(ModuleCloneWithModuleInterface) \
|
||||
_(ClassTypeAddRemoveAttr) \
|
||||
_(Inliner) \
|
||||
_(LiteInterpreterAdd) \
|
||||
_(LiteInterpreterConv) \
|
||||
_(LiteInterpreterInline) \
|
||||
_(LiteInterpreterTuple) \
|
||||
_(LiteInterpreterUpsampleNearest2d) \
|
||||
_(CommonAncestor) \
|
||||
_(AutogradSymbols) \
|
||||
_(DefaultArgTypeHinting) \
|
||||
_(Futures) \
|
||||
_(TLSFutureCallbacks) \
|
||||
_(ProfilerDisableInCallback) \
|
||||
_(MobileTypeParser) \
|
||||
_(LiteInterpreterBuiltinFunction) \
|
||||
_(LiteInterpreterPrim) \
|
||||
_(LiteInterpreterPrimScalar) \
|
||||
_(LiteInterpreterLoadOrigJit) \
|
||||
_(LiteInterpreterWrongMethodName) \
|
||||
_(LiteInterpreterParams) \
|
||||
_(LiteInterpreterSetState) \
|
||||
_(LiteInterpreterModuleInfoBasic) \
|
||||
_(LiteInterpreterNotSavingModuleInfo) \
|
||||
_(LiteInterpreterOneSubmoduleModuleInfo) \
|
||||
_(LiteInterpreterTwoSubmodulesModuleInfo) \
|
||||
_(LiteInterpreterSequentialModuleInfo) \
|
||||
_(LiteInterpreterHierarchyModuleInfo) \
|
||||
_(LiteInterpreterDuplicatedClassTypeModuleInfo) \
|
||||
_(LiteInterpreterEval) \
|
||||
_(LiteInterpreterDict) \
|
||||
_(LiteInterpreterFindAndRunMethod) \
|
||||
_(LiteInterpreterFindWrongMethodName) \
|
||||
_(MobileNamedParameters) \
|
||||
_(MobileSaveLoadData) \
|
||||
_(MobileSaveLoadParameters) \
|
||||
_(MobileSaveLoadParametersEmpty) \
|
||||
_(LiteSGD) \
|
||||
_(LiteSequentialSampler)
|
||||
|
||||
#if defined(USE_CUDA)
|
||||
#define TH_FORALL_TESTS_CUDA(_) \
|
||||
_(GraphExecutor) \
|
||||
_(ModuleConversion) \
|
||||
_(Interp) \
|
||||
_(TypeCheck) \
|
||||
_(GPU_IrGraphGenerator) \
|
||||
_(GPU_FusionDispatch) \
|
||||
_(GPU_FusionClear) \
|
||||
_(GPU_FusionCopy) \
|
||||
_(GPU_FusionMove) \
|
||||
_(GPU_FusionSimpleArith) \
|
||||
_(GPU_FusionExprEvalConstants) \
|
||||
_(GPU_FusionExprEvalBindings) \
|
||||
_(GPU_FusionExprEvalBasic) \
|
||||
_(GPU_FusionExprEvalComplex) \
|
||||
_(GPU_FusionExprEvalPostLower) \
|
||||
_(GPU_FusionSimpleTypePromote) \
|
||||
_(GPU_FusionMutator) \
|
||||
_(GPU_FusionRegister) \
|
||||
_(GPU_FusionTopoSort) \
|
||||
_(GPU_FusionTensor) \
|
||||
_(GPU_FusionFilterVals) \
|
||||
_(GPU_FusionTVSplit) \
|
||||
_(GPU_FusionTVMerge) \
|
||||
_(GPU_FusionTVReorder) \
|
||||
_(GPU_FusionEquality) \
|
||||
_(GPU_FusionParser) \
|
||||
_(GPU_FusionDependency) \
|
||||
_(GPU_FusionCodeGen) \
|
||||
_(GPU_FusionCodeGen2) \
|
||||
_(GPU_FusionSimplePWise) \
|
||||
_(GPU_FusionExecKernel) \
|
||||
_(GPU_FusionForLoop) \
|
||||
_(GPU_FusionLoopUnroll) \
|
||||
_(GPU_FusionUnaryOps) \
|
||||
_(GPU_FusionBinaryOps) \
|
||||
_(GPU_FusionTernaryOps) \
|
||||
_(GPU_FusionCompoundOps) \
|
||||
_(GPU_FusionCastOps) \
|
||||
_(GPU_FusionAdvancedComputeAt) \
|
||||
_(GPU_FusionComputeAtMultiConsumers) \
|
||||
_(GPU_FusionComputeAtCommonConsumer1) \
|
||||
_(GPU_FusionComputeAtCommonConsumer2) \
|
||||
_(GPU_FusionComputeAtCommonConsumer3) \
|
||||
_(GPU_FusionComputeAtNoCommonConsumer) \
|
||||
_(GPU_FusionScalarInputs) \
|
||||
_(GPU_FusionBCastConcretizeBasic) \
|
||||
_(GPU_FusionBCastConcretizeRfactor) \
|
||||
_(GPU_FusionProveIdEqBasic) \
|
||||
_(GPU_FusionProveIdEqRfactor) \
|
||||
_(GPU_FusionRFactorReplay) \
|
||||
_(GPU_FusionReduction) \
|
||||
_(GPU_FusionReduction2) \
|
||||
_(GPU_FusionReduction3) \
|
||||
_(GPU_FusionReduction4) \
|
||||
_(GPU_FusionReduction5) \
|
||||
_(GPU_FusionReductionTFT) \
|
||||
_(GPU_FusionSimpleBCast) \
|
||||
_(GPU_FusionComplexBCast) \
|
||||
_(GPU_FusionAdvancedIndexing) \
|
||||
_(GPU_FusionSimpleGemm) \
|
||||
_(GPU_FusionSoftmax1D) \
|
||||
_(GPU_FusionSoftmax1DNormalized) \
|
||||
_(GPU_FusionSoftmax3D) \
|
||||
_(GPU_FusionSoftmax3DNormalized) \
|
||||
_(GPU_FusionSoftmaxComputeAt) \
|
||||
_(GPU_FusionGridReduction1) \
|
||||
_(GPU_FusionGridReduction2) \
|
||||
_(GPU_FusionGridReduction3dim1) \
|
||||
_(GPU_FusionGridReduction3dim0) \
|
||||
_(GPU_FusionGridReduction4) \
|
||||
_(GPU_FusionGridReduction5) \
|
||||
_(GPU_FusionGridReduction6) \
|
||||
_(GPU_FusionNonRedAxisBind) \
|
||||
_(GPU_FusionBCastInnerDim) \
|
||||
_(GPU_FusionBCastReduce) \
|
||||
_(GPU_FusionSplitBCast) \
|
||||
_(GPU_FusionComputeAtExprOrder) \
|
||||
_(GPU_FusionZeroDimComputeAt) \
|
||||
_(GPU_FusionZeroDimBroadcast) \
|
||||
_(GPU_FusionZeroDimReduction) \
|
||||
_(GPU_FusionReductionMultiConsumer) \
|
||||
_(GPU_FusionBCastAfterReduce) \
|
||||
_(GPU_FusionReductionScheduler) \
|
||||
_(GPU_FusionReductionSchedulerMultiDimNonFastest) \
|
||||
_(GPU_FusionReductionSchedulerMultiDimFastest) \
|
||||
_(GPU_FusionReductionSchedulerDimShmoo) \
|
||||
_(GPU_FusionCacheBefore) \
|
||||
_(GPU_FusionCacheAfter) \
|
||||
_(GPU_FusionCacheIndirect) \
|
||||
_(GPU_FusionCacheBcast) \
|
||||
_(GPU_FusionCacheComplex) \
|
||||
_(GPU_FusionCacheMultiConsumer) \
|
||||
_(GPU_FusionSmem) \
|
||||
_(GPU_FusionSmemReduce) \
|
||||
_(GPU_FusionSmemBlockGemm) \
|
||||
_(GPU_FusionSmemBlockGemmCache) \
|
||||
_(GPU_FusionSmemDynamicReductionSymbolic) \
|
||||
_(GPU_FusionSmemDynamicReductionSymbolicArg) \
|
||||
_(GPU_FusionSmemDynamicPwiseMulSymbolicArgWAR) \
|
||||
_(GPU_FusionSmemDynamicTiledGemm) \
|
||||
_(GPU_FusionGlobalIntermediate) \
|
||||
_(GPU_FusionGlobalIntermediateDefaultSchedule) \
|
||||
_(GPU_FusionConstCheck) \
|
||||
_(GPU_FusionSymbolicReduction) \
|
||||
_(GPU_FusionUnrollWithAlloc) \
|
||||
_(GPU_FusionIsZeroInt) \
|
||||
_(GPU_FusionIsOneInt) \
|
||||
_(GPU_FusionComputeAtNonterminatingOutput) \
|
||||
_(GPU_FusionTraversalOrder1) \
|
||||
_(GPU_FusionTraversalOrder2) \
|
||||
_(GPU_FusionTraversalOrder3) \
|
||||
_(GPU_FusionTraversalOrder4) \
|
||||
_(GPU_FusionTraversalOrder5) \
|
||||
_(GPU_FusionTraversalOrder6) \
|
||||
_(GPU_FusionTraversalOrder7) \
|
||||
_(GPU_FusionBranches) \
|
||||
_(GPU_FusionThreadPredicate) \
|
||||
_(GPU_FusionLSTMCell) \
|
||||
_(GPU_FusionComputeAtMultiBCast) \
|
||||
_(GPU_FusionReductionHalf) \
|
||||
_(GPU_FusionInputsIdLookup)
|
||||
#else
|
||||
#define TH_FORALL_TESTS_CUDA(_) \
|
||||
_(GraphExecutor) \
|
||||
_(ModuleConversion) \
|
||||
_(Interp) \
|
||||
_(TypeCheck)
|
||||
#endif
|
||||
|
||||
#define DECLARE_JIT_TEST(name) void test##name();
|
||||
TH_FORALL_TESTS(DECLARE_JIT_TEST)
|
||||
TH_FORALL_TESTS_CUDA(DECLARE_JIT_TEST)
|
||||
#undef DECLARE_JIT_TEST
|
||||
|
||||
// This test is special since it requires prior setup in python.
|
||||
// So it is not part of the general test list (which is shared between the gtest
|
||||
// and python test runners), but is instead invoked manually by the
|
||||
// torch_python_test.cpp
|
||||
void testEvalModeForLoadedModule();
|
||||
void testSerializationInterop();
|
||||
void testTorchSaveError();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
Loading…
Reference in New Issue
Block a user