Make device check throw specific error (#155085)

Fixes #122757

The fix is lost after revert and rebase previous PR https://github.com/pytorch/pytorch/pull/150750 (only change of tests are merged).

## Test Result

```python
>>> import torch
>>>
>>> model_output = torch.randn(10, 5).cuda()
>>> labels = torch.randint(0, 5, (10,)).cuda()
>>> weights = torch.randn(5)
>>>
>>> loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
>>> loss = loss_fn(input=model_output, target=labels)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/zong/code/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/nn/modules/loss.py", line 1297, in forward
    return F.cross_entropy(
           ^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/nn/functional.py", line 3476, in cross_entropy
    return torch._C._nn.cross_entropy_loss(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but got weight is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_nll_loss_forward)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155085
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
zeshengzong 2025-06-06 07:00:04 +00:00 committed by PyTorch MergeBot
parent 07da8a469b
commit 9d59b516e9

View File

@ -5,9 +5,8 @@ namespace c10::impl {
void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
TORCH_CHECK(false,
"Expected all tensors to be on the same device, but "
"found at least two devices, ", common_device, " and ", tensor.device(), "! "
"(when checking argument for argument ", argName, " in method ", methodName, ")");
"Expected all tensors to be on the same device, but got ", argName, " is on ", tensor.device(),
", different from other tensors on ", common_device, " (when checking argument in method ", methodName, ")");
}
} // namespace c10::impl