mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Intel GPU] Support SparseCsrXPU codegen (#144722)
Adding a new dispatch key - `SparseCsrXPU` to enable Intel GPU support for SparseCsr Tensor. Similar PR: https://github.com/pytorch/pytorch/pull/139267 Pull Request resolved: https://github.com/pytorch/pytorch/pull/144722 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/albanD Co-authored-by: Kanya-Mo <kanya.mo@intel.com>
This commit is contained in:
parent
1677a31019
commit
8f20026bcb
|
|
@ -56,9 +56,11 @@ SparseCsrTensorImpl::SparseCsrTensorImpl(
|
||||||
|
|
||||||
TORCH_INTERNAL_ASSERT(((key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kCPU)
|
TORCH_INTERNAL_ASSERT(((key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kCPU)
|
||||||
|| (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kCUDA)
|
|| (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kCUDA)
|
||||||
|
|| (key_set.has(DispatchKey::SparseCsrXPU) && device().type() == kXPU)
|
||||||
|| (key_set.has(DispatchKey::SparseCsrMeta) && device().type() == kMeta)
|
|| (key_set.has(DispatchKey::SparseCsrMeta) && device().type() == kMeta)
|
||||||
|| (key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kMeta) // fake tensor
|
|| (key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kMeta) // fake tensor
|
||||||
|| (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kMeta) // fake tensor
|
|| (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kMeta) // fake tensor
|
||||||
|
|| (key_set.has(DispatchKey::SparseCsrXPU) && device().type() == kMeta) // fake tensor
|
||||||
|| (key_set.has(DispatchKey::SparseCsrPrivateUse1) && device().type() == kPrivateUse1)),
|
|| (key_set.has(DispatchKey::SparseCsrPrivateUse1) && device().type() == kPrivateUse1)),
|
||||||
"Inconsistent key_set (=", key_set, ") and device (=", device(), ")");
|
"Inconsistent key_set (=", key_set, ") and device (=", device(), ")");
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -360,6 +360,9 @@ static SparseCsrTensor new_compressed_tensor(const TensorOptions& options) {
|
||||||
case kCUDA:
|
case kCUDA:
|
||||||
dispatch_key = DispatchKey::SparseCsrCUDA;
|
dispatch_key = DispatchKey::SparseCsrCUDA;
|
||||||
break;
|
break;
|
||||||
|
case kXPU:
|
||||||
|
dispatch_key = DispatchKey::SparseCsrXPU;
|
||||||
|
break;
|
||||||
case kMeta:
|
case kMeta:
|
||||||
dispatch_key = DispatchKey::SparseCsrMeta;
|
dispatch_key = DispatchKey::SparseCsrMeta;
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -283,6 +283,7 @@ dispatch_keys = [
|
||||||
DispatchKey.MPS,
|
DispatchKey.MPS,
|
||||||
DispatchKey.XPU,
|
DispatchKey.XPU,
|
||||||
DispatchKey.SparseXPU,
|
DispatchKey.SparseXPU,
|
||||||
|
DispatchKey.SparseCsrXPU,
|
||||||
DispatchKey.SparseCUDA,
|
DispatchKey.SparseCUDA,
|
||||||
DispatchKey.SparseCsrCUDA,
|
DispatchKey.SparseCsrCUDA,
|
||||||
DispatchKey.QuantizedCPU,
|
DispatchKey.QuantizedCPU,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user