mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] fix hardsigmoid op (#162758)
Currently std::min -> ::min did not work as expected on ROCm when input values >= 2147483648 It can be fixed by explicit typing std::min<opmath_t> Pull Request resolved: https://github.com/pytorch/pytorch/pull/162758 Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
parent
7357eb66c5
commit
1e9ddf510f
|
|
@ -36,7 +36,7 @@ void hardsigmoid_kernel(TensorIteratorBase& iter) {
|
||||||
[zero, one_sixth, three, six] GPU_LAMBDA(
|
[zero, one_sixth, three, six] GPU_LAMBDA(
|
||||||
scalar_t self_val) -> scalar_t {
|
scalar_t self_val) -> scalar_t {
|
||||||
opmath_t x = static_cast<opmath_t>(self_val);
|
opmath_t x = static_cast<opmath_t>(self_val);
|
||||||
return std::min(std::max(x + three, zero), six) * one_sixth;
|
return std::min<opmath_t>(std::max<opmath_t>(x + three, zero), six) * one_sixth;
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16593,12 +16593,7 @@ op_db: list[OpInfo] = [
|
||||||
toleranceOverride({torch.float16: tol(atol=1e-04, rtol=0.001)}), 'TestUnaryUfuncs', device_type='cuda',), ],
|
toleranceOverride({torch.float16: tol(atol=1e-04, rtol=0.001)}), 'TestUnaryUfuncs', device_type='cuda',), ],
|
||||||
skips=[
|
skips=[
|
||||||
# still want to test that first derivative works though second derivative isn't supported
|
# still want to test that first derivative works though second derivative isn't supported
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', "test_inplace_gradgrad"),
|
DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', "test_inplace_gradgrad")]
|
||||||
# produces 0 instead of nan on ROCM
|
|
||||||
DecorateInfo(unittest.expectedFailure,
|
|
||||||
'TestUnaryUfuncs', "test_reference_numerics_extremal",
|
|
||||||
device_type='cuda',
|
|
||||||
active_if=(TEST_WITH_ROCM)), ]
|
|
||||||
),
|
),
|
||||||
UnaryUfuncInfo(
|
UnaryUfuncInfo(
|
||||||
'nn.functional.logsigmoid',
|
'nn.functional.logsigmoid',
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user