diff --git a/aten/src/ATen/SparseCsrTensorImpl.cpp b/aten/src/ATen/SparseCsrTensorImpl.cpp index 8dc1fd05452..0ec3c97a2da 100644 --- a/aten/src/ATen/SparseCsrTensorImpl.cpp +++ b/aten/src/ATen/SparseCsrTensorImpl.cpp @@ -56,9 +56,11 @@ SparseCsrTensorImpl::SparseCsrTensorImpl( TORCH_INTERNAL_ASSERT(((key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kCPU) || (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kCUDA) + || (key_set.has(DispatchKey::SparseCsrXPU) && device().type() == kXPU) || (key_set.has(DispatchKey::SparseCsrMeta) && device().type() == kMeta) || (key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kMeta) // fake tensor || (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kMeta) // fake tensor + || (key_set.has(DispatchKey::SparseCsrXPU) && device().type() == kMeta) // fake tensor || (key_set.has(DispatchKey::SparseCsrPrivateUse1) && device().type() == kPrivateUse1)), "Inconsistent key_set (=", key_set, ") and device (=", device(), ")"); diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp index ca5447c6a80..13f54749f8d 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp @@ -360,6 +360,9 @@ static SparseCsrTensor new_compressed_tensor(const TensorOptions& options) { case kCUDA: dispatch_key = DispatchKey::SparseCsrCUDA; break; + case kXPU: + dispatch_key = DispatchKey::SparseCsrXPU; + break; case kMeta: dispatch_key = DispatchKey::SparseCsrMeta; break; diff --git a/torchgen/model.py b/torchgen/model.py index f5fd1a5a90e..54bb8087dc0 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -283,6 +283,7 @@ dispatch_keys = [ DispatchKey.MPS, DispatchKey.XPU, DispatchKey.SparseXPU, + DispatchKey.SparseCsrXPU, DispatchKey.SparseCUDA, DispatchKey.SparseCsrCUDA, DispatchKey.QuantizedCPU,