mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[TF/XLA] Fix the CollectiveReduceV2Op lowering by specifying the number of replicas needed at compile-time
PiperOrigin-RevId: 365890700 Change-Id: I8c213c4ed469d7307d0a8be99a1a32bc6e06a1da
This commit is contained in:
parent
fb37439d64
commit
c4d24e402f
|
|
@ -175,6 +175,9 @@ Status XlaCompilationCache::BuildExecutable(
|
|||
argument_layouts[i] = &result.xla_input_shapes[i];
|
||||
}
|
||||
xla::ExecutableBuildOptions build_options;
|
||||
if (result.collective_reduce_info) {
|
||||
build_options.set_num_replicas(result.collective_reduce_info->group_size);
|
||||
}
|
||||
build_options.set_device_ordinal(options.device_ordinal != -1
|
||||
? options.device_ordinal
|
||||
: client_->default_device_ordinal());
|
||||
|
|
|
|||
|
|
@ -308,13 +308,9 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
|
|||
sync_batchnorm=sync_batchnorm,
|
||||
jit_compile=False)
|
||||
|
||||
error_margin = 1e-2 if jit_compile else 1e-3
|
||||
loss_error_margin = 5e-2 if jit_compile else 1e-3
|
||||
|
||||
self.assertAllClose(wts, wts_with_ds, atol=error_margin, rtol=error_margin)
|
||||
self.assertAllClose(
|
||||
loss, loss_with_ds, atol=loss_error_margin, rtol=loss_error_margin)
|
||||
self.assertAllClose(acc, acc_with_ds, atol=error_margin, rtol=error_margin)
|
||||
self.assertAllClose(wts, wts_with_ds, atol=1e-3, rtol=1e-3)
|
||||
self.assertAllClose(loss, loss_with_ds, atol=1e-3, rtol=1e-3)
|
||||
self.assertAllClose(acc, acc_with_ds, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user