mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[TF:XLA] Improve numerical stability of SoftPlus.
PiperOrigin-RevId: 171003559
This commit is contained in:
parent
727d6270f9
commit
2114fd51e9
|
|
@ -309,11 +309,6 @@ class UnaryOpsTest(XLATestCase):
|
|||
[0.032058604, 0.087144323, 0.23688284, 0.64391428]],
|
||||
dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.softplus,
|
||||
np.array([[-2, 0, 8]], dtype=dtype),
|
||||
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.softsign,
|
||||
np.array([[-2, -1, 0, 1, 2]], dtype=dtype),
|
||||
|
|
@ -543,6 +538,25 @@ class UnaryOpsTest(XLATestCase):
|
|||
[[9, 10, 11, 12],
|
||||
[13, 14, 15, 16]]]], dtype=dtype))
|
||||
|
||||
def _assertSoftplusMatchesExpected(self, features, dtype):
|
||||
features = np.array(features, dtype=dtype)
|
||||
zero = np.asarray(0).astype(dtype)
|
||||
expected = np.logaddexp(zero, features)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.softplus, features, expected=expected)
|
||||
|
||||
def testSoftplus(self):
|
||||
for dtype in self.float_types:
|
||||
self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype)
|
||||
self._assertSoftplusMatchesExpected(
|
||||
[[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype)
|
||||
log_eps = np.log(np.finfo(dtype).eps)
|
||||
one = dtype(1)
|
||||
ten = dtype(10)
|
||||
self._assertSoftplusMatchesExpected([
|
||||
log_eps, log_eps - one, log_eps + one, log_eps - ten,
|
||||
log_eps + ten, -log_eps, -log_eps - one, -log_eps + one,
|
||||
-log_eps - ten, -log_eps + ten], dtype)
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
|
|
|||
|
|
@ -129,8 +129,28 @@ XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
|
|||
XLAJIT_MAKE_UNARY(Sinh,
|
||||
b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))),
|
||||
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
|
||||
XLAJIT_MAKE_UNARY(Softplus,
|
||||
b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0)))));
|
||||
|
||||
static xla::ComputationDataHandle Softplus(
|
||||
xla::ComputationBuilder* b, DataType dtype,
|
||||
const xla::ComputationDataHandle& features) {
|
||||
xla::ComputationDataHandle threshold =
|
||||
b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)),
|
||||
XlaHelpers::FloatLiteral(b, dtype, 2.0));
|
||||
// Value above which exp(x) may overflow, but softplus(x) == x
|
||||
// is within machine epsilon.
|
||||
xla::ComputationDataHandle too_large = b->Gt(features, b->Neg(threshold));
|
||||
// Value below which exp(x) may underflow, but softplus(x) == exp(x)
|
||||
// is within machine epsilon.
|
||||
xla::ComputationDataHandle too_small = b->Lt(features, threshold);
|
||||
xla::ComputationDataHandle features_exp = b->Exp(features);
|
||||
xla::ComputationDataHandle output = b->Select(
|
||||
too_large, features,
|
||||
b->Select(too_small, features_exp,
|
||||
b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype)))));
|
||||
return output;
|
||||
}
|
||||
XLAJIT_MAKE_UNARY(Softplus, Softplus(b, input_type(0), x));
|
||||
|
||||
// softsign(x) = x / (abs(x) + 1)
|
||||
XLAJIT_MAKE_UNARY(Softsign,
|
||||
b->Div(x,
|
||||
|
|
|
|||
|
|
@ -54,6 +54,19 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b,
|
|||
return b->ConstantLiteral(xla::Literal::One(type));
|
||||
}
|
||||
|
||||
xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b,
|
||||
DataType data_type) {
|
||||
switch (data_type) {
|
||||
case DT_FLOAT:
|
||||
return b->ConstantR0<float>(std::numeric_limits<float>::epsilon());
|
||||
case DT_DOUBLE:
|
||||
return b->ConstantR0<double>(std::numeric_limits<double>::epsilon());
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: "
|
||||
<< DataTypeString(data_type);
|
||||
}
|
||||
}
|
||||
|
||||
xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
|
||||
xla::ComputationBuilder* b, DataType data_type, int64 value) {
|
||||
xla::Literal literal;
|
||||
|
|
|
|||
|
|
@ -48,6 +48,11 @@ class XlaHelpers {
|
|||
static xla::ComputationDataHandle One(xla::ComputationBuilder* b,
|
||||
DataType data_type);
|
||||
|
||||
// Returns the machine epsilon for floating-point type `data_type`, i.e.,
|
||||
// the difference between 1.0 and the next representable value.
|
||||
static xla::ComputationDataHandle Epsilon(xla::ComputationBuilder* b,
|
||||
DataType data_type);
|
||||
|
||||
// Returns a handle representing the given value of an integer scalar
|
||||
// element of data_type.
|
||||
// Note that unlike One and Zero, does not work on boolean types.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user