Merge pull request #48335 from geetachavan1/cherrypicks_OHAKM

[CherryPick:r2.5] [TF/XLA] Fix the CollectiveReduceV2Op lowering by specifying the number of replicas needed at compile-time
This commit is contained in:
Mihai Maruseac 2021-04-09 13:08:36 -07:00 committed by GitHub
commit 43abad5bb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 7 deletions

View File

@ -175,6 +175,9 @@ Status XlaCompilationCache::BuildExecutable(
argument_layouts[i] = &result.xla_input_shapes[i];
}
xla::ExecutableBuildOptions build_options;
if (result.collective_reduce_info) {
build_options.set_num_replicas(result.collective_reduce_info->group_size);
}
build_options.set_device_ordinal(options.device_ordinal != -1
? options.device_ordinal
: client_->default_device_ordinal());

View File

@ -308,13 +308,9 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
sync_batchnorm=sync_batchnorm,
jit_compile=False)
error_margin = 1e-2 if jit_compile else 1e-3
loss_error_margin = 5e-2 if jit_compile else 1e-3
self.assertAllClose(wts, wts_with_ds, atol=error_margin, rtol=error_margin)
self.assertAllClose(
loss, loss_with_ds, atol=loss_error_margin, rtol=loss_error_margin)
self.assertAllClose(acc, acc_with_ds, atol=error_margin, rtol=error_margin)
self.assertAllClose(wts, wts_with_ds, atol=1e-3, rtol=1e-3)
self.assertAllClose(loss, loss_with_ds, atol=1e-3, rtol=1e-3)
self.assertAllClose(acc, acc_with_ds, atol=1e-3, rtol=1e-3)
if __name__ == '__main__':