[PrivateUse1] Allow out-of-tree devices to pass check when validating csr tensor args (#149374)

Fixes #149303
Fllow-up: #147306

Because we have a dispatch key named `DispatchKey::SparseCsrPrivateUse1` for this case, we allow users to create a csr tensor on out-of-tree devices, so we should also let that pass the check.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149374
Approved by: https://github.com/FFFrog, https://github.com/albanD
This commit is contained in:
Yuanhao Ji 2025-04-11 09:05:16 +00:00 committed by PyTorch MergeBot
parent 5590a0692c
commit d6f1c72354

View File

@ -276,10 +276,10 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres
// Device Invariants
// 4.1
TORCH_CHECK(
values.device().type() == kCPU || values.device().type() == kCUDA || values.device().type() == kXPU || values.device().type() == kMeta,
values.device().type() == kCPU || values.device().type() == kCUDA || values.device().type() == kXPU || values.device().type() == kMeta || values.device().type() == kPrivateUse1,
"device type of values (",
values.device().type(),
") must be CPU or CUDA or XPU or Meta");
") must be one of CPU, CUDA, XPU, Meta or PrivateUse1")
// 4.2, 4.3, 4.4
TORCH_CHECK(
compressed_indices.get_device() == values.get_device(),