mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make device check throw specific error (#155085)
Fixes #122757 The fix is lost after revert and rebase previous PR https://github.com/pytorch/pytorch/pull/150750 (only change of tests are merged). ## Test Result ```python >>> import torch >>> >>> model_output = torch.randn(10, 5).cuda() >>> labels = torch.randint(0, 5, (10,)).cuda() >>> weights = torch.randn(5) >>> >>> loss_fn = torch.nn.CrossEntropyLoss(weight=weights) >>> loss = loss_fn(input=model_output, target=labels) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/zong/code/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/nn/modules/loss.py", line 1297, in forward return F.cross_entropy( ^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/nn/functional.py", line 3476, in cross_entropy return torch._C._nn.cross_entropy_loss( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: Expected all tensors to be on the same device, but got weight is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_nll_loss_forward) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/155085 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
parent
07da8a469b
commit
9d59b516e9
|
|
@ -5,9 +5,8 @@ namespace c10::impl {
|
|||
|
||||
void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
||||
TORCH_CHECK(false,
|
||||
"Expected all tensors to be on the same device, but "
|
||||
"found at least two devices, ", common_device, " and ", tensor.device(), "! "
|
||||
"(when checking argument for argument ", argName, " in method ", methodName, ")");
|
||||
"Expected all tensors to be on the same device, but got ", argName, " is on ", tensor.device(),
|
||||
", different from other tensors on ", common_device, " (when checking argument in method ", methodName, ")");
|
||||
}
|
||||
|
||||
} // namespace c10::impl
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user