[Caffe2][Testing] Check for equality first in assertTensorEqualsWithType<float> (#61006)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61006

Test Plan: Modified existing unit test to test for eps = 0. It would fail without the equality test first.

Reviewed By: ajyu

Differential Revision: D29423770

fbshipit-source-id: 168e7de00d8522c4b646a8335d0120700915f260
This commit is contained in:
Hao Lu 2021-06-29 23:30:12 -07:00 committed by Facebook GitHub Bot
parent 287c0ab170
commit 4adc5eb6c5
2 changed files with 12 additions and 9 deletions

View File

@ -651,8 +651,8 @@ TEST(TensorTest, CopyAndAssignment) {
Tensor y(x);
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
Tensor z = x;
testing::assertTensorEquals(x, y);
testing::assertTensorEquals(x, z);
testing::assertTensorEquals(x, y, 0);
testing::assertTensorEquals(x, z, 0);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -26,13 +26,16 @@ void assertTensorEqualsWithType<float>(
float eps) {
CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
for (auto idx = 0; idx < tensor1.numel(); ++idx) {
CAFFE_ENFORCE_LT(
fabs(tensor1.data<float>()[idx] - tensor2.data<float>()[idx]),
eps,
"Mismatch at index ",
idx,
" exceeds threshold of ",
eps);
// When a == b, a - b may not be equal to 0
if (tensor1.data<float>()[idx] != tensor2.data<float>()[idx]) {
CAFFE_ENFORCE_LT(
fabs(tensor1.data<float>()[idx] - tensor2.data<float>()[idx]),
eps,
"Mismatch at index ",
idx,
" exceeds threshold of ",
eps);
}
}
}
} // namespace