mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: **Summary** There is often a need to create a Tensor when writing IR by hand for JIT optimisation pass unit tests. The only options for this today are real Tensor creation functions like `aten::ones`. Any test that uses these functions must also use the same default arguments as the Python/C++ API, which means that all of the tests have to be updated when the API is updated. This commit introduces a new primitive, `prim::MakeTestTensor` with schema `() -> Tensor` that should be used in unit tests instead of real Tensor creation functions. This new primitive has no public-facing API, so the maintenance burden is much lower. **Testing** This commit updates the alias analysis and DCE tests to use `prim::MakeTestTensor` instead of `aten::rand`, `aten::ones`, and `aten::zeros`. ``` $ ./bin/test_jit CUDA not available. Disabling CUDA and MultiCUDA tests Note: Google Test filter = *-*_CUDA:*_MultiCUDA [==========] Running 75 tests from 1 test case. [----------] Global test environment set-up. [----------] 75 tests from JitTest [ RUN ] JitTest.ADFormulas [ OK ] JitTest.ADFormulas (82 ms) [ RUN ] JitTest.Attributes [ OK ] JitTest.Attributes (0 ms) ... ... ... [ RUN ] JitTest.LiteInterpreterPrim [ OK ] JitTest.LiteInterpreterPrim (0 ms) [ RUN ] JitTest.LiteInterpreterLoadOrigJit [ OK ] JitTest.LiteInterpreterLoadOrigJit (2 ms) [----------] 75 tests from JitTest (150 ms total) [----------] Global test environment tear-down [==========] 75 tests from 1 test case ran. (150 ms total) [ PASSED ] 75 tests. ``` **Fixes** This pull request fixes https://github.com/pytorch/pytorch/issues/33500. Pull Request resolved: https://github.com/pytorch/pytorch/pull/33595 Differential Revision: D20127441 Pulled By: SplitInfinity fbshipit-source-id: 56da4f23ac46335227254f606c6481718108f378
26 lines
582 B
C++
26 lines
582 B
C++
#include <test/cpp/jit/test_base.h>
|
|
#include <test/cpp/jit/test_utils.h>
|
|
|
|
#include "torch/csrc/jit/custom_operator.h"
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
inline c10::OperatorOptions aliasAnalysisFromSchema() {
|
|
c10::OperatorOptions result;
|
|
result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
|
|
return result;
|
|
}
|
|
|
|
RegisterOperators reg({
|
|
Operator(
|
|
"prim::MakeTestTensor() -> Tensor",
|
|
[](Stack& stack) {
|
|
push(stack, at::Tensor());
|
|
return 0;
|
|
},
|
|
aliasAnalysisFromSchema()),
|
|
});
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|