mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Add sync_devices function.
There was an RFC for this API: https://github.com/tensorflow/community/pull/434 PiperOrigin-RevId: 504062646
This commit is contained in:
parent
3fce1fc72c
commit
267c63aa09
|
|
@ -120,6 +120,9 @@
|
|||
`rerandomize_each_iteration=True`, the `sample_from_datasets()`
|
||||
operation will use a different (deterministic) sequence of numbers every
|
||||
epoch.
|
||||
* `tf.test`:
|
||||
* Added `tf.test.experimental.sync_devices`, which is useful for
|
||||
accurately measuring performance in benchmarks.
|
||||
|
||||
# Bug Fixes and Other Changes
|
||||
|
||||
|
|
|
|||
|
|
@ -605,6 +605,7 @@ cc_library(
|
|||
"//tensorflow/core/kernels:random_index_shuffle_ops",
|
||||
"//tensorflow/core/kernels:random_ops",
|
||||
"//tensorflow/core/kernels:stateful_random_ops",
|
||||
"//tensorflow/core/kernels:sync_ops",
|
||||
"//tensorflow/core/kernels:random_binomial_op",
|
||||
"//tensorflow/core/kernels:random_poisson_op",
|
||||
"//tensorflow/core/kernels:required",
|
||||
|
|
@ -971,6 +972,7 @@ filegroup(
|
|||
"stateless_random_ops_v2_op_lib",
|
||||
"string_ops_op_lib",
|
||||
"summary_ops_op_lib",
|
||||
"sync_ops_op_lib",
|
||||
"tpu_configuration_ops_op_lib",
|
||||
"tpu_cross_replica_ops_op_lib",
|
||||
"tpu_embedding_ops_op_lib",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
op {
|
||||
graph_op_name: "SyncDevice"
|
||||
visibility: HIDDEN
|
||||
summary: "Synchronizes the device this op is run on."
|
||||
description: <<END
|
||||
Only GPU ops are asynchrous in TensorFlow, and so this only has an effect when
|
||||
run on GPUs. On GPUs, this op synchronizes the GPU's compute stream.
|
||||
END
|
||||
}
|
||||
|
|
@ -7614,6 +7614,14 @@ tf_kernel_library(
|
|||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "sync_ops",
|
||||
prefix = "sync_ops",
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
# Library to link with when compiling the cwise_op kernels directly,
|
||||
# e.g. for selective registration.
|
||||
# should not be linked by projects that also link the cwise_op library.
|
||||
|
|
|
|||
59
tensorflow/core/kernels/sync_ops.cc
Normal file
59
tensorflow/core/kernels/sync_ops.cc
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class SyncDeviceOp : public OpKernel {
|
||||
public:
|
||||
explicit SyncDeviceOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SyncDeviceOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SyncDevice").Device(DEVICE_DEFAULT),
|
||||
SyncDeviceOp);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
class SyncDeviceGpuOp : public OpKernel {
|
||||
public:
|
||||
explicit SyncDeviceGpuOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const DeviceBase::AcceleratorDeviceInfo* info =
|
||||
context->device()->tensorflow_accelerator_device_info();
|
||||
if (info && info->stream) {
|
||||
OP_REQUIRES_OK(context, info->stream->BlockHostUntilDone());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SyncDeviceGpuOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SyncDevice").Device(DEVICE_GPU), SyncDeviceGpuOp);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
|
@ -90,6 +90,7 @@ tf_gen_op_libs(
|
|||
"state_ops",
|
||||
"stateless_random_ops",
|
||||
"stateless_random_ops_v2",
|
||||
"sync_ops",
|
||||
"summary_ops",
|
||||
"training_ops",
|
||||
],
|
||||
|
|
@ -284,6 +285,7 @@ cc_library(
|
|||
":image_ops_op_lib",
|
||||
":io_ops_op_lib",
|
||||
":linalg_ops_op_lib",
|
||||
":sync_ops_op_lib",
|
||||
":list_ops_op_lib",
|
||||
":map_ops_op_lib",
|
||||
":logging_ops_op_lib",
|
||||
|
|
|
|||
27
tensorflow/core/ops/sync_ops.cc
Normal file
27
tensorflow/core/ops/sync_ops.cc
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// SyncDevice is stateful because it has a side effect: it synchronizes the GPU
|
||||
// steam. If it weren't stateful, optimization passes like dead code elimination
|
||||
// might incorrectly remove it.
|
||||
REGISTER_OP("SyncDevice")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
@ -998,6 +998,16 @@ tf_gen_op_wrapper_private_py(
|
|||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "sync_ops_gen",
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:sync_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "array_grad",
|
||||
srcs = ["ops/array_grad.py"],
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ load(
|
|||
"tf_cc_shared_object",
|
||||
"tf_cc_test",
|
||||
"tf_gen_op_wrapper_py",
|
||||
"tf_kernel_library",
|
||||
)
|
||||
load("//tensorflow:tensorflow.default.bzl", "cuda_py_test", "tf_py_test", "tf_python_pybind_extension")
|
||||
load("//tensorflow:pytype.default.bzl", "pytype_library", "pytype_strict_library")
|
||||
|
|
@ -1521,6 +1522,7 @@ py_library(
|
|||
"//tensorflow/python:pywrap_tf_session",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:sync_ops_gen",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
|
|
@ -1860,10 +1862,11 @@ tf_gen_op_wrapper_py(
|
|||
deps = [":test_ops_kernels"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
tf_kernel_library(
|
||||
name = "test_ops_kernels",
|
||||
srcs = ["test_ops.cc"],
|
||||
linkstatic = 1,
|
||||
hdrs = ["test_ops.h"],
|
||||
gpu_srcs = ["test_ops.cu.cc"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/time",
|
||||
"//tensorflow/core:framework",
|
||||
|
|
@ -2115,7 +2118,7 @@ tf_py_test(
|
|||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
cuda_py_test(
|
||||
name = "test_util_test",
|
||||
size = "small",
|
||||
srcs = ["test_util_test.py"],
|
||||
|
|
|
|||
|
|
@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines ops and op kernels that are only used by Python tests.
|
||||
|
||||
#include "tensorflow/python/framework/test_ops.h"
|
||||
|
||||
#include "absl/time/clock.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
|
|
@ -66,6 +70,7 @@ REGISTER_OP("GetDeadline")
|
|||
|
||||
REGISTER_OP("SleepOp")
|
||||
.Input("sleep_seconds: int32")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("SleepIdentityOp")
|
||||
|
|
@ -73,6 +78,7 @@ REGISTER_OP("SleepIdentityOp")
|
|||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_RESOURCE_HANDLE_OP(StubResource);
|
||||
|
|
@ -222,6 +228,20 @@ class SleepOp : public OpKernel {
|
|||
|
||||
REGISTER_KERNEL_BUILDER(Name("SleepOp").Device(DEVICE_CPU), SleepOp);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
class SleepGpuOp : public OpKernel {
|
||||
public:
|
||||
explicit SleepGpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
GpuSleep(ctx, ctx->input(0).scalar<int>()());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("SleepOp").Device(DEVICE_GPU).HostMemory("sleep_seconds"), SleepGpuOp);
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
class SleepIdentityOp : public OpKernel {
|
||||
public:
|
||||
explicit SleepIdentityOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
|
|
|||
47
tensorflow/python/framework/test_ops.cu.cc
Normal file
47
tensorflow/python/framework/test_ops.cu.cc
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
__global__ void sleep_kernel(int seconds) {
|
||||
#if __CUDA_ARCH__ >= 700 // __nanosleep requires compute capability 7.0
|
||||
int64_t nanoseconds = int64_t{seconds} * 1'000'000'000;
|
||||
// Passing too high a number to __nanosleep makes it sleep for much less time
|
||||
// than the passed-in number. So only pass 1,000,000 and keep calling
|
||||
// __nanosleep in a loop.
|
||||
for (int64_t i = 0; i < nanoseconds; i += 1'000'000) {
|
||||
__nanosleep(1'000'000);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void GpuSleep(OpKernelContext* ctx, int seconds) {
|
||||
auto* cu_stream = ctx->eigen_device<GPUDevice>().stream();
|
||||
CHECK(cu_stream); // Crash OK
|
||||
TF_CHECK_OK(GpuLaunchKernel(sleep_kernel, 1, 1, 0, cu_stream, seconds));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // GOOGLE_CUDA
|
||||
26
tensorflow/python/framework/test_ops.h
Normal file
26
tensorflow/python/framework/test_ops.h
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_PYTHON_FRAMEWORK_TEST_OPS_H_
|
||||
#define TENSORFLOW_PYTHON_FRAMEWORK_TEST_OPS_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Run a kernel on the GPU that sleeps for the given time
|
||||
void GpuSleep(OpKernelContext* ctx, int seconds);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_FRAMEWORK_TEST_OPS_H_
|
||||
|
|
@ -68,6 +68,7 @@ from tensorflow.python.framework import versions
|
|||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import control_flow_util_v2
|
||||
from tensorflow.python.ops import gen_sync_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
|
|
@ -3991,3 +3992,59 @@ class TestDelta:
|
|||
def Get(self):
|
||||
value = _test_metrics_util.test_counter_value(self.name, self.label)
|
||||
return value - self.last_value
|
||||
|
||||
|
||||
@tf_export("test.experimental.sync_devices")
|
||||
def sync_devices():
|
||||
"""Synchronizes all devices.
|
||||
|
||||
By default, GPUs run asynchronously. This means that when you run an op on the
|
||||
GPU, like `tf.linalg.matmul`, the op may still be running on the GPU when the
|
||||
function returns. Non-GPU devices can also be made to run asynchronously by
|
||||
calling `tf.config.experimental.set_synchronous_execution(False)`. Calling
|
||||
`sync_devices()` blocks until pending ops have finished executing. This is
|
||||
primarily useful for measuring performance during a benchmark.
|
||||
|
||||
For example, here is how you can measure how long `tf.linalg.matmul` runs:
|
||||
|
||||
>>> import time
|
||||
>>> x = tf.random.normal((4096, 4096))
|
||||
>>> tf.linalg.matmul(x, x) # Warmup.
|
||||
>>> tf.test.experimental.sync_devices() # Block until warmup has completed.
|
||||
>>>
|
||||
>>> start = time.time()
|
||||
>>> y = tf.linalg.matmul(x, x)
|
||||
>>> tf.test.experimental.sync_devices() # Block until matmul has completed.
|
||||
>>> end = time.time()
|
||||
>>> print(f'Time taken: {end - start}')
|
||||
|
||||
If the call to `sync_devices()` was omitted, the time printed could be too
|
||||
small. This is because the op could still be running asynchronously when
|
||||
the line `end = time.time()` is executed.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If run outside Eager mode. This must be called in Eager mode,
|
||||
outside any `tf.function`s.
|
||||
"""
|
||||
if not context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"sync_devices() must only be called in Eager mode, outside tf.functions"
|
||||
)
|
||||
|
||||
# There are two sources of asynchrony in TensorFlow:
|
||||
#
|
||||
# 1. On GPUs, kernels are run on a CUDA stream, which is inherently
|
||||
# asynchronous.
|
||||
# 2. Calling `tf.config.experimental.set_synchronous_execution(False)` makes
|
||||
# all ops asynchronous, in which case TensorFlow maintains internal queues
|
||||
# of pending ops.
|
||||
#
|
||||
# Calling SyncDevice addresses source (1). Calling async_await addresses
|
||||
# source (2). It is important that SyncDevice() is called before async_wait(),
|
||||
# otherwise the SyncDevice op itself may still be pending on an internal
|
||||
# TensorFlow queue when the sync_devices() Python function returns.
|
||||
devices = config.list_logical_devices()
|
||||
for dev in devices:
|
||||
with ops.device(dev.name):
|
||||
gen_sync_ops.SyncDevice()
|
||||
context.async_wait()
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import collections
|
|||
import copy
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
import weakref
|
||||
|
||||
|
|
@ -33,6 +34,7 @@ from tensorflow.python.compat import compat
|
|||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
|
|
@ -1118,5 +1120,57 @@ class RunFunctionsEagerlyInV2Test(test_util.TensorFlowTestCase,
|
|||
self.assertTrue(isinstance(t, ops.Tensor) for t in results)
|
||||
|
||||
|
||||
class SyncDevicesTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
config.set_synchronous_execution(True)
|
||||
|
||||
def test_sync_device_cpu(self):
|
||||
with context.eager_mode(), ops.device("/CPU:0"):
|
||||
config.set_synchronous_execution(False)
|
||||
start = time.time()
|
||||
test_ops.sleep_op(sleep_seconds=1)
|
||||
self.assertLess(time.time() - start, 1.0)
|
||||
test_util.sync_devices()
|
||||
self.assertGreater(time.time() - start, 1.0)
|
||||
|
||||
config.set_synchronous_execution(True)
|
||||
start = time.time()
|
||||
test_ops.sleep_op(sleep_seconds=1)
|
||||
self.assertGreaterEqual(time.time() - start, 1.0)
|
||||
start = time.time()
|
||||
test_util.sync_devices()
|
||||
self.assertLess(time.time() - start, 1.0)
|
||||
|
||||
def test_sync_device_gpu(self):
|
||||
if not test_util.is_gpu_available(min_cuda_compute_capability=(7, 0)):
|
||||
# sleep_op requires compute capability 7.0
|
||||
self.skipTest("Requires GPU with compute capability 7.0")
|
||||
|
||||
with context.eager_mode(), ops.device("/GPU:0"):
|
||||
config.set_synchronous_execution(False)
|
||||
start = time.time()
|
||||
test_ops.sleep_op(sleep_seconds=1)
|
||||
self.assertLess(time.time() - start, 1.0)
|
||||
test_util.sync_devices()
|
||||
self.assertGreater(time.time() - start, 1.0)
|
||||
|
||||
config.set_synchronous_execution(True)
|
||||
start = time.time()
|
||||
test_ops.sleep_op(sleep_seconds=1)
|
||||
self.assertLess(time.time() - start, 1.0)
|
||||
start = time.time()
|
||||
test_util.sync_devices()
|
||||
self.assertGreaterEqual(time.time() - start, 1.0)
|
||||
|
||||
def test_sync_devices_graph_mode_error(self):
|
||||
with context.graph_mode():
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"sync_devices\(\) must only be called in Eager mode"
|
||||
):
|
||||
test_util.sync_devices()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ TENSORFLOW_API_INIT_FILES = [
|
|||
"summary/experimental/__init__.py",
|
||||
"sysconfig/__init__.py",
|
||||
"test/__init__.py",
|
||||
"test/experimental/__init__.py",
|
||||
"tpu/experimental/embedding/__init__.py",
|
||||
"tpu/experimental/__init__.py",
|
||||
"tpu/__init__.py",
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
|
|||
"summary/__init__.py",
|
||||
"sysconfig/__init__.py",
|
||||
"test/__init__.py",
|
||||
"test/experimental/__init__.py",
|
||||
"tpu/experimental/embedding/__init__.py",
|
||||
"tpu/experimental/__init__.py",
|
||||
"tpu/__init__.py",
|
||||
|
|
|
|||
|
|
@ -4884,6 +4884,10 @@ tf_module {
|
|||
name: "SymbolicGradient"
|
||||
argspec: "args=[\'input\', \'Tout\', \'f\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SyncDevice"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TFRecordDataset"
|
||||
argspec: "args=[\'filenames\', \'compression_type\', \'buffer_size\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
path: "tensorflow.test.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "sync_devices"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
@ -12,6 +12,10 @@ tf_module {
|
|||
name: "TestCase"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "mock"
|
||||
mtype: "<type \'module\'>"
|
||||
|
|
|
|||
|
|
@ -4884,6 +4884,10 @@ tf_module {
|
|||
name: "SymbolicGradient"
|
||||
argspec: "args=[\'input\', \'Tout\', \'f\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SyncDevice"
|
||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TFRecordDataset"
|
||||
argspec: "args=[\'filenames\', \'compression_type\', \'buffer_size\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
path: "tensorflow.test.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "sync_devices"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
@ -8,6 +8,10 @@ tf_module {
|
|||
name: "TestCase"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "assert_equal_graph_def"
|
||||
argspec: "args=[\'expected\', \'actual\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user