mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable oneDNN for tanh based GELU on aarch64 (#130925)
Provides speedup for GELU on aarch64 compared to native PyTorch implementation. e.g. 8.5x speedup compared to native implementation for 1x1x16384 on 32 threads on Graviton 3 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130925 Approved by: https://github.com/malfet
This commit is contained in:
parent
97eba8e174
commit
ce61300141
|
|
@ -396,6 +396,13 @@ auto approximate_type = get_gelutype_enum(approximate);
|
|||
ideep::tensor y = itensor_from_tensor(result);
|
||||
ideep::eltwise_forward::compute(
|
||||
x, y, ideep::algorithm::eltwise_gelu_erf, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
|
||||
#ifdef __aarch64__
|
||||
} else if (use_mkldnn(self) && (approximate_type == GeluType::Tanh)) {
|
||||
const ideep::tensor& x = itensor_from_tensor(self);
|
||||
ideep::tensor y = itensor_from_tensor(result);
|
||||
ideep::eltwise_forward::compute(
|
||||
x, y, ideep::algorithm::eltwise_gelu_tanh, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
|
||||
#endif // ifdef __aarch64__
|
||||
} else {
|
||||
GeluKernel(kCPU, *this, approximate_type);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user