mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make device check error message more descriptive (#150750)
Fixes #122757 ## Test Result ```python import torch model_output = torch.randn(10, 5).cuda() labels = torch.randint(0, 5, (10,)).cuda() weights = torch.randn(5) loss_fn = torch.nn.CrossEntropyLoss(weight=weights) loss = loss_fn(input=model_output, target=labels) print(loss) Traceback (most recent call last): File "/home/zong/code/pytorch/../loss2.py", line 17, in <module> loss = loss_fn(input=model_output, target=labels) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/nn/modules/loss.py", line 1297, in forward return F.cross_entropy( ^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/nn/functional.py", line 3494, in cross_entropy return torch._C._nn.cross_entropy_loss( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: Expected all tensors to be on the same device, but got weight is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_nll_loss_forward) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150750 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
parent
1d7728056b
commit
8253970a1f
|
|
@ -5,9 +5,8 @@ namespace c10::impl {
|
|||
|
||||
void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
||||
TORCH_CHECK(false,
|
||||
"Expected all tensors to be on the same device, but "
|
||||
"found at least two devices, ", common_device, " and ", tensor.device(), "! "
|
||||
"(when checking argument for argument ", argName, " in method ", methodName, ")");
|
||||
"Expected all tensors to be on the same device, but got ", argName, " is on ", tensor.device(),
|
||||
", different from other tensors on ", common_device, " (when checking argument in method ", methodName, ")");
|
||||
}
|
||||
|
||||
} // namespace c10::impl
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user