[JIT] Introduce a fake Tensor creation node for IR unit tests (#33595)

Summary:
**Summary**
There is often a need to create a Tensor when writing IR by hand for JIT
optimisation pass unit tests. The only options for this today are real
Tensor creation functions like `aten::ones`. Any test that uses these functions
must also use the same default arguments as the Python/C++ API, which means
that all of the tests have to be updated when the API is updated. This commit
introduces a new primitive, `prim::MakeTestTensor` with schema `() -> Tensor` that
should be used in unit tests instead of real Tensor creation functions. This new
primitive has no public-facing API, so the maintenance burden is much lower.

**Testing**
This commit updates the alias analysis and DCE tests to use `prim::MakeTestTensor` instead of
`aten::rand`, `aten::ones`, and `aten::zeros`.

```
$ ./bin/test_jit
CUDA not available. Disabling CUDA and MultiCUDA tests
Note: Google Test filter = *-*_CUDA:*_MultiCUDA
[==========] Running 75 tests from 1 test case.
[----------] Global test environment set-up.
[----------] 75 tests from JitTest
[ RUN      ] JitTest.ADFormulas
[       OK ] JitTest.ADFormulas (82 ms)
[ RUN      ] JitTest.Attributes
[       OK ] JitTest.Attributes (0 ms)
...
...
...
[ RUN      ] JitTest.LiteInterpreterPrim
[       OK ] JitTest.LiteInterpreterPrim (0 ms)
[ RUN      ] JitTest.LiteInterpreterLoadOrigJit
[       OK ] JitTest.LiteInterpreterLoadOrigJit (2 ms)
[----------] 75 tests from JitTest (150 ms total)

[----------] Global test environment tear-down
[==========] 75 tests from 1 test case ran. (150 ms total)
[  PASSED  ] 75 tests.
```

**Fixes**
This pull request fixes https://github.com/pytorch/pytorch/issues/33500.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33595

Differential Revision: D20127441

Pulled By: SplitInfinity

fbshipit-source-id: 56da4f23ac46335227254f606c6481718108f378
This commit is contained in:
Meghan Lele 2020-02-27 12:29:00 -08:00 committed by Facebook Github Bot
parent dbe850af5b
commit 390d4d6df3
4 changed files with 41 additions and 34 deletions

View File

@ -78,6 +78,7 @@ namespace c10 {
_(prim, dtype) \
_(prim, shape) \
_(prim, requires_grad) \
_(prim, MakeTestTensor) /* test */ \
_(prim, AutogradAdd) \
_(prim, GradOf) \
_(aten, grad) \

View File

@ -659,13 +659,9 @@ graph(%x : int,
script::parseIR(
R"IR(
graph():
%4 : Device? = prim::Constant()
%2 : int? = prim::Constant()
%0 : float = prim::Constant[value=1]()
%20 : bool = prim::Constant[value=0]()
%a : Tensor = aten::tensor(%0, %2, %4, %20)
%a : Tensor = prim::MakeTestTensor()
%a_list : Tensor[] = prim::ListConstruct(%a)
%b : Tensor = aten::tensor(%0, %2, %4, %20)
%b : Tensor = prim::MakeTestTensor()
%b_list : Tensor[] = prim::ListConstruct(%b)
%13 : (Tensor[], Tensor[]) = prim::TupleConstruct(%a_list, %b_list)
return (%13)
@ -746,19 +742,16 @@ graph():
script::parseIR(
R"IR(
graph():
%10 : bool? = prim::Constant()
%8 : Device? = prim::Constant()
%4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%2 : int[] = prim::ListConstruct(%0, %1)
%x : Tensor = aten::rand(%2, %4, %4, %8, %10)
%x : Tensor = prim::MakeTestTensor()
%12 : int[] = prim::ListConstruct(%0, %1)
%y : Tensor = aten::rand(%12, %4, %4, %8, %10)
%y : Tensor = prim::MakeTestTensor()
%22 : int[] = prim::ListConstruct(%0, %1)
%z : Tensor = aten::rand(%22, %4, %4, %8, %10)
%z : Tensor = prim::MakeTestTensor()
%32 : int[] = prim::ListConstruct(%0, %1)
%fresh : Tensor = aten::rand(%32, %4, %4, %8, %10)
%fresh : Tensor = prim::MakeTestTensor()
%foo : Tensor[] = prim::ListConstruct(%x, %y)
%43 : Tensor[] = aten::append(%foo, %z)
return ()
@ -791,13 +784,10 @@ graph():
script::parseIR(
R"IR(
graph():
%10 : bool? = prim::Constant()
%8 : Device? = prim::Constant()
%4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%2 : int[] = prim::ListConstruct(%0, %1)
%11 : Tensor = aten::rand(%2, %4, %4, %8, %10)
%11 : Tensor = prim::MakeTestTensor()
%12 : Tensor[] = prim::ListConstruct(%11)
%out : Tensor[] = custom::conservative(%12)
%ret.2 : Tensor = aten::div(%11, %11)
@ -826,20 +816,17 @@ graph():
R"IR(
graph():
%35 : int = prim::Constant[value=1]()
%10 : bool? = prim::Constant()
%8 : Device? = prim::Constant()
%4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%23 : int = prim::Constant[value=0]()
%2 : int[] = prim::ListConstruct(%0, %1)
%11 : Tensor = aten::rand(%2, %4, %4, %8, %10)
%11 : Tensor = prim::MakeTestTensor()
%12 : int[] = prim::ListConstruct(%0, %1)
%21 : Tensor = aten::rand(%12, %4, %4, %8, %10)
%21 : Tensor = prim::MakeTestTensor()
%l : Tensor[] = prim::ListConstruct(%11, %21)
%24 : Tensor = aten::select(%l, %23)
%25 : int[] = prim::ListConstruct(%0, %1)
%34 : Tensor = aten::rand(%25, %4, %4, %8, %10)
%34 : Tensor = prim::MakeTestTensor()
%36 : Tensor = aten::add_(%24, %34, %35)
%37 : Tensor = uses::list(%l)
return (%37)
@ -868,21 +855,18 @@ graph():
R"IR(
graph():
%38 : int = prim::Constant[value=1]()
%10 : bool? = prim::Constant()
%8 : Device? = prim::Constant()
%4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%24 : int = prim::Constant[value=0]()
%2 : int[] = prim::ListConstruct(%0, %1)
%11 : Tensor = aten::rand(%2, %4, %4, %8, %10)
%11 : Tensor = prim::MakeTestTensor()
%12 : int[] = prim::ListConstruct(%0, %1)
%21 : Tensor = aten::rand(%12, %4, %4, %8, %10)
%21 : Tensor = prim::MakeTestTensor()
%l : Tensor[] = prim::ListConstruct(%11, %21)
%25 : Tensor = aten::select(%l, %24)
%27 : Tensor = aten::select(%25, %24, %24)
%28 : int[] = prim::ListConstruct(%0, %1)
%37 : Tensor = aten::rand(%28, %4, %4, %8, %10)
%37 : Tensor = prim::MakeTestTensor()
%39 : Tensor = aten::add_(%27, %37, %38)
%40 : Tensor = uses::list(%l)
return (%40)

View File

@ -0,0 +1,25 @@
#include <test/cpp/jit/test_base.h>
#include <test/cpp/jit/test_utils.h>
#include "torch/csrc/jit/custom_operator.h"
namespace torch {
namespace jit {
inline c10::OperatorOptions aliasAnalysisFromSchema() {
c10::OperatorOptions result;
result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
return result;
}
RegisterOperators reg({
Operator(
"prim::MakeTestTensor() -> Tensor",
[](Stack& stack) {
push(stack, at::Tensor());
return 0;
},
aliasAnalysisFromSchema()),
});
} // namespace jit
} // namespace torch

View File

@ -22,17 +22,14 @@ void testDCE() {
graph():
%48 : None = prim::Constant()
%50 : bool = prim::Constant[value=1]()
%10 : bool? = prim::Constant()
%8 : Device? = prim::Constant()
%4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%12 : int = prim::Constant[value=1]()
%24 : int = prim::Constant[value=3]()
%31 : int = prim::Constant[value=0]()
%2 : int[] = prim::ListConstruct(%0, %0)
%a.1 : Tensor = aten::ones(%2, %4, %4, %8, %10)
%a.1 : Tensor = prim::MakeTestTensor()
%14 : int[] = prim::ListConstruct(%12)
%tot.1 : Tensor = aten::zeros(%14, %4, %4, %8, %10)
%tot.1 : Tensor = prim::MakeTestTensor()
%tot : Tensor = prim::Loop(%24, %50, %tot.1)
block0(%i : int, %tot.6 : Tensor):
%33 : Tensor = aten::select(%a.1, %31, %31)