Add sync_devices function.

There was an RFC for this API: https://github.com/tensorflow/community/pull/434

PiperOrigin-RevId: 504062646
This commit is contained in:
Reed Wanderman-Milne 2023-01-23 13:09:37 -08:00 committed by TensorFlower Gardener
parent 3fce1fc72c
commit 267c63aa09
22 changed files with 362 additions and 3 deletions

View File

@ -120,6 +120,9 @@
`rerandomize_each_iteration=True`, the `sample_from_datasets()`
operation will use a different (deterministic) sequence of numbers every
epoch.
* `tf.test`:
* Added `tf.test.experimental.sync_devices`, which is useful for
accurately measuring performance in benchmarks.
# Bug Fixes and Other Changes

View File

@ -605,6 +605,7 @@ cc_library(
"//tensorflow/core/kernels:random_index_shuffle_ops",
"//tensorflow/core/kernels:random_ops",
"//tensorflow/core/kernels:stateful_random_ops",
"//tensorflow/core/kernels:sync_ops",
"//tensorflow/core/kernels:random_binomial_op",
"//tensorflow/core/kernels:random_poisson_op",
"//tensorflow/core/kernels:required",
@ -971,6 +972,7 @@ filegroup(
"stateless_random_ops_v2_op_lib",
"string_ops_op_lib",
"summary_ops_op_lib",
"sync_ops_op_lib",
"tpu_configuration_ops_op_lib",
"tpu_cross_replica_ops_op_lib",
"tpu_embedding_ops_op_lib",

View File

@ -0,0 +1,9 @@
op {
graph_op_name: "SyncDevice"
visibility: HIDDEN
summary: "Synchronizes the device this op is run on."
description: <<END
Only GPU ops are asynchrous in TensorFlow, and so this only has an effect when
run on GPUs. On GPUs, this op synchronizes the GPU's compute stream.
END
}

View File

@ -7614,6 +7614,14 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "sync_ops",
prefix = "sync_ops",
deps = [
"//tensorflow/core:framework",
],
)
# Library to link with when compiling the cwise_op kernels directly,
# e.g. for selective registration.
# should not be linked by projects that also link the cwise_op library.

View File

@ -0,0 +1,59 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow {
namespace {
class SyncDeviceOp : public OpKernel {
public:
explicit SyncDeviceOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {}
private:
TF_DISALLOW_COPY_AND_ASSIGN(SyncDeviceOp);
};
REGISTER_KERNEL_BUILDER(Name("SyncDevice").Device(DEVICE_DEFAULT),
SyncDeviceOp);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
class SyncDeviceGpuOp : public OpKernel {
public:
explicit SyncDeviceGpuOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const DeviceBase::AcceleratorDeviceInfo* info =
context->device()->tensorflow_accelerator_device_info();
if (info && info->stream) {
OP_REQUIRES_OK(context, info->stream->BlockHostUntilDone());
}
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(SyncDeviceGpuOp);
};
REGISTER_KERNEL_BUILDER(Name("SyncDevice").Device(DEVICE_GPU), SyncDeviceGpuOp);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace
} // namespace tensorflow

View File

@ -90,6 +90,7 @@ tf_gen_op_libs(
"state_ops",
"stateless_random_ops",
"stateless_random_ops_v2",
"sync_ops",
"summary_ops",
"training_ops",
],
@ -284,6 +285,7 @@ cc_library(
":image_ops_op_lib",
":io_ops_op_lib",
":linalg_ops_op_lib",
":sync_ops_op_lib",
":list_ops_op_lib",
":map_ops_op_lib",
":logging_ops_op_lib",

View File

@ -0,0 +1,27 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
// SyncDevice is stateful because it has a side effect: it synchronizes the GPU
// steam. If it weren't stateful, optimization passes like dead code elimination
// might incorrectly remove it.
REGISTER_OP("SyncDevice")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs);
} // namespace tensorflow

View File

@ -998,6 +998,16 @@ tf_gen_op_wrapper_private_py(
],
)
tf_gen_op_wrapper_private_py(
name = "sync_ops_gen",
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/core:sync_ops_op_lib",
],
)
py_library(
name = "array_grad",
srcs = ["ops/array_grad.py"],

View File

@ -10,6 +10,7 @@ load(
"tf_cc_shared_object",
"tf_cc_test",
"tf_gen_op_wrapper_py",
"tf_kernel_library",
)
load("//tensorflow:tensorflow.default.bzl", "cuda_py_test", "tf_py_test", "tf_python_pybind_extension")
load("//tensorflow:pytype.default.bzl", "pytype_library", "pytype_strict_library")
@ -1521,6 +1522,7 @@ py_library(
"//tensorflow/python:pywrap_tf_session",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:session",
"//tensorflow/python:sync_ops_gen",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
@ -1860,10 +1862,11 @@ tf_gen_op_wrapper_py(
deps = [":test_ops_kernels"],
)
cc_library(
tf_kernel_library(
name = "test_ops_kernels",
srcs = ["test_ops.cc"],
linkstatic = 1,
hdrs = ["test_ops.h"],
gpu_srcs = ["test_ops.cu.cc"],
deps = [
"@com_google_absl//absl/time",
"//tensorflow/core:framework",
@ -2115,7 +2118,7 @@ tf_py_test(
],
)
tf_py_test(
cuda_py_test(
name = "test_util_test",
size = "small",
srcs = ["test_util_test.py"],

View File

@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines ops and op kernels that are only used by Python tests.
#include "tensorflow/python/framework/test_ops.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "tensorflow/core/framework/common_shape_fns.h"
@ -66,6 +70,7 @@ REGISTER_OP("GetDeadline")
REGISTER_OP("SleepOp")
.Input("sleep_seconds: int32")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("SleepIdentityOp")
@ -73,6 +78,7 @@ REGISTER_OP("SleepIdentityOp")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_RESOURCE_HANDLE_OP(StubResource);
@ -222,6 +228,20 @@ class SleepOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("SleepOp").Device(DEVICE_CPU), SleepOp);
#if GOOGLE_CUDA
class SleepGpuOp : public OpKernel {
public:
explicit SleepGpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
GpuSleep(ctx, ctx->input(0).scalar<int>()());
}
};
REGISTER_KERNEL_BUILDER(
Name("SleepOp").Device(DEVICE_GPU).HostMemory("sleep_seconds"), SleepGpuOp);
#endif // GOOGLE_CUDA
class SleepIdentityOp : public OpKernel {
public:
explicit SleepIdentityOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}

View File

@ -0,0 +1,47 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/util/gpu_kernel_helper.h"
typedef Eigen::GpuDevice GPUDevice;
namespace tensorflow {
namespace {
__global__ void sleep_kernel(int seconds) {
#if __CUDA_ARCH__ >= 700 // __nanosleep requires compute capability 7.0
int64_t nanoseconds = int64_t{seconds} * 1'000'000'000;
// Passing too high a number to __nanosleep makes it sleep for much less time
// than the passed-in number. So only pass 1,000,000 and keep calling
// __nanosleep in a loop.
for (int64_t i = 0; i < nanoseconds; i += 1'000'000) {
__nanosleep(1'000'000);
}
#endif
}
} // namespace
void GpuSleep(OpKernelContext* ctx, int seconds) {
auto* cu_stream = ctx->eigen_device<GPUDevice>().stream();
CHECK(cu_stream); // Crash OK
TF_CHECK_OK(GpuLaunchKernel(sleep_kernel, 1, 1, 0, cu_stream, seconds));
}
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,26 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_PYTHON_FRAMEWORK_TEST_OPS_H_
#define TENSORFLOW_PYTHON_FRAMEWORK_TEST_OPS_H_
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
// Run a kernel on the GPU that sleeps for the given time
void GpuSleep(OpKernelContext* ctx, int seconds);
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_FRAMEWORK_TEST_OPS_H_

View File

@ -68,6 +68,7 @@ from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import gen_sync_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
@ -3991,3 +3992,59 @@ class TestDelta:
def Get(self):
value = _test_metrics_util.test_counter_value(self.name, self.label)
return value - self.last_value
@tf_export("test.experimental.sync_devices")
def sync_devices():
"""Synchronizes all devices.
By default, GPUs run asynchronously. This means that when you run an op on the
GPU, like `tf.linalg.matmul`, the op may still be running on the GPU when the
function returns. Non-GPU devices can also be made to run asynchronously by
calling `tf.config.experimental.set_synchronous_execution(False)`. Calling
`sync_devices()` blocks until pending ops have finished executing. This is
primarily useful for measuring performance during a benchmark.
For example, here is how you can measure how long `tf.linalg.matmul` runs:
>>> import time
>>> x = tf.random.normal((4096, 4096))
>>> tf.linalg.matmul(x, x) # Warmup.
>>> tf.test.experimental.sync_devices() # Block until warmup has completed.
>>>
>>> start = time.time()
>>> y = tf.linalg.matmul(x, x)
>>> tf.test.experimental.sync_devices() # Block until matmul has completed.
>>> end = time.time()
>>> print(f'Time taken: {end - start}')
If the call to `sync_devices()` was omitted, the time printed could be too
small. This is because the op could still be running asynchronously when
the line `end = time.time()` is executed.
Raises:
RuntimeError: If run outside Eager mode. This must be called in Eager mode,
outside any `tf.function`s.
"""
if not context.executing_eagerly():
raise RuntimeError(
"sync_devices() must only be called in Eager mode, outside tf.functions"
)
# There are two sources of asynchrony in TensorFlow:
#
# 1. On GPUs, kernels are run on a CUDA stream, which is inherently
# asynchronous.
# 2. Calling `tf.config.experimental.set_synchronous_execution(False)` makes
# all ops asynchronous, in which case TensorFlow maintains internal queues
# of pending ops.
#
# Calling SyncDevice addresses source (1). Calling async_await addresses
# source (2). It is important that SyncDevice() is called before async_wait(),
# otherwise the SyncDevice op itself may still be pending on an internal
# TensorFlow queue when the sync_devices() Python function returns.
devices = config.list_logical_devices()
for dev in devices:
with ops.device(dev.name):
gen_sync_ops.SyncDevice()
context.async_wait()

View File

@ -18,6 +18,7 @@ import collections
import copy
import random
import threading
import time
import unittest
import weakref
@ -33,6 +34,7 @@ from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -1118,5 +1120,57 @@ class RunFunctionsEagerlyInV2Test(test_util.TensorFlowTestCase,
self.assertTrue(isinstance(t, ops.Tensor) for t in results)
class SyncDevicesTest(test_util.TensorFlowTestCase):
def tearDown(self):
super().tearDown()
config.set_synchronous_execution(True)
def test_sync_device_cpu(self):
with context.eager_mode(), ops.device("/CPU:0"):
config.set_synchronous_execution(False)
start = time.time()
test_ops.sleep_op(sleep_seconds=1)
self.assertLess(time.time() - start, 1.0)
test_util.sync_devices()
self.assertGreater(time.time() - start, 1.0)
config.set_synchronous_execution(True)
start = time.time()
test_ops.sleep_op(sleep_seconds=1)
self.assertGreaterEqual(time.time() - start, 1.0)
start = time.time()
test_util.sync_devices()
self.assertLess(time.time() - start, 1.0)
def test_sync_device_gpu(self):
if not test_util.is_gpu_available(min_cuda_compute_capability=(7, 0)):
# sleep_op requires compute capability 7.0
self.skipTest("Requires GPU with compute capability 7.0")
with context.eager_mode(), ops.device("/GPU:0"):
config.set_synchronous_execution(False)
start = time.time()
test_ops.sleep_op(sleep_seconds=1)
self.assertLess(time.time() - start, 1.0)
test_util.sync_devices()
self.assertGreater(time.time() - start, 1.0)
config.set_synchronous_execution(True)
start = time.time()
test_ops.sleep_op(sleep_seconds=1)
self.assertLess(time.time() - start, 1.0)
start = time.time()
test_util.sync_devices()
self.assertGreaterEqual(time.time() - start, 1.0)
def test_sync_devices_graph_mode_error(self):
with context.graph_mode():
with self.assertRaisesRegex(
RuntimeError, r"sync_devices\(\) must only be called in Eager mode"
):
test_util.sync_devices()
if __name__ == "__main__":
googletest.main()

View File

@ -104,6 +104,7 @@ TENSORFLOW_API_INIT_FILES = [
"summary/experimental/__init__.py",
"sysconfig/__init__.py",
"test/__init__.py",
"test/experimental/__init__.py",
"tpu/experimental/embedding/__init__.py",
"tpu/experimental/__init__.py",
"tpu/__init__.py",

View File

@ -89,6 +89,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"summary/__init__.py",
"sysconfig/__init__.py",
"test/__init__.py",
"test/experimental/__init__.py",
"tpu/experimental/embedding/__init__.py",
"tpu/experimental/__init__.py",
"tpu/__init__.py",

View File

@ -4884,6 +4884,10 @@ tf_module {
name: "SymbolicGradient"
argspec: "args=[\'input\', \'Tout\', \'f\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "SyncDevice"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "TFRecordDataset"
argspec: "args=[\'filenames\', \'compression_type\', \'buffer_size\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "

View File

@ -0,0 +1,7 @@
path: "tensorflow.test.experimental"
tf_module {
member_method {
name: "sync_devices"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -12,6 +12,10 @@ tf_module {
name: "TestCase"
mtype: "<type \'type\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"
}
member {
name: "mock"
mtype: "<type \'module\'>"

View File

@ -4884,6 +4884,10 @@ tf_module {
name: "SymbolicGradient"
argspec: "args=[\'input\', \'Tout\', \'f\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "SyncDevice"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "TFRecordDataset"
argspec: "args=[\'filenames\', \'compression_type\', \'buffer_size\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "

View File

@ -0,0 +1,7 @@
path: "tensorflow.test.experimental"
tf_module {
member_method {
name: "sync_devices"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -8,6 +8,10 @@ tf_module {
name: "TestCase"
mtype: "<type \'type\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"
}
member_method {
name: "assert_equal_graph_def"
argspec: "args=[\'expected\', \'actual\'], varargs=None, keywords=None, defaults=None"