Merge pull request #49794 from geetachavan1/cherrypicks_9CW9I

Prevent division by 0 in OneHot implementation
This commit is contained in:
Mihai Maruseac 2021-05-30 17:20:05 -07:00 committed by GitHub
commit 28638f5b5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -69,6 +69,11 @@ void OneHotComputeImpl(const OneHotContext& op_context) {
for (int i = 0; i < op_context.axis; ++i) { for (int i = 0; i < op_context.axis; ++i) {
prefix_dim_size *= op_context.indices->dims->data[i]; prefix_dim_size *= op_context.indices->dims->data[i];
} }
if (prefix_dim_size == 0) {
// If indices tensor is degenerate, return a degenerate tensor, just like
// TensorFlow does.
return;
}
const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size; const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
const int depth = *op_context.depth->data.i32; const int depth = *op_context.depth->data.i32;