Move contrib/nccl to core/nccl.

PiperOrigin-RevId: 218908694
This commit is contained in:
A. Unique TensorFlower 2018-10-26 13:54:44 -07:00 committed by TensorFlower Gardener
parent 2c164ed32f
commit fc6cd33c33
28 changed files with 207 additions and 394 deletions

View File

@ -1,6 +1,7 @@
# Where component owners are known, add them here.
/tenosrflow/core/debug @caisq
/tensorflow/core/nccl/ @azaks @csigg
/tensorflow/core/platform/windows/ @mrry
/tensorflow/core/platform/s3 @yongtang
/tensorflow/go @asimshankar
@ -46,7 +47,6 @@
/tensorflow/contrib/losses/ @alextp @ispirmustafa
/tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg
/tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa
/tensorflow/contrib/nccl/ @cwhipkey @zheng-xq
/tensorflow/contrib/opt/ @strategist333 @alextp
/tensorflow/contrib/pi_examples/ @maciekcc
/tensorflow/contrib/quantization/ @petewarden

View File

@ -72,7 +72,6 @@ py_library(
"//tensorflow/contrib/metrics:metrics_py",
"//tensorflow/contrib/mixed_precision:mixed_precision",
"//tensorflow/contrib/model_pruning",
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py",
"//tensorflow/contrib/nn:nn_py",
"//tensorflow/contrib/opt:opt_py",
@ -179,9 +178,7 @@ cc_library(
"//tensorflow/contrib/tensor_forest:stats_ops_kernels",
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
"//tensorflow/contrib/text:all_kernels",
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([
"//tensorflow/contrib/nccl:nccl_kernels",
]) + select({
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({
"//tensorflow:android": [],
"//tensorflow:ios": [],
"//tensorflow:linux_s390x": [],
@ -215,7 +212,6 @@ cc_library(
"//tensorflow/contrib/hadoop:dataset_ops_op_lib",
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
"//tensorflow/contrib/nccl:nccl_ops_op_lib",
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib",
"//tensorflow/contrib/rnn:all_ops",
"//tensorflow/contrib/seq2seq:beam_search_ops_op_lib",

View File

@ -62,7 +62,6 @@ from tensorflow.contrib import memory_stats
from tensorflow.contrib import metrics
from tensorflow.contrib import mixed_precision
from tensorflow.contrib import model_pruning
from tensorflow.contrib import nccl
from tensorflow.contrib import nn
from tensorflow.contrib import opt
from tensorflow.contrib import periodic_resample

View File

@ -29,10 +29,10 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nccl_ops",
],
)

View File

@ -21,11 +21,11 @@ from __future__ import print_function
import collections
import math
from tensorflow.contrib import nccl
from tensorflow.python.framework import device as device_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nccl_ops
def _flatten_tensors(tensors):
@ -693,7 +693,7 @@ def build_nccl_all_reduce(input_tensors, red_op, un_op=None):
ValueError: red_op not supported.
"""
if red_op == math_ops.add:
output_tensors = nccl.all_sum(input_tensors)
output_tensors = nccl_ops.all_sum(input_tensors)
else:
raise ValueError("red_op not supported by NCCL all-reduce: ", red_op)
if un_op:
@ -745,7 +745,7 @@ def _build_nccl_hybrid(input_tensors, red_op, upper_level_f):
for w in range(0, num_workers):
dst_tensors = []
with ops.device(per_worker_devices[w][0]):
broadcast_src = nccl.broadcast(array_ops.identity(level_2_output[w]))
broadcast_src = nccl_ops.broadcast(array_ops.identity(level_2_output[w]))
for d in per_worker_devices[w]:
with ops.device(d):
dst_tensors.append(array_ops.identity(broadcast_src))

View File

@ -308,11 +308,6 @@ tensorflow/contrib/model_pruning/examples
tensorflow/contrib/model_pruning/examples/cifar10
tensorflow/contrib/model_pruning/python
tensorflow/contrib/model_pruning/python/layers
tensorflow/contrib/nccl
tensorflow/contrib/nccl/kernels
tensorflow/contrib/nccl/ops
tensorflow/contrib/nccl/python
tensorflow/contrib/nccl/python/ops
tensorflow/contrib/nearest_neighbor
tensorflow/contrib/nearest_neighbor/kernels
tensorflow/contrib/nearest_neighbor/ops

View File

@ -97,9 +97,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/libsvm/ops/libsvm_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/ops/nearest_neighbor_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/resampler/kernels/resampler_ops.cc"

View File

@ -99,7 +99,6 @@ GENERATE_CONTRIB_OP_LIBRARY(image_distort_image "${tensorflow_source_dir}/tensor
GENERATE_CONTRIB_OP_LIBRARY(image_sirds "${tensorflow_source_dir}/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc")
GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(periodic_resample "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/ops/array_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(nearest_neighbor "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/ops/nearest_neighbor_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(resampler "${tensorflow_source_dir}/tensorflow/contrib/resampler/ops/resampler_ops.cc")

View File

@ -594,7 +594,6 @@ py_library(
deps = [
":values",
"//tensorflow/contrib/all_reduce:all_reduce_py",
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:collective_ops",
"//tensorflow/python:device",
@ -602,6 +601,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:nccl_ops",
],
)

View File

@ -21,7 +21,6 @@ from __future__ import print_function
import collections as pycoll
import threading
from tensorflow.contrib import nccl
from tensorflow.contrib.all_reduce.python import all_reduce
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.python.framework import device as pydev
@ -31,6 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nccl_ops
def aggregate_gradients_using_nccl(replica_grads):
@ -38,7 +38,7 @@ def aggregate_gradients_using_nccl(replica_grads):
agg_all_g_and_v = []
for single_g_and_v in zip(*replica_grads):
single_grads = [g for g, _ in single_g_and_v]
agg_grads = nccl.all_sum(single_grads)
agg_grads = nccl_ops.all_sum(single_grads)
agg_all_g_and_v.append(
[(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)])
@ -376,7 +376,7 @@ def sum_grad_and_var_all_reduce(grad_and_vars,
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
scaled_grads = [g for g, _ in grad_and_vars]
if alg == 'nccl':
summed_grads = nccl.all_sum(scaled_grads)
summed_grads = nccl_ops.all_sum(scaled_grads)
elif alg == 'xring':
summed_grads = all_reduce.build_ring_all_reduce(
scaled_grads, num_workers, num_shards, gpu_indices, math_ops.add)

View File

@ -1,177 +0,0 @@
# Description:
# Wrap NVIDIA (https://github.com/NVIDIA/nccl) NCCL with tensorflow ops.
# APIs are meant to change over time.
package(default_visibility = ["//tensorflow:__subpackages__"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
"tf_custom_op_library",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
load("//tensorflow:tensorflow.bzl", "if_not_windows_cuda")
tf_custom_op_library(
name = "python/ops/_nccl_ops.so",
srcs = [
"ops/nccl_ops.cc",
],
gpu_srcs = if_not_windows_cuda([
"kernels/nccl_manager.cc",
"kernels/nccl_manager.h",
"kernels/nccl_ops.cc",
]),
deps = [] + if_cuda([
"@local_config_nccl//:nccl",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:protos_all_proto_text",
]),
)
tf_cuda_cc_test(
name = "nccl_manager_test",
size = "medium",
srcs = if_cuda(
[
"kernels/nccl_manager.cc",
"kernels/nccl_manager.h",
"kernels/nccl_manager_test.cc",
],
[],
),
# Disabled on jenkins until errors finding nvmlShutdown are found.
tags = [
"manual",
"multi_gpu",
"no_oss",
"noguitar",
"notap",
],
deps =
if_cuda([
"@local_config_nccl//:nccl",
"//tensorflow/core:cuda",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
]),
)
tf_kernel_library(
name = "nccl_kernels",
srcs = if_cuda([
"kernels/nccl_manager.cc",
"kernels/nccl_manager.h",
"kernels/nccl_ops.cc",
"kernels/nccl_rewrite.cc",
]),
deps = if_cuda([
"@local_config_nccl//:nccl",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor",
]),
alwayslink = 1,
)
tf_gen_op_libs(
op_lib_names = ["nccl_ops"],
deps = [
"//tensorflow/core:lib",
],
)
tf_gen_op_wrapper_py(
name = "nccl_ops",
deps = [":nccl_ops_op_lib"],
)
# Test only nccl ops lib without dso to test behavior when NCCL lib is not
# installed. See nccl_dependency_test for more details.
#
# Users should use the public nccl_py lib that also adds the dso.
tf_custom_op_py_library(
name = "nccl_ops_lib_without_dso",
srcs = [
"__init__.py",
"python/ops/nccl_ops.py",
],
kernels = if_cuda([":nccl_kernels"]) + [
":nccl_ops_op_lib",
],
deps = [
":nccl_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:device",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
],
)
tf_custom_op_py_library(
name = "nccl_py",
dso = [":python/ops/_nccl_ops.so"],
visibility = ["//visibility:public"],
deps = [
":nccl_ops_lib_without_dso",
],
)
cuda_py_test(
name = "nccl_ops_test",
size = "small",
srcs = ["python/ops/nccl_ops_test.py"],
additional_deps = [
":nccl_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
# Disabled on jenkins until errors finding nvmlShutdown are found.
tags = [
"manual",
"multi_gpu",
"no_oss",
"noguitar",
"notap",
],
)
cuda_py_test(
name = "nccl_dependency_test",
size = "small",
srcs = ["python/ops/nccl_dependency_test.py"],
additional_deps = [
":nccl_ops_lib_without_dso",
"//tensorflow/python:constant_op",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform_test",
],
# Disable this test internally as static linking is used internally and only
# run for OSS to verify that NCCL is an optional dynamic dependency.
tags = [
"manual",
"noguitar",
"notap",
],
)

View File

@ -1,38 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for using NVIDIA nccl collective ops.
@@all_max
@@all_min
@@all_prod
@@all_sum
@@reduce_sum
@@broadcast
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_max
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_min
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_prod
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_sum
from tensorflow.contrib.nccl.python.ops.nccl_ops import broadcast
from tensorflow.contrib.nccl.python.ops.nccl_ops import reduce_sum
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)

View File

@ -1,59 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dependency test for nccl to test behavior when NCCL is not installed."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib import nccl
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
class NcclDependencyTest(test.TestCase):
"""Verifies that importing nccl ops lib does not fail even if NCCL is not
installed but nccl ops throws an exception on use if NCCL is not installed.
"""
def test_nccl_ops(self):
"""Tests behavior of nccl ops when NCCL is not installed."""
public_methods = [
m[0]
for m in tf_inspect.getmembers(nccl, tf_inspect.isfunction)
if not m[0].startswith('_')
]
for method_name in public_methods:
with ops.device('/device:CPU:0'):
tensor = constant_op.constant(1)
if method_name == 'broadcast':
arg = tensor
else:
arg = [tensor]
nccl_op = getattr(nccl, method_name)
with ops.device('/device:CPU:0'):
with self.assertRaisesRegexp(errors_impl.NotFoundError,
r'cannot open shared object file'):
nccl_op(arg)
if __name__ == '__main__':
test.main()

View File

@ -1068,6 +1068,7 @@ tf_gen_op_libs(
"logging_ops",
"manip_ops",
"math_ops",
"nccl_ops",
"nn_ops",
"no_op",
"parsing_ops",
@ -1216,6 +1217,7 @@ cc_library(
":lookup_ops_op_lib",
":manip_ops_op_lib",
":math_ops_op_lib",
":nccl_ops_op_lib",
":nn_ops_op_lib",
":no_op_op_lib",
":parsing_ops_op_lib",
@ -1395,6 +1397,7 @@ cc_library(
"//tensorflow/core/kernels:fact_op",
"//tensorflow/core/kernels:array_not_windows",
"//tensorflow/core/kernels:math_not_windows",
"//tensorflow/core/kernels:nccl_kernels",
"//tensorflow/core/kernels:quantized_ops",
"//tensorflow/core/kernels/neon:neon_depthwise_conv_op",
]) + if_mkl([

View File

@ -0,0 +1,19 @@
op {
graph_op_name: "NcclAllReduce"
summary: "Outputs a tensor containing the reduction across all input tensors."
description: <<END
Outputs a tensor containing the reduction across all input tensors passed to ops
within the same `shared_name.
The graph should be constructed so if one op runs with shared_name value `c`,
then `num_devices` ops will run with shared_name value `c`. Failure to do so
will cause the graph execution to fail to complete.
input: the input to the reduction
data: the value of the reduction across all `num_devices` devices.
reduction: the reduction operation to perform.
num_devices: The number of devices participating in this reduction.
shared_name: Identifier that shared between ops of the same reduction.
END
visibility: HIDDEN
}

View File

@ -0,0 +1,17 @@
op {
graph_op_name: "NcclBroadcast"
summary: "Sends `input` to all devices that are connected to the output."
description: <<END
Sends `input` to all devices that are connected to the output.
The graph should be constructed so that all ops connected to the output have a
valid device assignment, and the op itself is assigned one of these devices.
input: The input to the broadcast.
output: The same as input.
shape: The shape of the input tensor.
END
visibility: HIDDEN
}

View File

@ -0,0 +1,15 @@
op {
graph_op_name: "NcclReduce"
summary: "Reduces `input` from `num_devices` using `reduction` to a single device."
description: <<END
Reduces `input` from `num_devices` using `reduction` to a single device.
The graph should be constructed so that all inputs have a valid device
assignment, and the op itself is assigned one of these devices.
input: The input to the reduction.
data: the value of the reduction across all `num_devices` devices.
reduction: the reduction operation to perform.
END
visibility: HIDDEN
}

View File

@ -270,6 +270,20 @@ cc_library(
],
)
tf_kernel_library(
name = "nccl_kernels",
srcs = if_cuda([
"nccl_ops.cc",
]),
deps = if_cuda([
"@local_config_nccl//:nccl",
"//tensorflow/core/nccl:nccl_lib",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:nccl_ops_op_lib",
]),
)
tf_cuda_library(
name = "ops_testutil",
testonly = 1,

View File

@ -18,8 +18,8 @@ limitations under the License.
#include <vector>
#include "third_party/nccl/nccl.h"
#include "tensorflow/contrib/nccl/kernels/nccl_manager.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/nccl/nccl_manager.h"
namespace tensorflow {
namespace {

View File

@ -0,0 +1,60 @@
# Description:
# Wrap NVIDIA (https://github.com/NVIDIA/nccl) NCCL with tensorflow ops.
# APIs are meant to change over time.
package(default_visibility = ["//tensorflow:__subpackages__"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_copts")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
cc_library(
name = "nccl_lib",
srcs = if_cuda([
"nccl_manager.cc",
"nccl_manager.h",
"nccl_rewrite.cc",
]),
copts = tf_copts(),
deps = if_cuda([
"@local_config_nccl//:nccl",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor",
]),
alwayslink = 1,
)
tf_cuda_cc_test(
name = "nccl_manager_test",
size = "medium",
srcs = if_cuda(
[
"nccl_manager_test.cc",
],
[],
),
# Disabled on jenkins until errors finding nvmlShutdown are found.
tags = [
"manual",
"multi_gpu",
"no_oss",
"noguitar",
"notap",
],
deps =
if_cuda([
":nccl_lib",
"@local_config_nccl//:nccl",
"//tensorflow/core:cuda",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
]),
)

View File

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/nccl/kernels/nccl_manager.h"
#include "tensorflow/core/nccl/nccl_manager.h"
#include <utility>

View File

@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_NCCL_KERNELS_NCCL_MANAGER_H_
#define TENSORFLOW_CONTRIB_NCCL_KERNELS_NCCL_MANAGER_H_
#ifndef TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_
#define TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_
#ifdef GOOGLE_CUDA
@ -135,4 +135,4 @@ class NcclManager {
#endif // GOOGLE_CUDA
#endif // TENSORFLOW_CONTRIB_NCCL_KERNELS_NCCL_MANAGER_H_
#endif // TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_

View File

@ -19,11 +19,11 @@ limitations under the License.
#include <random>
#include <vector>
#include "tensorflow/contrib/nccl/kernels/nccl_manager.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/nccl/nccl_manager.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {

View File

@ -29,21 +29,7 @@ REGISTER_OP("NcclAllReduce")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Outputs a tensor containing the reduction across all input tensors passed to ops
within the same `shared_name.
The graph should be constructed so if one op runs with shared_name value `c`,
then `num_devices` ops will run with shared_name value `c`. Failure to do so
will cause the graph execution to fail to complete.
input: the input to the reduction
data: the value of the reduction across all `num_devices` devices.
reduction: the reduction operation to perform.
num_devices: The number of devices participating in this reduction.
shared_name: Identifier that shared between ops of the same reduction.
)doc");
.SetShapeFn(shape_inference::UnchangedShape);
// Note: This op has no kernel implementation, but is replaced by
// _NcclReduceSend and _NcclReduceRecv during graph optimization stage.
@ -54,17 +40,7 @@ REGISTER_OP("NcclReduce")
.Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Reduces `input` from `num_devices` using `reduction` to a single device.
The graph should be constructed so that all inputs have a valid device
assignment, and the op itself is assigned one of these devices.
input: The input to the reduction.
data: the value of the reduction across all `num_devices` devices.
reduction: the reduction operation to perform.
)doc");
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("_NcclReduceSend")
.Input("input: T")
@ -121,17 +97,7 @@ REGISTER_OP("NcclBroadcast")
.Attr("T: {half, float, float64, int32, int64}")
.Attr("shape: shape")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Sends `input` to all devices that are connected to the output.
The graph should be constructed so that all ops connected to the output have a
valid device assignment, and the op itself is assigned one of these devices.
input: The input to the broadcast.
output: The same as input.
shape: The shape of the input tensor.
)doc");
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("_NcclBroadcastSend")
.Input("input: T")

View File

@ -109,6 +109,7 @@ py_library(
":manip_ops",
":math_ops",
":metrics",
":nccl_ops",
":nn",
":ops",
":platform",
@ -5757,6 +5758,48 @@ py_test(
],
)
tf_gen_op_wrapper_private_py(
name = "nccl_ops_gen",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/core:nccl_ops_op_lib",
],
)
py_library(
name = "nccl_ops",
srcs = ["ops/nccl_ops.py"],
srcs_version = "PY2AND3",
visibility = visibility + [
"//learning/deepmind/tensorflow:__subpackages__",
],
deps = [
":framework_for_generated_wrappers",
":nccl_ops_gen",
],
)
cuda_py_test(
name = "nccl_ops_test",
size = "small",
srcs = ["ops/nccl_ops_test.py"],
additional_deps = [
":nccl_ops",
":array_ops",
":client_testlib",
":framework_test_lib",
":platform_test",
],
# Disabled on jenkins until errors finding nvmlShutdown are found.
tags = [
"manual",
"multi_gpu",
"no_oss",
"noguitar",
"notap",
],
)
py_binary(
name = "graph_analyzer",
srcs = [

View File

@ -19,15 +19,11 @@ from __future__ import print_function
import threading
from tensorflow.contrib.nccl.ops import gen_nccl_ops
from tensorflow.contrib.util import loader
from tensorflow.python.eager import context
from tensorflow.python.framework import device
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
from tensorflow.python.ops import gen_nccl_ops
_nccl_ops_so = None
_module_lock = threading.Lock()
_shared_name_counter = 0
@ -182,7 +178,6 @@ def broadcast(tensor):
A tensor with the value of `src_tensor`, which can be used as input to
ops on other GPU devices.
"""
_validate_and_load_nccl_so()
_check_device(tensor)
with ops.device(tensor.device):
@ -214,7 +209,6 @@ def _apply_all_reduce(reduction, tensors):
"""Helper function for all_* functions."""
if not tensors:
raise ValueError('Must pass >0 tensors to all reduce operations')
_validate_and_load_nccl_so()
shared_name = _get_shared_name()
res = []
@ -236,7 +230,6 @@ def _apply_reduce(reduction, tensors):
"""Helper function for reduce_* functions."""
if not tensors:
raise ValueError('Must pass >0 tensors to reduce operations')
_validate_and_load_nccl_so()
for t in tensors:
_check_device(t)
@ -262,27 +255,3 @@ def _check_device(tensor, expected=None):
raise ValueError('Device assignment required for nccl collective ops')
if expected and expected != tensor.device:
raise ValueError('Expected device %s, got %s' % (expected, tensor.device))
def _maybe_load_nccl_ops_so():
"""Loads nccl ops so if it hasn't been loaded already."""
with _module_lock:
global _nccl_ops_so
if not _nccl_ops_so:
_nccl_ops_so = loader.load_op_library(
resource_loader.get_path_to_datafile('_nccl_ops.so'))
def _validate_and_load_nccl_so():
"""Validates calling context and loads nccl ops so file.
Raises:
ValueError: Ops are not supported.
errors_impl.NotFoundError: nccl library is not installed.
"""
if context.executing_eagerly():
raise ValueError('Nccl ops are not supported in eager mode')
_maybe_load_nccl_ops_so()

View File

@ -19,14 +19,13 @@ from __future__ import division
from __future__ import print_function
from functools import partial
import os
import numpy as np
from tensorflow.contrib import nccl
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import nccl_ops
from tensorflow.python.platform import test
@ -52,7 +51,7 @@ def _NcclBroadcast(tensors, devices):
sender = np.random.randint(0, len(devices))
with ops.device(devices[sender]):
tensor = array_ops.identity(tensors[0])
broadcast = nccl.broadcast(tensor)
broadcast = nccl_ops.broadcast(tensor)
return _DeviceTensors([broadcast] * len(devices), devices)
@ -61,7 +60,6 @@ class NcclTestCase(test.TestCase):
def _Test(self,
nccl_reduce,
numpy_fn,
dtypes=[np.float16, np.float32, np.int32, np.int64, np.float64],
device_sets=(['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
['/device:GPU:1', '/device:GPU:0'])):
"""Tests that nccl_reduce does the same as reduction with numpy_fn.
@ -74,10 +72,7 @@ class NcclTestCase(test.TestCase):
two.
device_sets: Tuple of virtual devices to run test on.
"""
# Enable NCCL printouts.
os.environ["NCCL_DEBUG"] = "INFO"
for dtype in dtypes:
for dtype in [np.float16, np.float32, np.int32, np.int64, np.float64]:
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess:
@ -129,36 +124,36 @@ class NcclTestCase(test.TestCase):
reduce_tensors, inputs, losses, colocate_gradients_with_ops=True)
return [g for g in grads if g is not None]
# int types are considered not 'trainable' and no gradients are generated.
self._Test(_Gradient, numpy_fn, dtypes=[np.float16, np.float32, np.float64])
self._Test(_Gradient, numpy_fn)
class AllReduceTest(NcclTestCase):
def testAllReduce(self):
self._Test(partial(_NcclAllReduce, nccl.all_sum), lambda x, y: x + y)
self._Test(partial(_NcclAllReduce, nccl.all_prod), lambda x, y: x * y)
self._Test(partial(_NcclAllReduce, nccl.all_min), np.minimum)
self._Test(partial(_NcclAllReduce, nccl.all_max), np.maximum)
self._Test(partial(_NcclAllReduce, nccl_ops.all_sum), lambda x, y: x + y)
self._Test(partial(_NcclAllReduce, nccl_ops.all_prod), lambda x, y: x * y)
self._Test(partial(_NcclAllReduce, nccl_ops.all_min), np.minimum)
self._Test(partial(_NcclAllReduce, nccl_ops.all_max), np.maximum)
def testAllSumGrad(self):
self._TestGradient(
partial(_NcclAllReduce, nccl.all_sum), lambda x, y: x + y)
partial(_NcclAllReduce, nccl_ops.all_sum), lambda x, y: x + y)
def testErrors(self):
with self.assertRaisesRegexp(ValueError, 'Device assignment required'):
nccl.all_sum([array_ops.identity(np.random.random_sample((3, 4)))])
nccl_ops.all_sum([array_ops.identity(np.random.random_sample((3, 4)))])
with self.assertRaisesRegexp(ValueError, 'Must pass >0 tensors'):
nccl.all_sum([])
nccl_ops.all_sum([])
class SingleReduceTest(NcclTestCase):
def testSum(self):
self._Test(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x + y)
self._Test(partial(_NcclReduce, nccl_ops.reduce_sum), lambda x, y: x + y)
def testSumGrad(self):
self._TestGradient(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x)
self._TestGradient(partial(_NcclReduce, nccl_ops.reduce_sum),
lambda x, y: x)
class BroadcastTest(NcclTestCase):
@ -189,8 +184,8 @@ class CombinedTest(NcclTestCase):
"""Test all-reduce vs. single-reduce plus broadcast in one session.run."""
def _Combined(self, tensors, devices):
all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices)
single_reduce_tensors = _NcclReduce(nccl.reduce_sum, tensors, devices)
all_reduce_tensors = _NcclAllReduce(nccl_ops.all_sum, tensors, devices)
single_reduce_tensors = _NcclReduce(nccl_ops.reduce_sum, tensors, devices)
broadcast_tensors = _NcclBroadcast(single_reduce_tensors, devices)
return all_reduce_tensors + broadcast_tensors