mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54403 A few important points about InferenceMode behavior: 1. All tensors created in InferenceMode are inference tensors except for view ops. - view ops produce output has the same is_inference_tensor property as their input. Namely view of normal tensor inside InferenceMode produce a normal tensor, which is exactly the same as creating a view inside NoGradMode. And view of inference tensor outside InferenceMode produce inference tensor as output. 2. All ops are allowed inside InferenceMode, faster than normal mode. 3. Inference tensor cannot be saved for backward. Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D27316483 Pulled By: ailzhang fbshipit-source-id: e03248a66d42e2d43cfe7ccb61e49cc4afb2923b
537 lines
19 KiB
C++
537 lines
19 KiB
C++
#include <torch/script.h>
|
|
#include <gtest/gtest.h>
|
|
#include <test/cpp/api/support.h>
|
|
|
|
using namespace torch::autograd;
|
|
using namespace torch::test;
|
|
|
|
namespace {
|
|
torch::Tensor functional_op(torch::Tensor& x) {
|
|
return x * x;
|
|
}
|
|
|
|
void inplace_op(torch::Tensor& x) {
|
|
x.mul_(1);
|
|
}
|
|
|
|
torch::Tensor view_op(torch::Tensor& x) {
|
|
return x.view({2, 3});
|
|
}
|
|
|
|
/*
|
|
Only the following combos of Autograd & InplaceOrView keys on tensors are valid:
|
|
- Autograd=true, InplaceOrView=true (normal tensor)
|
|
- Autograd=false, InplaceOrView=false (inference tensor)
|
|
Tensors created in InferenceMode are mostly inference tensors. The only exception
|
|
is that view of normal tensors created in InferenceMode still produce normal tensor.
|
|
*/
|
|
bool is_inference_tensor(torch::Tensor& x) {
|
|
c10::DispatchKeySet ks = x.key_set();
|
|
bool has_Autograd = ks.has(c10::DispatchKey::AutogradCPU);
|
|
bool has_InplaceOrView = ks.has(c10::DispatchKey::InplaceOrView);
|
|
// They must be either both true or false.
|
|
bool is_inference_tensor = !has_Autograd && !has_InplaceOrView && x.is_leaf();
|
|
return is_inference_tensor;
|
|
}
|
|
|
|
void assert_TLS_states(bool inference_mode) {
|
|
ASSERT_EQ(InferenceMode::is_enabled(), inference_mode);
|
|
ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded(c10::DispatchKey::InplaceOrView));
|
|
ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(c10::autograd_dispatch_keyset));
|
|
ASSERT_EQ(c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset), inference_mode);
|
|
ASSERT_EQ(c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::InplaceOrView), !inference_mode);
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestTLSState) {
|
|
assert_TLS_states(false);
|
|
{
|
|
InferenceMode guard;
|
|
assert_TLS_states(true);
|
|
{
|
|
InferenceMode guard(false);
|
|
assert_TLS_states(false);
|
|
}
|
|
assert_TLS_states(true);
|
|
}
|
|
assert_TLS_states(false);
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorCreation) {
|
|
{
|
|
InferenceMode guard;
|
|
// New tensor created through constructors are inference tensors.
|
|
torch::Tensor c = torch::ones({1, 2, 3});
|
|
ASSERT_FALSE(c.requires_grad());
|
|
ASSERT_TRUE(is_inference_tensor(c));
|
|
|
|
// requires_grad doesn't change inference tensor behavior inside InferenceMode.
|
|
torch::Tensor tmp = torch::ones({1, 2, 3}).set_requires_grad(true);
|
|
ASSERT_TRUE(tmp.requires_grad());
|
|
ASSERT_TRUE(is_inference_tensor(tmp));
|
|
|
|
tmp = torch::ones({1, 2, 3}).set_requires_grad(false);
|
|
ASSERT_FALSE(tmp.requires_grad());
|
|
ASSERT_TRUE(is_inference_tensor(tmp));
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestExistingAutogradSession) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
|
|
torch::Tensor a = s.clone();
|
|
|
|
// Save `a` in an existing autograd session
|
|
torch::Tensor out = a * a;
|
|
{
|
|
InferenceMode guard;
|
|
inplace_op(a);
|
|
}
|
|
// Performing backward should trigger error since `a`'s version has been bumped.
|
|
ASSERT_THROWS_WITH(out.backward(torch::ones_like(out)),
|
|
"one of the variables needed for gradient computation has been modified by an inplace operation")
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) {
|
|
c10::InferenceMode guard;
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
|
|
torch::Tensor func_out = functional_op(c); // go through kernels: CPU
|
|
ASSERT_TRUE(is_inference_tensor(func_out));
|
|
ASSERT_FALSE(func_out.requires_grad());
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorInInferenceModeInplaceOp) {
|
|
c10::InferenceMode guard;
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
|
|
inplace_op(c); // go through kernels: CPU
|
|
ASSERT_TRUE(is_inference_tensor(c));
|
|
ASSERT_EQ(c.requires_grad(), requires_grad);
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) {
|
|
c10::InferenceMode guard;
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
|
|
torch::Tensor view_out = view_op(c); // go through kernels: CPU
|
|
ASSERT_TRUE(is_inference_tensor(view_out));
|
|
// Note this is different from NoGradMode but makes sense.
|
|
ASSERT_FALSE(view_out.requires_grad());
|
|
ASSERT_FALSE(view_out.is_view());
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) {
|
|
torch::Tensor inference_tensor;
|
|
for (bool requires_grad: {true, false}) {
|
|
{
|
|
InferenceMode guard;
|
|
inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
}
|
|
|
|
// Due to issue #54614, this might run slower compared to InferenceMode since
|
|
// intermediate tensors are normal tensors, and they might dispatch to VariableType
|
|
// kernels. This is fine since users can easily fix it by moving
|
|
// it inside InferenceMode block.
|
|
torch::Tensor tmp = functional_op(inference_tensor); // go through kernels: InplaceOrView(fallthrough), CPU
|
|
ASSERT_FALSE(is_inference_tensor(tmp));
|
|
ASSERT_FALSE(tmp.requires_grad());
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorInNormalModeInplaceOp) {
|
|
torch::Tensor inference_tensor;
|
|
for (bool requires_grad: {true, false}) {
|
|
{
|
|
InferenceMode guard;
|
|
inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
}
|
|
ASSERT_THROWS_WITH(inplace_op(inference_tensor), // go through kernels: InplaceOrView, CPU
|
|
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) {
|
|
torch::Tensor inference_tensor;
|
|
for (bool requires_grad: {true, false}) {
|
|
{
|
|
InferenceMode guard;
|
|
inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
}
|
|
torch::Tensor out = view_op(inference_tensor); // go through kernels: InplaceOrView, CPU
|
|
ASSERT_TRUE(is_inference_tensor(out));
|
|
ASSERT_FALSE(out.requires_grad());
|
|
ASSERT_FALSE(out.is_view());
|
|
ASSERT_TRUE(out.is_leaf());
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) {
|
|
for (bool requires_grad: {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
|
|
{
|
|
c10::InferenceMode guard;
|
|
|
|
inplace_op(a); // go through kernels: InplaceOrView, CPU
|
|
ASSERT_FALSE(is_inference_tensor(a));
|
|
ASSERT_EQ(a.requires_grad(), requires_grad);
|
|
|
|
// inplace -> inplace
|
|
inplace_op(a); // go through kernels: InplaceOrView, CPU
|
|
ASSERT_FALSE(is_inference_tensor(a));
|
|
ASSERT_EQ(a.requires_grad(), requires_grad);
|
|
|
|
// inplace -> inplace -> view
|
|
torch::Tensor view_out = view_op(a); // go through kernels: InplaceOrView, CPU
|
|
ASSERT_FALSE(is_inference_tensor(view_out));
|
|
ASSERT_EQ(view_out.requires_grad(), requires_grad);
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) {
|
|
for (bool requires_grad: {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
|
|
{
|
|
c10::InferenceMode guard;
|
|
|
|
inplace_op(a); // go through kernels: InplaceOrView, CPU
|
|
ASSERT_FALSE(is_inference_tensor(a));
|
|
ASSERT_EQ(a.requires_grad(), requires_grad);
|
|
}
|
|
|
|
torch::Tensor tmp = functional_op(a); // go through kernels: VariableType, InplaceOrView(fallthrough), CPU
|
|
ASSERT_FALSE(is_inference_tensor(tmp));
|
|
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
|
|
|
inplace_op(a); // go through kernels: VariableType, InplaceOrView, CPU
|
|
ASSERT_FALSE(is_inference_tensor(a));
|
|
ASSERT_EQ(a.requires_grad(), requires_grad);
|
|
|
|
tmp = view_op(a); // go through kernels: VariableType, InplaceOrView, CPU
|
|
ASSERT_FALSE(is_inference_tensor(tmp));
|
|
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) {
|
|
for (bool requires_grad: {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
torch::Tensor view_out, tmp;
|
|
|
|
{
|
|
c10::InferenceMode guard;
|
|
// View ops on normal tensor produce normal tensors as output.
|
|
// - For view ops it has both dispatch keys since due to the way we create
|
|
// view Tensors in alias_with_sizes_and_strides:
|
|
// ```
|
|
// auto impl = c10::make_intrusive<TensorImpl>(
|
|
// Storage(self.storage()), self.key_set(), self.dtype());
|
|
// ```
|
|
// In addition, these view output tensors are normal in the sense they
|
|
// have both Autograd and InplaceOrView keys. But they're still special
|
|
// since they'll have CreationMeta::INFERENCE_MODE. In other words they behave
|
|
// exactly the same as a view tensor created in no_grad mode.
|
|
|
|
view_out = view_op(a); // go through kernels: InplaceOrView, CPU
|
|
ASSERT_FALSE(is_inference_tensor(view_out));
|
|
assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
|
|
ASSERT_EQ(view_out.requires_grad(), requires_grad);
|
|
ASSERT_TRUE(view_out.is_leaf());
|
|
|
|
// view -> view
|
|
tmp = view_op(view_out); // go through kernels: InplaceOrView, CPU
|
|
ASSERT_FALSE(is_inference_tensor(tmp));
|
|
assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
|
|
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
|
ASSERT_TRUE(tmp.is_leaf());
|
|
|
|
// view -> view -> inplace
|
|
inplace_op(tmp); // kernels: InplaceOrView, CPU
|
|
assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
|
|
ASSERT_FALSE(is_inference_tensor(tmp));
|
|
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
|
ASSERT_TRUE(tmp.is_leaf());
|
|
ASSERT_EQ(a._version(), tmp._version());
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) {
|
|
for (bool requires_grad: {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
torch::Tensor view_out, tmp;
|
|
|
|
{
|
|
c10::InferenceMode guard;
|
|
view_out = view_op(a); // go through kernels: InplaceOrView, CPU
|
|
ASSERT_FALSE(is_inference_tensor(view_out));
|
|
assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
|
|
ASSERT_EQ(view_out.requires_grad(), requires_grad);
|
|
ASSERT_TRUE(view_out.is_leaf());
|
|
}
|
|
|
|
tmp = functional_op(view_out);
|
|
ASSERT_FALSE(is_inference_tensor(view_out));
|
|
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
|
|
|
if (requires_grad) {
|
|
ASSERT_THROWS_WITH(inplace_op(view_out), // go through kernels: VariableType, InplaceOrView, CPU
|
|
"A view was created in inference mode and is being modified inplace")
|
|
} else {
|
|
inplace_op(view_out);
|
|
}
|
|
|
|
tmp = view_op(view_out);
|
|
ASSERT_FALSE(is_inference_tensor(view_out));
|
|
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) {
|
|
for (bool requires_grad: {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor c;
|
|
{
|
|
InferenceMode guard;
|
|
c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
}
|
|
|
|
// add(Tensor, Tensor) is safe with inference tensor since it doesn't save any variable for backward.
|
|
torch::Tensor out = c.add(s); // go through kernels: VariableType, InplaceOrView(fallthrough), CPU
|
|
ASSERT_FALSE(is_inference_tensor(out));
|
|
ASSERT_EQ(out.requires_grad(), requires_grad);
|
|
if (requires_grad) {
|
|
// leaf inference tensor with requires_grad=true can still have gradient.
|
|
// Note this behavior is different from NoGradMode which has empty grad.
|
|
out.backward(torch::ones_like(out));
|
|
assert_tensor_equal(c.grad(), torch::ones_like(c));
|
|
}
|
|
|
|
if (requires_grad) {
|
|
// mul(self, other) saves variable when requires_grad=true
|
|
ASSERT_THROWS_WITH(c.mul(s),
|
|
"Inference tensors cannot be saved for backward.");
|
|
|
|
// Inference tensor in TensorList input
|
|
std::vector<torch::Tensor> inputs = {s, c};
|
|
ASSERT_THROWS_WITH(torch::stack(inputs), // go through kernels: VariableType(ERROR)!, InplaceOrView(fallthrough), CPU
|
|
"Inference tensors cannot be saved for backward.")
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) {
|
|
for (bool requires_grad: {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
torch::Tensor c;
|
|
{
|
|
InferenceMode guard;
|
|
c = torch::ones({1, 2, 3});
|
|
}
|
|
|
|
if (requires_grad) {
|
|
ASSERT_THROWS_WITH(a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode, CPU
|
|
"Inference tensors cannot be saved for backward.");
|
|
|
|
ASSERT_THROWS_WITH(torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType(ERROR!), InplaceOrView, CPU
|
|
"out=... arguments don't support automatic differentiation, but one of the arguments requires grad")
|
|
} else {
|
|
a.mul_(c);
|
|
|
|
ASSERT_THROWS_WITH(torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType, InplaceOrView(ERROR!), CPU
|
|
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) {
|
|
for (bool requires_grad: {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor c;
|
|
{
|
|
InferenceMode guard;
|
|
c = torch::ones({1, 2, 3});
|
|
}
|
|
|
|
// view_as is a composite op which calls view() with only one tensor argument.
|
|
// So there isn't a mixed inference tensor and normal tensor inputs for view ops.
|
|
torch::Tensor tmp1 = c.view_as(s); // go through kernels: InplaceOrView, CPU
|
|
ASSERT_TRUE(is_inference_tensor(tmp1));
|
|
ASSERT_FALSE(tmp1.requires_grad());
|
|
|
|
// This is fine since it's equivalent as s.view(c.sizes()) which
|
|
// isn't a mixed input scenario.
|
|
torch::Tensor tmp2 = s.view_as(c); // go through kernels: VariableType, InplaceOrView, CPU
|
|
ASSERT_FALSE(is_inference_tensor(tmp2));
|
|
ASSERT_EQ(tmp2.requires_grad(), requires_grad);
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestHandleDirectViewOnRebase) {
|
|
for (bool requires_grad: {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
torch::Tensor view_out;
|
|
{
|
|
InferenceMode guard;
|
|
view_out = view_op(a); // go through kernels: InplaceOrView, CPU
|
|
}
|
|
if (requires_grad) {
|
|
ASSERT_THROWS_WITH(inplace_op(view_out),
|
|
"A view was created in inference mode and is being modified inplace")
|
|
} else {
|
|
inplace_op(view_out);
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestHandleInDirectViewOnRebase) {
|
|
for (bool requires_grad: {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
torch::Tensor view_out;
|
|
{
|
|
InferenceMode guard;
|
|
view_out = view_op(a); // go through kernels: InplaceOrView, CPU
|
|
}
|
|
inplace_op(a);
|
|
if (requires_grad) {
|
|
ASSERT_THROWS_WITH(view_out.grad_fn(),
|
|
"A view was created in inference mode and its base or another view of its base has been modified inplace");
|
|
} else {
|
|
view_out.grad_fn();
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestCreationMetaPropagation) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
|
|
torch::Tensor b, c;
|
|
{
|
|
InferenceMode guard;
|
|
b = s.view_as(s);
|
|
}
|
|
ASSERT_THROWS_WITH(b.add_(1),
|
|
"A view was created in inference mode and is being modified inplace");
|
|
{
|
|
AutoGradMode mode(false);
|
|
c = b.view_as(b);
|
|
}
|
|
ASSERT_THROWS_WITH(c.add_(1),
|
|
"A view was created in inference mode and is being modified inplace");
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) {
|
|
for (bool requires_grad: {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor t;
|
|
{
|
|
InferenceMode guard;
|
|
t = torch::ones({1, 2, 3});
|
|
t.copy_(s);
|
|
ASSERT_TRUE(is_inference_tensor(t));
|
|
ASSERT_FALSE(t.requires_grad());
|
|
}
|
|
|
|
ASSERT_THROWS_WITH(t.copy_(s),
|
|
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestSetRequiresGradInNormalMode) {
|
|
torch::Tensor t;
|
|
{
|
|
InferenceMode guard;
|
|
t = torch::ones({1, 2, 3});
|
|
}
|
|
t.set_requires_grad(false);
|
|
ASSERT_THROWS_WITH(t.set_requires_grad(true),
|
|
"Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestAccessVersionCounter) {
|
|
torch::Tensor t;
|
|
{
|
|
InferenceMode guard;
|
|
t = torch::ones({1, 2, 3});
|
|
ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->version_counter().current_version(),
|
|
"Inference tensor do not track version counter.");
|
|
t.unsafeGetTensorImpl()->bump_version();
|
|
}
|
|
ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->version_counter().current_version(),
|
|
"Inference tensor do not track version counter.");
|
|
ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->bump_version(),
|
|
"Inplace update to inference tensor outside InferenceMode is not allowed.");
|
|
// Suggested workaround
|
|
torch::Tensor c = t.clone();
|
|
uint32_t v = c.unsafeGetTensorImpl()->version_counter().current_version();
|
|
c.unsafeGetTensorImpl()->bump_version();
|
|
ASSERT_EQ(c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1);
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) {
|
|
torch::Tensor s = torch::ones({1, 2, 3});
|
|
torch::Tensor t;
|
|
{
|
|
InferenceMode guard;
|
|
t = torch::ones({1, 2, 3});
|
|
// Testing both copy_ from VariableTypeManual and add_ from generated code.
|
|
s.copy_(t);
|
|
s.add_(t);
|
|
t.add_(s);
|
|
t.copy_(s);
|
|
}
|
|
s.copy_(t);
|
|
s.add_(t);
|
|
ASSERT_THROWS_WITH(t.copy_(s),
|
|
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
|
|
|
ASSERT_THROWS_WITH(t.add_(s),
|
|
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestComplexViewInInferenceMode) {
|
|
torch::Tensor s = torch::ones({3, 3, 2});
|
|
torch::Tensor t = torch::view_as_complex(s);
|
|
{
|
|
InferenceMode guard;
|
|
torch::Tensor tmp;
|
|
|
|
tmp = torch::view_as_real(t);
|
|
ASSERT_FALSE(is_inference_tensor(tmp));
|
|
tmp = torch::view_as_complex(s);
|
|
ASSERT_FALSE(is_inference_tensor(tmp));
|
|
|
|
torch::Tensor e = torch::ones({3, 3, 2});
|
|
tmp = torch::view_as_complex(e);
|
|
ASSERT_TRUE(is_inference_tensor(tmp));
|
|
tmp = torch::view_as_real(tmp);
|
|
ASSERT_TRUE(is_inference_tensor(tmp));
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestComplexViewInNormalMode) {
|
|
torch::Tensor s;
|
|
{
|
|
InferenceMode guard;
|
|
s = torch::ones({3, 3, 2});
|
|
}
|
|
torch::Tensor tmp = torch::view_as_complex(s);
|
|
ASSERT_TRUE(is_inference_tensor(tmp));
|
|
tmp = torch::view_as_real(tmp);
|
|
ASSERT_TRUE(is_inference_tensor(tmp));
|
|
}
|