mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge pull request #49794 from geetachavan1/cherrypicks_9CW9I
Prevent division by 0 in OneHot implementation
This commit is contained in:
commit
28638f5b5d
|
|
@ -69,6 +69,11 @@ void OneHotComputeImpl(const OneHotContext& op_context) {
|
|||
for (int i = 0; i < op_context.axis; ++i) {
|
||||
prefix_dim_size *= op_context.indices->dims->data[i];
|
||||
}
|
||||
if (prefix_dim_size == 0) {
|
||||
// If indices tensor is degenerate, return a degenerate tensor, just like
|
||||
// TensorFlow does.
|
||||
return;
|
||||
}
|
||||
const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
|
||||
const int depth = *op_context.depth->data.i32;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user