mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PrivateUse1] Allow out-of-tree devices to pass check when validating csr tensor args (#149374)
Fixes #149303 Fllow-up: #147306 Because we have a dispatch key named `DispatchKey::SparseCsrPrivateUse1` for this case, we allow users to create a csr tensor on out-of-tree devices, so we should also let that pass the check. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149374 Approved by: https://github.com/FFFrog, https://github.com/albanD
This commit is contained in:
parent
5590a0692c
commit
d6f1c72354
|
|
@ -276,10 +276,10 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres
|
|||
// Device Invariants
|
||||
// 4.1
|
||||
TORCH_CHECK(
|
||||
values.device().type() == kCPU || values.device().type() == kCUDA || values.device().type() == kXPU || values.device().type() == kMeta,
|
||||
values.device().type() == kCPU || values.device().type() == kCUDA || values.device().type() == kXPU || values.device().type() == kMeta || values.device().type() == kPrivateUse1,
|
||||
"device type of values (",
|
||||
values.device().type(),
|
||||
") must be CPU or CUDA or XPU or Meta");
|
||||
") must be one of CPU, CUDA, XPU, Meta or PrivateUse1")
|
||||
// 4.2, 4.3, 4.4
|
||||
TORCH_CHECK(
|
||||
compressed_indices.get_device() == values.get_device(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user