mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[JIT] Introduce a fake Tensor creation node for IR unit tests (#33595)
Summary: **Summary** There is often a need to create a Tensor when writing IR by hand for JIT optimisation pass unit tests. The only options for this today are real Tensor creation functions like `aten::ones`. Any test that uses these functions must also use the same default arguments as the Python/C++ API, which means that all of the tests have to be updated when the API is updated. This commit introduces a new primitive, `prim::MakeTestTensor` with schema `() -> Tensor` that should be used in unit tests instead of real Tensor creation functions. This new primitive has no public-facing API, so the maintenance burden is much lower. **Testing** This commit updates the alias analysis and DCE tests to use `prim::MakeTestTensor` instead of `aten::rand`, `aten::ones`, and `aten::zeros`. ``` $ ./bin/test_jit CUDA not available. Disabling CUDA and MultiCUDA tests Note: Google Test filter = *-*_CUDA:*_MultiCUDA [==========] Running 75 tests from 1 test case. [----------] Global test environment set-up. [----------] 75 tests from JitTest [ RUN ] JitTest.ADFormulas [ OK ] JitTest.ADFormulas (82 ms) [ RUN ] JitTest.Attributes [ OK ] JitTest.Attributes (0 ms) ... ... ... [ RUN ] JitTest.LiteInterpreterPrim [ OK ] JitTest.LiteInterpreterPrim (0 ms) [ RUN ] JitTest.LiteInterpreterLoadOrigJit [ OK ] JitTest.LiteInterpreterLoadOrigJit (2 ms) [----------] 75 tests from JitTest (150 ms total) [----------] Global test environment tear-down [==========] 75 tests from 1 test case ran. (150 ms total) [ PASSED ] 75 tests. ``` **Fixes** This pull request fixes https://github.com/pytorch/pytorch/issues/33500. Pull Request resolved: https://github.com/pytorch/pytorch/pull/33595 Differential Revision: D20127441 Pulled By: SplitInfinity fbshipit-source-id: 56da4f23ac46335227254f606c6481718108f378
This commit is contained in:
parent
dbe850af5b
commit
390d4d6df3
|
|
@ -78,6 +78,7 @@ namespace c10 {
|
|||
_(prim, dtype) \
|
||||
_(prim, shape) \
|
||||
_(prim, requires_grad) \
|
||||
_(prim, MakeTestTensor) /* test */ \
|
||||
_(prim, AutogradAdd) \
|
||||
_(prim, GradOf) \
|
||||
_(aten, grad) \
|
||||
|
|
|
|||
|
|
@ -659,13 +659,9 @@ graph(%x : int,
|
|||
script::parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%4 : Device? = prim::Constant()
|
||||
%2 : int? = prim::Constant()
|
||||
%0 : float = prim::Constant[value=1]()
|
||||
%20 : bool = prim::Constant[value=0]()
|
||||
%a : Tensor = aten::tensor(%0, %2, %4, %20)
|
||||
%a : Tensor = prim::MakeTestTensor()
|
||||
%a_list : Tensor[] = prim::ListConstruct(%a)
|
||||
%b : Tensor = aten::tensor(%0, %2, %4, %20)
|
||||
%b : Tensor = prim::MakeTestTensor()
|
||||
%b_list : Tensor[] = prim::ListConstruct(%b)
|
||||
%13 : (Tensor[], Tensor[]) = prim::TupleConstruct(%a_list, %b_list)
|
||||
return (%13)
|
||||
|
|
@ -746,19 +742,16 @@ graph():
|
|||
script::parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%10 : bool? = prim::Constant()
|
||||
%8 : Device? = prim::Constant()
|
||||
%4 : int? = prim::Constant()
|
||||
%0 : int = prim::Constant[value=2]()
|
||||
%1 : int = prim::Constant[value=3]()
|
||||
%2 : int[] = prim::ListConstruct(%0, %1)
|
||||
%x : Tensor = aten::rand(%2, %4, %4, %8, %10)
|
||||
%x : Tensor = prim::MakeTestTensor()
|
||||
%12 : int[] = prim::ListConstruct(%0, %1)
|
||||
%y : Tensor = aten::rand(%12, %4, %4, %8, %10)
|
||||
%y : Tensor = prim::MakeTestTensor()
|
||||
%22 : int[] = prim::ListConstruct(%0, %1)
|
||||
%z : Tensor = aten::rand(%22, %4, %4, %8, %10)
|
||||
%z : Tensor = prim::MakeTestTensor()
|
||||
%32 : int[] = prim::ListConstruct(%0, %1)
|
||||
%fresh : Tensor = aten::rand(%32, %4, %4, %8, %10)
|
||||
%fresh : Tensor = prim::MakeTestTensor()
|
||||
%foo : Tensor[] = prim::ListConstruct(%x, %y)
|
||||
%43 : Tensor[] = aten::append(%foo, %z)
|
||||
return ()
|
||||
|
|
@ -791,13 +784,10 @@ graph():
|
|||
script::parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%10 : bool? = prim::Constant()
|
||||
%8 : Device? = prim::Constant()
|
||||
%4 : int? = prim::Constant()
|
||||
%0 : int = prim::Constant[value=2]()
|
||||
%1 : int = prim::Constant[value=3]()
|
||||
%2 : int[] = prim::ListConstruct(%0, %1)
|
||||
%11 : Tensor = aten::rand(%2, %4, %4, %8, %10)
|
||||
%11 : Tensor = prim::MakeTestTensor()
|
||||
%12 : Tensor[] = prim::ListConstruct(%11)
|
||||
%out : Tensor[] = custom::conservative(%12)
|
||||
%ret.2 : Tensor = aten::div(%11, %11)
|
||||
|
|
@ -826,20 +816,17 @@ graph():
|
|||
R"IR(
|
||||
graph():
|
||||
%35 : int = prim::Constant[value=1]()
|
||||
%10 : bool? = prim::Constant()
|
||||
%8 : Device? = prim::Constant()
|
||||
%4 : int? = prim::Constant()
|
||||
%0 : int = prim::Constant[value=2]()
|
||||
%1 : int = prim::Constant[value=3]()
|
||||
%23 : int = prim::Constant[value=0]()
|
||||
%2 : int[] = prim::ListConstruct(%0, %1)
|
||||
%11 : Tensor = aten::rand(%2, %4, %4, %8, %10)
|
||||
%11 : Tensor = prim::MakeTestTensor()
|
||||
%12 : int[] = prim::ListConstruct(%0, %1)
|
||||
%21 : Tensor = aten::rand(%12, %4, %4, %8, %10)
|
||||
%21 : Tensor = prim::MakeTestTensor()
|
||||
%l : Tensor[] = prim::ListConstruct(%11, %21)
|
||||
%24 : Tensor = aten::select(%l, %23)
|
||||
%25 : int[] = prim::ListConstruct(%0, %1)
|
||||
%34 : Tensor = aten::rand(%25, %4, %4, %8, %10)
|
||||
%34 : Tensor = prim::MakeTestTensor()
|
||||
%36 : Tensor = aten::add_(%24, %34, %35)
|
||||
%37 : Tensor = uses::list(%l)
|
||||
return (%37)
|
||||
|
|
@ -868,21 +855,18 @@ graph():
|
|||
R"IR(
|
||||
graph():
|
||||
%38 : int = prim::Constant[value=1]()
|
||||
%10 : bool? = prim::Constant()
|
||||
%8 : Device? = prim::Constant()
|
||||
%4 : int? = prim::Constant()
|
||||
%0 : int = prim::Constant[value=2]()
|
||||
%1 : int = prim::Constant[value=3]()
|
||||
%24 : int = prim::Constant[value=0]()
|
||||
%2 : int[] = prim::ListConstruct(%0, %1)
|
||||
%11 : Tensor = aten::rand(%2, %4, %4, %8, %10)
|
||||
%11 : Tensor = prim::MakeTestTensor()
|
||||
%12 : int[] = prim::ListConstruct(%0, %1)
|
||||
%21 : Tensor = aten::rand(%12, %4, %4, %8, %10)
|
||||
%21 : Tensor = prim::MakeTestTensor()
|
||||
%l : Tensor[] = prim::ListConstruct(%11, %21)
|
||||
%25 : Tensor = aten::select(%l, %24)
|
||||
%27 : Tensor = aten::select(%25, %24, %24)
|
||||
%28 : int[] = prim::ListConstruct(%0, %1)
|
||||
%37 : Tensor = aten::rand(%28, %4, %4, %8, %10)
|
||||
%37 : Tensor = prim::MakeTestTensor()
|
||||
%39 : Tensor = aten::add_(%27, %37, %38)
|
||||
%40 : Tensor = uses::list(%l)
|
||||
return (%40)
|
||||
|
|
|
|||
25
test/cpp/jit/test_base.cpp
Normal file
25
test/cpp/jit/test_base.cpp
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#include <test/cpp/jit/test_base.h>
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
#include "torch/csrc/jit/custom_operator.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
inline c10::OperatorOptions aliasAnalysisFromSchema() {
|
||||
c10::OperatorOptions result;
|
||||
result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
|
||||
return result;
|
||||
}
|
||||
|
||||
RegisterOperators reg({
|
||||
Operator(
|
||||
"prim::MakeTestTensor() -> Tensor",
|
||||
[](Stack& stack) {
|
||||
push(stack, at::Tensor());
|
||||
return 0;
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
});
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -22,17 +22,14 @@ void testDCE() {
|
|||
graph():
|
||||
%48 : None = prim::Constant()
|
||||
%50 : bool = prim::Constant[value=1]()
|
||||
%10 : bool? = prim::Constant()
|
||||
%8 : Device? = prim::Constant()
|
||||
%4 : int? = prim::Constant()
|
||||
%0 : int = prim::Constant[value=2]()
|
||||
%12 : int = prim::Constant[value=1]()
|
||||
%24 : int = prim::Constant[value=3]()
|
||||
%31 : int = prim::Constant[value=0]()
|
||||
%2 : int[] = prim::ListConstruct(%0, %0)
|
||||
%a.1 : Tensor = aten::ones(%2, %4, %4, %8, %10)
|
||||
%a.1 : Tensor = prim::MakeTestTensor()
|
||||
%14 : int[] = prim::ListConstruct(%12)
|
||||
%tot.1 : Tensor = aten::zeros(%14, %4, %4, %8, %10)
|
||||
%tot.1 : Tensor = prim::MakeTestTensor()
|
||||
%tot : Tensor = prim::Loop(%24, %50, %tot.1)
|
||||
block0(%i : int, %tot.6 : Tensor):
|
||||
%33 : Tensor = aten::select(%a.1, %31, %31)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user