mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Functorch] Support Functorch for PrivateUse1 backend (#154700)
This PR enable that functorch to be used in 3rd party backends. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154700 Approved by: https://github.com/zou3519
This commit is contained in:
parent
15e9119a69
commit
daff263062
|
|
@ -159,6 +159,7 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
|||
DispatchKey::XLA,
|
||||
DispatchKey::CUDA,
|
||||
DispatchKey::CPU,
|
||||
DispatchKey::PrivateUse1,
|
||||
});
|
||||
|
||||
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ static Tensor make_feature_noise(const Tensor& input) {
|
|||
}
|
||||
|
||||
static bool is_fused_kernel_acceptable(const Tensor& input, double p) {
|
||||
return (input.is_cuda() || input.is_xpu() || input.is_lazy()) && p > 0 && p < 1 && input.numel() > 0;
|
||||
return (input.is_cuda() || input.is_xpu() || input.is_lazy() || input.is_privateuseone()) && p > 0 && p < 1 && input.numel() > 0;
|
||||
}
|
||||
|
||||
// NB: sure, we could have used different overloads here, but I would feel insecure
|
||||
|
|
|
|||
|
|
@ -56,7 +56,8 @@ void dumpTensorCout(const Tensor& tensor) {
|
|||
|
||||
static c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, const std::shared_ptr<bool>& life_handle) {
|
||||
auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
|
||||
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
|
||||
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA,
|
||||
DispatchKey::AutogradPrivateUse1});
|
||||
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
|
||||
key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
|
||||
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, life_handle);
|
||||
|
|
@ -76,7 +77,8 @@ static Tensor unsafeMakeTensorWrapper(
|
|||
}
|
||||
|
||||
auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
|
||||
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
|
||||
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA,
|
||||
DispatchKey::AutogradPrivateUse1});
|
||||
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
|
||||
key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
|
||||
auto result = at::detail::make_tensor<TensorWrapper>(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user