[TF/XLA] Fix the CollectiveReduceV2Op lowering by specifying the number of replicas needed at compile-time

PiperOrigin-RevId: 365890700
Change-Id: I8c213c4ed469d7307d0a8be99a1a32bc6e06a1da
This commit is contained in:
George Karpenkov 2021-03-30 13:58:13 -07:00 committed by Geeta Chavan
parent fb37439d64
commit c4d24e402f
2 changed files with 6 additions and 7 deletions

View File

@ -175,6 +175,9 @@ Status XlaCompilationCache::BuildExecutable(
argument_layouts[i] = &result.xla_input_shapes[i];
}
xla::ExecutableBuildOptions build_options;
if (result.collective_reduce_info) {
build_options.set_num_replicas(result.collective_reduce_info->group_size);
}
build_options.set_device_ordinal(options.device_ordinal != -1
? options.device_ordinal
: client_->default_device_ordinal());

View File

@ -308,13 +308,9 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
sync_batchnorm=sync_batchnorm,
jit_compile=False)
error_margin = 1e-2 if jit_compile else 1e-3
loss_error_margin = 5e-2 if jit_compile else 1e-3
self.assertAllClose(wts, wts_with_ds, atol=error_margin, rtol=error_margin)
self.assertAllClose(
loss, loss_with_ds, atol=loss_error_margin, rtol=loss_error_margin)
self.assertAllClose(acc, acc_with_ds, atol=error_margin, rtol=error_margin)
self.assertAllClose(wts, wts_with_ds, atol=1e-3, rtol=1e-3)
self.assertAllClose(loss, loss_with_ds, atol=1e-3, rtol=1e-3)
self.assertAllClose(acc, acc_with_ds, atol=1e-3, rtol=1e-3)
if __name__ == '__main__':