mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improve reciprocal() and rsqrt() accuracy on arm64 (#47478)
Summary: Neither `vrecpeq_f32` nor `vrsqrteq_f32` yield accurate results but just perform first of two steps in an iteration of the Newton-Raphson method, as documented at https://developer.arm.com/documentation/dui0472/j/using-neon-support/neon-intrinsics-for-reciprocal-and-sqrt Use appropriate NEON instruction to run two more steps of the Newton's method to improve results Before: ``` $ python -c "import torch;print(torch.arange(1.0, 17.0, 1.0, dtype=torch.float32).reciprocal())" tensor([0.9980, 0.4990, 0.3330, 0.2495, 0.1997, 0.1665, 0.1426, 0.1248, 0.1108, 0.0999, 0.0908, 0.0833, 0.0769, 0.0713, 0.0667, 0.0624]) $ python -c "import torch;print(torch.arange(1.0, 17.0, 1.0, dtype=torch.float32).rsqrt())" tensor([0.9980, 0.7051, 0.5762, 0.4990, 0.4463, 0.4082, 0.3779, 0.3525, 0.3330, 0.3154, 0.3008, 0.2881, 0.2773, 0.2666, 0.2578, 0.2495]) ``` After: ``` $ python -c "import torch;print(torch.arange(1.0, 17.0, 1.0, dtype=torch.float32).reciprocal())" tensor([1.0000, 0.5000, 0.3333, 0.2500, 0.2000, 0.1667, 0.1429, 0.1250, 0.1111, 0.1000, 0.0909, 0.0833, 0.0769, 0.0714, 0.0667, 0.0625]) $ python -c "import torch;print(torch.arange(1.0, 17.0, 1.0, dtype=torch.float32).rsqrt())" tensor([1.0000, 0.7071, 0.5774, 0.5000, 0.4472, 0.4082, 0.3780, 0.3536, 0.3333, 0.3162, 0.3015, 0.2887, 0.2774, 0.2673, 0.2582, 0.2500]) ``` Partially addresses https://github.com/pytorch/pytorch/issues/47476 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47478 Reviewed By: walterddr Differential Revision: D24773443 Pulled By: malfet fbshipit-source-id: 224dca9725601d29fb229f8d71d968a30f25c829
This commit is contained in:
parent
5614f72534
commit
d0d673b043
|
|
@ -444,14 +444,23 @@ public:
|
|||
vsqrtq_f32(values.val[1]));
|
||||
}
|
||||
Vec256<float> reciprocal() const {
|
||||
return Vec256<float>(
|
||||
vrecpeq_f32(values.val[0]),
|
||||
vrecpeq_f32(values.val[1]));
|
||||
float32x4_t r0 = vrecpeq_f32(values.val[0]);
|
||||
float32x4_t r1 = vrecpeq_f32(values.val[1]);
|
||||
// Run two more Netwon's method iterations to get more accurate results
|
||||
r0 = vmulq_f32(vrecpsq_f32(values.val[0], r0), r0);
|
||||
r0 = vmulq_f32(vrecpsq_f32(values.val[0], r0), r0);
|
||||
r1 = vmulq_f32(vrecpsq_f32(values.val[1], r1), r1);
|
||||
r1 = vmulq_f32(vrecpsq_f32(values.val[1], r1), r1);
|
||||
return Vec256<float>(r0, r1);
|
||||
}
|
||||
Vec256<float> rsqrt() const {
|
||||
return Vec256<float>(
|
||||
vrsqrteq_f32(values.val[0]),
|
||||
vrsqrteq_f32(values.val[1]));
|
||||
float32x4_t r0 = vrsqrteq_f32(values.val[0]);
|
||||
float32x4_t r1 = vrsqrteq_f32(values.val[1]);
|
||||
r0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(values.val[0], r0), r0), r0);
|
||||
r0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(values.val[0], r0), r0), r0);
|
||||
r1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(values.val[1], r1), r1), r1);
|
||||
r1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(values.val[1], r1), r1), r1);
|
||||
return Vec256<float>(r0, r1);
|
||||
}
|
||||
Vec256<float> pow(const Vec256<float> &exp) const {
|
||||
__at_align32__ float tmp[size()];
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user