Add a warning when a tensor with requires_grad=True is converted to a scalar (#143261)

Fixes #143071

Operations performed on tensors with `requires_grad=True` such as
```python
import torch

x = torch.tensor(2.0, requires_grad=True)
y = x ** 3
```
and
```python
x = torch.tensor(2.0, requires_grad=True)
y = torch.pow(x,3)
```
are valid operations.

While an operation using `numpy` like
```python
import numpy as np

x = torch.tensor(2.0, requires_grad=True)
y = np.pow(x,3)
# > RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
```
leads to an error.

However, an operation that uses `math` like
```python
import math

x = torch.tensor(2.0, requires_grad=True)
y = math.pow(x,3)
```
does not cause an error, and `y` is no longer a tensor with a gradient!

This represents a [footgun](https://en.wiktionary.org/wiki/footgun#Noun) for some users, like myself when training small, custom, non-neural network models.

To prevent future undesired behavior, I added a warning when converting tensors with `requires_grad=True` to scalars. Now, when using `math.pow` on a `tensor`, we get a single warning with:
```python
x = torch.tensor(2.0, requires_grad=True)
y = math.pow(x,3)
# > UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.
# Consider using tensor.detach() first.
```

Please let me know if you have any questions 👍
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143261
Approved by: https://github.com/malfet

Co-authored-by: albanD <desmaison.alban@gmail.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Joshua Hamilton 2025-04-01 00:42:46 +00:00 committed by PyTorch MergeBot
parent 49b7d0d84d
commit 4ce0b959ff
2 changed files with 23 additions and 0 deletions

View File

@ -11,11 +11,17 @@
#include <ATen/ops/item_native.h>
#endif
#include <ATen/core/grad_mode.h>
namespace at::native {
Scalar item(const Tensor& self) {
auto numel = self.sym_numel();
TORCH_CHECK(numel == 1, "a Tensor with ", numel, " elements cannot be converted to Scalar");
if (at::GradMode::is_enabled() && self.requires_grad()) {
TORCH_WARN_ONCE("Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n"
"Consider using tensor.detach() first.");
}
if (self.is_sparse()) {
if (self._nnz() == 0) return Scalar(0);
if (self.is_coalesced()) return at::_local_scalar_dense(self._values());

View File

@ -10829,6 +10829,23 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
def test_bf16_supported_on_cpu(self):
self.assertFalse(torch.cuda.is_bf16_supported())
def test_tensor_with_grad_to_scalar_warning(self) -> None:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
x = torch.tensor(2.0, requires_grad=True)
math.pow(x, 3) # calling this results in a warning
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[0].category, UserWarning))
self.assertIn(
"Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.",
str(w[0].message)
)
_ = math.pow(x, 3) # calling it again does not result in a second warning
self.assertEqual(len(w), 1)
# The following block extends TestTorch with negative dim wrapping tests
# FIXME: replace these with OpInfo sample inputs or systemic OpInfo tests