mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add a warning when a tensor with requires_grad=True is converted to a scalar (#143261)
Fixes #143071 Operations performed on tensors with `requires_grad=True` such as ```python import torch x = torch.tensor(2.0, requires_grad=True) y = x ** 3 ``` and ```python x = torch.tensor(2.0, requires_grad=True) y = torch.pow(x,3) ``` are valid operations. While an operation using `numpy` like ```python import numpy as np x = torch.tensor(2.0, requires_grad=True) y = np.pow(x,3) # > RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead. ``` leads to an error. However, an operation that uses `math` like ```python import math x = torch.tensor(2.0, requires_grad=True) y = math.pow(x,3) ``` does not cause an error, and `y` is no longer a tensor with a gradient! This represents a [footgun](https://en.wiktionary.org/wiki/footgun#Noun) for some users, like myself when training small, custom, non-neural network models. To prevent future undesired behavior, I added a warning when converting tensors with `requires_grad=True` to scalars. Now, when using `math.pow` on a `tensor`, we get a single warning with: ```python x = torch.tensor(2.0, requires_grad=True) y = math.pow(x,3) # > UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior. # Consider using tensor.detach() first. ``` Please let me know if you have any questions 👍 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143261 Approved by: https://github.com/malfet Co-authored-by: albanD <desmaison.alban@gmail.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
parent
49b7d0d84d
commit
4ce0b959ff
|
|
@ -11,11 +11,17 @@
|
|||
#include <ATen/ops/item_native.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/core/grad_mode.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
Scalar item(const Tensor& self) {
|
||||
auto numel = self.sym_numel();
|
||||
TORCH_CHECK(numel == 1, "a Tensor with ", numel, " elements cannot be converted to Scalar");
|
||||
if (at::GradMode::is_enabled() && self.requires_grad()) {
|
||||
TORCH_WARN_ONCE("Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n"
|
||||
"Consider using tensor.detach() first.");
|
||||
}
|
||||
if (self.is_sparse()) {
|
||||
if (self._nnz() == 0) return Scalar(0);
|
||||
if (self.is_coalesced()) return at::_local_scalar_dense(self._values());
|
||||
|
|
|
|||
|
|
@ -10829,6 +10829,23 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
def test_bf16_supported_on_cpu(self):
|
||||
self.assertFalse(torch.cuda.is_bf16_supported())
|
||||
|
||||
def test_tensor_with_grad_to_scalar_warning(self) -> None:
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
x = torch.tensor(2.0, requires_grad=True)
|
||||
math.pow(x, 3) # calling this results in a warning
|
||||
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertTrue(issubclass(w[0].category, UserWarning))
|
||||
self.assertIn(
|
||||
"Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.",
|
||||
str(w[0].message)
|
||||
)
|
||||
|
||||
_ = math.pow(x, 3) # calling it again does not result in a second warning
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
# The following block extends TestTorch with negative dim wrapping tests
|
||||
# FIXME: replace these with OpInfo sample inputs or systemic OpInfo tests
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user