mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge pull request #49794 from geetachavan1/cherrypicks_9CW9I
Prevent division by 0 in OneHot implementation
This commit is contained in:
commit
28638f5b5d
|
|
@ -69,6 +69,11 @@ void OneHotComputeImpl(const OneHotContext& op_context) {
|
||||||
for (int i = 0; i < op_context.axis; ++i) {
|
for (int i = 0; i < op_context.axis; ++i) {
|
||||||
prefix_dim_size *= op_context.indices->dims->data[i];
|
prefix_dim_size *= op_context.indices->dims->data[i];
|
||||||
}
|
}
|
||||||
|
if (prefix_dim_size == 0) {
|
||||||
|
// If indices tensor is degenerate, return a degenerate tensor, just like
|
||||||
|
// TensorFlow does.
|
||||||
|
return;
|
||||||
|
}
|
||||||
const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
|
const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
|
||||||
const int depth = *op_context.depth->data.i32;
|
const int depth = *op_context.depth->data.i32;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user