#include #include #include using namespace torch::autograd; using namespace torch::test; namespace { torch::Tensor functional_op(torch::Tensor& x) { return x * x; } void inplace_op(torch::Tensor& x) { x.mul_(1); } torch::Tensor view_op(torch::Tensor& x) { return x.view({2, 3}); } /* Only the following combos of Autograd & ADInplaceOrView keys on tensors are valid: - Autograd=true, ADInplaceOrView=true (normal tensor) - Autograd=false, ADInplaceOrView=false (inference tensor) Tensors created in InferenceMode are mostly inference tensors. The only exception is that view of normal tensors created in InferenceMode still produce normal tensor. */ bool is_inference_tensor(torch::Tensor& x) { c10::DispatchKeySet ks = x.key_set(); bool has_Autograd = ks.has(c10::DispatchKey::AutogradCPU); bool has_ADInplaceOrView = ks.has(c10::DispatchKey::ADInplaceOrView); // They must be either both true or false. bool is_inference_tensor = !has_Autograd && !has_ADInplaceOrView && x.is_leaf(); return is_inference_tensor; } void assert_TLS_states(bool inference_mode) { ASSERT_EQ(InferenceMode::is_enabled(), inference_mode); ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded(c10::DispatchKey::ADInplaceOrView)); ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(c10::autograd_dispatch_keyset)); ASSERT_EQ(c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset), inference_mode); ASSERT_EQ(c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::ADInplaceOrView), !inference_mode); ASSERT_EQ(GradMode::is_enabled(), !inference_mode); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestTLSState) { assert_TLS_states(false); { InferenceMode guard; assert_TLS_states(true); { InferenceMode guard(false); assert_TLS_states(false); } assert_TLS_states(true); } assert_TLS_states(false); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestInferenceTensorCreation) { { InferenceMode guard; // New tensor created through constructors are inference tensors. torch::Tensor c = torch::ones({1, 2, 3}); ASSERT_FALSE(c.requires_grad()); ASSERT_TRUE(is_inference_tensor(c)); // requires_grad doesn't change inference tensor behavior inside InferenceMode. torch::Tensor tmp = torch::ones({1, 2, 3}).set_requires_grad(true); ASSERT_TRUE(tmp.requires_grad()); ASSERT_TRUE(is_inference_tensor(tmp)); tmp = torch::ones({1, 2, 3}).set_requires_grad(false); ASSERT_FALSE(tmp.requires_grad()); ASSERT_TRUE(is_inference_tensor(tmp)); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestExistingAutogradSession) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true); torch::Tensor a = s.clone(); // Save `a` in an existing autograd session torch::Tensor out = a * a; { InferenceMode guard; inplace_op(a); } // Performing backward should trigger error since `a`'s version has been bumped. ASSERT_THROWS_WITH(out.backward(torch::ones_like(out)), "one of the variables needed for gradient computation has been modified by an inplace operation") } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) { c10::InferenceMode guard; for (bool requires_grad : {true, false}) { torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor func_out = functional_op(c); // go through kernels: CPU ASSERT_TRUE(is_inference_tensor(func_out)); ASSERT_FALSE(func_out.requires_grad()); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestInferenceTensorInInferenceModeInplaceOp) { c10::InferenceMode guard; for (bool requires_grad : {true, false}) { torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); inplace_op(c); // go through kernels: CPU ASSERT_TRUE(is_inference_tensor(c)); ASSERT_EQ(c.requires_grad(), requires_grad); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) { c10::InferenceMode guard; for (bool requires_grad : {true, false}) { torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor view_out = view_op(c); // go through kernels: CPU ASSERT_TRUE(is_inference_tensor(view_out)); // Note this is different from NoGradMode but makes sense. ASSERT_FALSE(view_out.requires_grad()); ASSERT_FALSE(view_out.is_view()); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) { torch::Tensor inference_tensor; for (bool requires_grad: {true, false}) { { InferenceMode guard; inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); } // Due to issue #54614, this might run slower compared to InferenceMode since // intermediate tensors are normal tensors, and they might dispatch to VariableType // kernels. This is fine since users can easily fix it by moving // it inside InferenceMode block. torch::Tensor tmp = functional_op(inference_tensor); // go through kernels: ADInplaceOrView(fallthrough), CPU ASSERT_FALSE(is_inference_tensor(tmp)); ASSERT_FALSE(tmp.requires_grad()); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestInferenceTensorInNormalModeInplaceOp) { torch::Tensor inference_tensor; for (bool requires_grad: {true, false}) { { InferenceMode guard; inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); } ASSERT_THROWS_WITH(inplace_op(inference_tensor), // go through kernels: ADInplaceOrView, CPU "Inplace update to inference tensor outside InferenceMode is not allowed"); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) { torch::Tensor inference_tensor; for (bool requires_grad: {true, false}) { { InferenceMode guard; inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); } torch::Tensor out = view_op(inference_tensor); // go through kernels: ADInplaceOrView, CPU ASSERT_TRUE(is_inference_tensor(out)); ASSERT_FALSE(out.requires_grad()); ASSERT_FALSE(out.is_view()); ASSERT_TRUE(out.is_leaf()); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) { for (bool requires_grad: {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); { c10::InferenceMode guard; inplace_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(is_inference_tensor(a)); ASSERT_EQ(a.requires_grad(), requires_grad); // inplace -> inplace inplace_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(is_inference_tensor(a)); ASSERT_EQ(a.requires_grad(), requires_grad); // inplace -> inplace -> view torch::Tensor view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(is_inference_tensor(view_out)); ASSERT_EQ(view_out.requires_grad(), requires_grad); } } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) { for (bool requires_grad: {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); { c10::InferenceMode guard; inplace_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(is_inference_tensor(a)); ASSERT_EQ(a.requires_grad(), requires_grad); } torch::Tensor tmp = functional_op(a); // go through kernels: VariableType, ADInplaceOrView(fallthrough), CPU ASSERT_FALSE(is_inference_tensor(tmp)); ASSERT_EQ(tmp.requires_grad(), requires_grad); inplace_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU ASSERT_FALSE(is_inference_tensor(a)); ASSERT_EQ(a.requires_grad(), requires_grad); tmp = view_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU ASSERT_FALSE(is_inference_tensor(tmp)); ASSERT_EQ(tmp.requires_grad(), requires_grad); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) { for (bool requires_grad: {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor view_out, tmp; { c10::InferenceMode guard; // View ops on normal tensor produce normal tensors as output. // - For view ops it has both dispatch keys since due to the way we create // view Tensors in alias_with_sizes_and_strides: // ``` // auto impl = c10::make_intrusive( // Storage(self.storage()), self.key_set(), self.dtype()); // ``` // In addition, these view output tensors are normal in the sense they // have both Autograd and ADInplaceOrView keys. But they're still special // since they'll have CreationMeta::INFERENCE_MODE. In other words they behave // exactly the same as a view tensor created in no_grad mode. view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(is_inference_tensor(view_out)); assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE); ASSERT_EQ(view_out.requires_grad(), requires_grad); ASSERT_TRUE(view_out.is_leaf()); // view -> view tmp = view_op(view_out); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(is_inference_tensor(tmp)); assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE); ASSERT_EQ(tmp.requires_grad(), requires_grad); ASSERT_TRUE(tmp.is_leaf()); // view -> view -> inplace inplace_op(tmp); // kernels: ADInplaceOrView, CPU assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE); ASSERT_FALSE(is_inference_tensor(tmp)); ASSERT_EQ(tmp.requires_grad(), requires_grad); ASSERT_TRUE(tmp.is_leaf()); ASSERT_EQ(a._version(), tmp._version()); } } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) { for (bool requires_grad: {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor view_out, tmp; { c10::InferenceMode guard; view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(is_inference_tensor(view_out)); assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE); ASSERT_EQ(view_out.requires_grad(), requires_grad); ASSERT_TRUE(view_out.is_leaf()); } tmp = functional_op(view_out); ASSERT_FALSE(is_inference_tensor(view_out)); ASSERT_EQ(tmp.requires_grad(), requires_grad); if (requires_grad) { ASSERT_THROWS_WITH(inplace_op(view_out), // go through kernels: VariableType, ADInplaceOrView, CPU "A view was created in inference mode and is being modified inplace") } else { inplace_op(view_out); } tmp = view_op(view_out); ASSERT_FALSE(is_inference_tensor(view_out)); ASSERT_EQ(tmp.requires_grad(), requires_grad); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) { for (bool requires_grad: {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor c; { InferenceMode guard; c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); } // add(Tensor, Tensor) is safe with inference tensor since it doesn't save any variable for backward. torch::Tensor out = c.add(s); // go through kernels: VariableType, ADInplaceOrView(fallthrough), CPU ASSERT_FALSE(is_inference_tensor(out)); ASSERT_EQ(out.requires_grad(), requires_grad); if (requires_grad) { // leaf inference tensor with requires_grad=true can still have gradient. // Note this behavior is different from NoGradMode which has empty grad. out.backward(torch::ones_like(out)); assert_tensor_equal(c.grad(), torch::ones_like(c)); } if (requires_grad) { // mul(self, other) saves variable when requires_grad=true ASSERT_THROWS_WITH(c.mul(s), "Inference tensors cannot be saved for backward."); // Inference tensor in TensorList input std::vector inputs = {s, c}; ASSERT_THROWS_WITH(torch::stack(inputs), // go through kernels: VariableType(ERROR)!, ADInplaceOrView(fallthrough), CPU "Inference tensors cannot be saved for backward.") } } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) { for (bool requires_grad: {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor c; { InferenceMode guard; c = torch::ones({1, 2, 3}); } if (requires_grad) { ASSERT_THROWS_WITH(a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode, CPU "Inference tensors cannot be saved for backward."); ASSERT_THROWS_WITH(torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType(ERROR!), ADInplaceOrView, CPU "out=... arguments don't support automatic differentiation, but one of the arguments requires grad") } else { a.mul_(c); ASSERT_THROWS_WITH(torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType, ADInplaceOrView(ERROR!), CPU "Inplace update to inference tensor outside InferenceMode is not allowed"); } } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) { for (bool requires_grad: {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor c; { InferenceMode guard; c = torch::ones({1, 2, 3}); } // view_as is a composite op which calls view() with only one tensor argument. // So there isn't a mixed inference tensor and normal tensor inputs for view ops. torch::Tensor tmp1 = c.view_as(s); // go through kernels: ADInplaceOrView, CPU ASSERT_TRUE(is_inference_tensor(tmp1)); ASSERT_FALSE(tmp1.requires_grad()); // This is fine since it's equivalent as s.view(c.sizes()) which // isn't a mixed input scenario. torch::Tensor tmp2 = s.view_as(c); // go through kernels: VariableType, ADInplaceOrView, CPU ASSERT_FALSE(is_inference_tensor(tmp2)); ASSERT_EQ(tmp2.requires_grad(), requires_grad); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestHandleDirectViewOnRebase) { for (bool requires_grad: {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor view_out; { InferenceMode guard; view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU } if (requires_grad) { ASSERT_THROWS_WITH(inplace_op(view_out), "A view was created in inference mode and is being modified inplace") } else { inplace_op(view_out); } } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestHandleInDirectViewOnRebase) { for (bool requires_grad: {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor view_out; { InferenceMode guard; view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU } inplace_op(a); if (requires_grad) { ASSERT_THROWS_WITH(view_out.grad_fn(), "A view was created in inference mode and its base or another view of its base has been modified inplace"); } else { view_out.grad_fn(); } } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestCreationMetaPropagation) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true); torch::Tensor b, c; { InferenceMode guard; b = s.view_as(s); } ASSERT_THROWS_WITH(b.add_(1), "A view was created in inference mode and is being modified inplace"); { AutoGradMode mode(false); c = b.view_as(b); } ASSERT_THROWS_WITH(c.add_(1), "A view was created in inference mode and is being modified inplace"); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) { for (bool requires_grad: {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor t; { InferenceMode guard; t = torch::ones({1, 2, 3}); t.copy_(s); ASSERT_TRUE(is_inference_tensor(t)); ASSERT_FALSE(t.requires_grad()); } ASSERT_THROWS_WITH(t.copy_(s), "Inplace update to inference tensor outside InferenceMode is not allowed"); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestSetRequiresGradInNormalMode) { torch::Tensor t; { InferenceMode guard; t = torch::ones({1, 2, 3}); } t.set_requires_grad(false); ASSERT_THROWS_WITH(t.set_requires_grad(true), "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed."); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestAccessVersionCounter) { torch::Tensor t; { InferenceMode guard; t = torch::ones({1, 2, 3}); ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->version_counter().current_version(), "Inference tensor do not track version counter."); t.unsafeGetTensorImpl()->bump_version(); } ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->version_counter().current_version(), "Inference tensor do not track version counter."); ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->bump_version(), "Inplace update to inference tensor outside InferenceMode is not allowed."); // Suggested workaround torch::Tensor c = t.clone(); uint32_t v = c.unsafeGetTensorImpl()->version_counter().current_version(); c.unsafeGetTensorImpl()->bump_version(); ASSERT_EQ(c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) { torch::Tensor s = torch::ones({1, 2, 3}); torch::Tensor t; { InferenceMode guard; t = torch::ones({1, 2, 3}); // Testing both copy_ from VariableTypeManual and add_ from generated code. s.copy_(t); s.add_(t); t.add_(s); t.copy_(s); } s.copy_(t); s.add_(t); ASSERT_THROWS_WITH(t.copy_(s), "Inplace update to inference tensor outside InferenceMode is not allowed"); ASSERT_THROWS_WITH(t.add_(s), "Inplace update to inference tensor outside InferenceMode is not allowed"); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestComplexViewInInferenceMode) { torch::Tensor s = torch::ones({3, 3, 2}); torch::Tensor t = torch::view_as_complex(s); { InferenceMode guard; torch::Tensor tmp; tmp = torch::view_as_real(t); ASSERT_FALSE(is_inference_tensor(tmp)); tmp = torch::view_as_complex(s); ASSERT_FALSE(is_inference_tensor(tmp)); torch::Tensor e = torch::ones({3, 3, 2}); tmp = torch::view_as_complex(e); ASSERT_TRUE(is_inference_tensor(tmp)); tmp = torch::view_as_real(tmp); ASSERT_TRUE(is_inference_tensor(tmp)); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestComplexViewInNormalMode) { torch::Tensor s; { InferenceMode guard; s = torch::ones({3, 3, 2}); } torch::Tensor tmp = torch::view_as_complex(s); ASSERT_TRUE(is_inference_tensor(tmp)); tmp = torch::view_as_real(tmp); ASSERT_TRUE(is_inference_tensor(tmp)); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestCustomFunction) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) { ctx->saved_data["mul"] = mul; ctx->save_for_backward({var1, var2}); return var1 + mul*var2 + var1*var2; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { int mul = ctx->saved_data["mul"].toInt(); auto saved = ctx->get_saved_variables(); auto var1 = saved[0]; auto var2 = saved[1]; variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1}; return output; } }; { InferenceMode guard; torch::Tensor var1 = torch::ones({3, 3}).set_requires_grad(true); auto var2 = var1.clone(); int mul = 2; // If InferenceMode didn't set NoGradGuard automatically, this line // would error out when trying to save `var1` and `var2` for backward. auto y = MyFunction::apply(var1, mul, var2); torch::Tensor expected = var1 + mul * var2 + var1 * var2; assert_tensor_equal(y, expected); } } // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(InferenceModeTest, TestLegacyAutoNonVariableTypeModeWarning) { bool prev = c10::Warning::get_warnAlways(); c10::Warning::set_warnAlways(true); { WarningCapture warnings; at::AutoNonVariableTypeMode guard; ASSERT_TRUE( warnings.str().find("AutoNonVariableTypeMode is deprecated") != std::string::npos); } c10::Warning::set_warnAlways(prev); }