mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge changes from github.
END_PUBLIC --- Commit9f81374c3authored by raymondxyang<zihao.yang@microsoft.com> Committed by Rasmus Munk Larsen<rmlarsen@google.com>: Add option for build more python tests in Cmake (#11853) * Ignore Windows built project * Fix deprecated methods in tf.contrib.python * Fix regex match for Windows build in contrib.keras * Fix Regex match for Windows build in session_bundle * * Fix deprecated methods * Fix regex match for Windows * Fix compatibility issue with Python 3.x * Add missing ops into Windows build for test * Enabled more testcases for Windows build * Clean code and fix typo * Add conditional cmake mode for enabling more unit testcase * Add Cmake mode for major Contrib packages * Add supplementary info in RAEDME for new cmake option * * Update tf_tests after testing with TF 1.3 * Clean code and resolve conflicts * Fix unsafe regex matches and format code * Update exclude list after testing with latest master branch * Fix missing module --- Commit98f0e1efeauthored by Yong Tang<yong.tang.github@outlook.com> Committed by Rasmus Munk Larsen<rmlarsen@google.com>: Dynamic ksize and strides with MaxPool (#11875) * Dynamic ksize with max_pool This fix tries to fix the issue raised in 4746 where ksize is static (attr) with max_pool. This fix changes ksize to input tensor so that it is dynamic now. This fix fixes 4746. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add dynamic ksize to MaxPoolGrad and MaxPoolGradGrad Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for max_pool_v2 Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix GPU Jenkins issue. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Enable MaxPoolV2 in GPU Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Hide MaxPoolV2 and other fixes. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- Commit02d6bc185authored by Bairen Yi<byronyi@users.noreply.github.com> Committed by Rasmus Munk Larsen<rmlarsen@google.com>: remove useless variable (#12212) --- Commited6b0d905authored by namrata-ibm<bhavenamrata@gmail.com> Committed by Rasmus Munk Larsen<rmlarsen@google.com>: Adding support for s390x in calculation of cpu_frequency (#12201) --- Commit627dfc9ddauthored by Taehoon Lee<taehoonlee@snu.ac.kr> Committed by Taehoon Lee<taehoonlee@snu.ac.kr>: Fix typos --- Commitc0f9b0a91authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: In fast-math mode emit a tanh that has a faster min/max. PiperOrigin-RevId: 164943597 --- Commit87605f3d6authored by Kay Zhu<kayzhu@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [TF:XLA] Use HloEvaluator for ComputeConstant, remove the need of a dedicated compute constant backend. PiperOrigin-RevId: 164940970 --- Commit881de45c2authored by Taehoon Lee<me@taehoonlee.com> Committed by Rasmus Munk Larsen<rmlarsen@google.com>: Add bool type supports for GPU kernels (#11927) * Add bool type supports for GPU kernels * Add bool type test codes for GPU kernels --- Commiteeacdcdb1authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add missing "CPU" suffix in registrations. PiperOrigin-RevId: 164939527 --- Commitde01be952authored by namrata-ibm<bhavenamrata@gmail.com> Committed by Rasmus Munk Larsen<rmlarsen@google.com>: Adding support for Big Endian in graph_constructor_test and wav_io (#12179) --- Commit26719d29fauthored by QingYing Chen<pkudysj@126.com> Committed by Rasmus Munk Larsen<rmlarsen@google.com>: Implement CRF decode (Viterbi decode) for tensor (#12056) * Implement CRF decoding for tensors * add test code for tensor version's CRF decoding * made modifications according to pylint * add some comments for crf decode * remove useless code * add comments at the top comment of crf module and add more comments in crf_test * capitalize first char of first word in comments * replace crf_decode test code with a deterministic example --- Commitf9a81ca2fauthored by Pete Warden<pete@petewarden.com> Committed by gunan<gunan@google.com>: Create CI build script for Raspberry Pi (#12190) * Create CI build script for Raspberry Pi * Moved location of Pi build script --- Commite2a163a90authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Merge code from PR #11940 with internal changes from cl/164796436, and update Python tests to also run on GPU. PiperOrigin-RevId: 164929133 --- Commit08bbfa187authored by Taehoon Lee<me@taehoonlee.com> Committed by Rasmus Munk Larsen<rmlarsen@google.com>: Fix typos (#12195) --- Commitab96f41fbauthored by Luke Iwanski<luke@codeplay.com> Committed by Rasmus Munk Larsen<rmlarsen@google.com>: [OpenCL] Extends matmul_benchmark.py to cover SYCL (#11697) * [OpenCL] Extends matmul_benchmark.py to cover SYCL * Fixed typo * /gpu:0 -> /device:GPU:0 * Fixes control_flow_ops_py_test * /gpu: -> /device:GPU: * Fixes //tensorflow/python/profiler/internal:run_metadata_test * gpu: -> GPU: * Fixes tfprof_node * [OpenCL] Fixes device path to name with many colons (#123) The device path is constructed from a device name by replacing all colons with underscores. Some device names contain more than one colon, for example 'device:SYCL:0' which gives a path 'device_SYCL_0'. The previous code would not convert this back to the original device name, but rather to 'device:SYCL_0'. An alternative fix would be to convert all underscores to colons in the device name (i.e. remove the restriction inside `replace("_", ":", 1)`), however I'm not sure if there are any device names which contain underscores. * If no gpu device aviable fake one * gpu: -> device:GPU * Fixes profiler test * /gpu:x -> /device:GPU:x * Fixes debug_io_utils_test.cc test * Fixes device_name_utils_test.cc --- Commit35e7a3665authored by Yong Tang<yong.tang.github@outlook.com> Committed by Rasmus Munk Larsen<rmlarsen@google.com>: Remove unneeded casting of int64 for reverse_sequence (#12192) This fix remove unneeded cast of int64 for reverse_sequence: ``` lengths = math_ops.to_int64(lengths) ``` as int32 has already been enabled for reverse_sequence. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- Commit9fba8c185authored by Anna R<annarev@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add benchmark dashboard link to benchmarks doc. Also, I added a link and description for Benchmarks page to Community index page. PiperOrigin-RevId: 164924906 --- Commitbb6f32fa7authored by Mark Heffernan<meheff@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make HloAliasAnalysis updatable after changes to the HLO graph. As part of this change make HloAliasAnalysis a thinner layer which basically only holds a map from HloValue to HloBuffer and vice versa. PiperOrigin-RevId: 164923041 --- Commit9103096c1authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by Thomas K?ppe<tkoeppe@google.com>: Merged commit includes the following changes: 164923041 by meheff: Make HloAliasAnalysis updatable after changes to the HLO graph. As part of this change make HloAliasAnalysis a thinner layer which basically only holds a map from HloValue to HloBuffer and vice versa. -- PiperOrigin-RevId: 164923041 --- Commit822603aedauthored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Merging sibling fusion instruction using multi_output_fusion PiperOrigin-RevId: 164920220 --- Commitc035aa2a8authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 164917891 --- Commite1e81d9baauthored by Luke Iwanski<luke@codeplay.com> Committed by Rasmus Munk Larsen<rmlarsen@google.com>: [OpenCL] Fixes double memcpy bug (#151) (#12173) * [OpenCL] Fixes double memcpy bug (#151) As the debg CopyOp is called on a Tensor without type, we need to use the DataType enum to get type information, and use this to pass the type on to Eigen. This is a workaround Eigen's need to have a type when calling memcpy. If the Eigen memcpy can be provided without a type requirement, then the memcpy in sycl_util is unnecessary. * Acts on feedback from: #12173/files/32cb12a9001b672425867b5a3110fd98e737a20b#r132496277 --- Commitd9ca2d86dauthored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Internal change PiperOrigin-RevId: 164916465 --- Commitb8d13d218authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Remove more parts of DCASGD missed in the first pass. (47949b) PiperOrigin-RevId: 164914552 --- Commit73b3d52c7authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: cmake fix PiperOrigin-RevId: 164911656 --- Commit2173b5b0aauthored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Allow TFE_TensorHandleCopyToDevice to have the same device as src and destination. It will reuse the same underlying buffer in those cases. PiperOrigin-RevId: 164909906 --- Commit13eb3b90eauthored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Experimental C and Python APIs to invoke TensorFlow kernels on concrete values. PiperOrigin-RevId: 164902588 --- Commit7dfabcc01authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Initialize ExecutionOptions in ComputeConstant to default values. PiperOrigin-RevId: 164894867 --- Commitc8897e9bcauthored by Benoit Steiner<bsteiner@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Static required time computation PiperOrigin-RevId: 164894645 --- Commit076158f9bauthored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Enable implicit->explicit conversion by default. PiperOrigin-RevId: 164890915 --- Commit58c4a4cb1authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Bugfix: number of input channels is not necessarily in the last dimension, after introduction of data_format param. PiperOrigin-RevId: 164889729 --- Commit8f9b1af8aauthored by Igor Saprykin<isaprykin@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Recover MonitoredSession when the Coordinator is requested to stop with one of the _PREEMPTION_ERRORS. When SyncReplicasOptimizer is used, a preemption in the Coordinator may result in two cases: Case 1) the session gets silently marked as complete Case 2) the session gets stuck This CL aims to solve and verify solutions for both of these problems. Fix 1 changes the should_stop logic. Fix 2 changes the CoordinatedSession.run() logic. SyncReplicasOptimizer runs a separate set of threads using a Coordinator instance. Those threads do FIFOQueue.enqueue; the main thread does a blocking FIFOQueue.dequeue. `sync_token_q` FIFOQueue is on parameter-servers. When one of the PS instances gets preempted, an AbortedError causes the Coordinator to stop via request_stop(ex). That by itself changes the state of MonitoredSession.should_stop() to True (Fix 1). Results of the blocking Dequeue operation are sent to the chief worker via Recv. What happens next depends on the amount of tokens in `sync_token_q`. If there are enough for the next call to Dequeue to return, then the low-level "tf session run() call" returns. The next iteration of the `while not MonitoredSession.should_stop()` loop decides that the training is complete (Case 1). If there are not enough tokens in `sync_token_q`, then the blocking Dequeue is going to keep waiting for them. This results in the graph execution getting stuck and the whole session getting garbage collected after 10 minutes (Case 2). We decided to fix that by re-creating a session after it gets garbage collected (Fix 2). An alternative was to try to cancel the pending Dequeue operation, but it's not clear that it is the right thing to do and it is also not easy. PiperOrigin-RevId: 164888390 --- Commit46e4de6e5authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Undo loop fusion changes for now as they seem to be altering a few results. END_PUBLIC RELNOTES: n/a BEGIN_PUBLIC BEGIN_PUBLIC Automated g4 rollback of changelist 164825735 PiperOrigin-RevId: 165340331
This commit is contained in:
parent
03a33c08dd
commit
28ce1d163e
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -13,3 +13,5 @@ node_modules
|
|||
__pycache__
|
||||
*.swp
|
||||
.vscode/
|
||||
cmake_build/
|
||||
.idea/**
|
||||
|
|
|
|||
89
CODEOWNERS
89
CODEOWNERS
|
|
@ -1,52 +1,53 @@
|
|||
# NOTE: Disabled temporarily because it's too noisy on pushes.
|
||||
# Where component owners are known, add them here.
|
||||
|
||||
tensorflow/core/platform/windows/* @mrry
|
||||
tensorflow/java/* @asimshankar
|
||||
tensorflow/tensorboard/* @jart @dandelionmane
|
||||
tensorflow/tools/docs/* @markdaoust
|
||||
#tensorflow/core/platform/windows/* @mrry
|
||||
#tensorflow/java/* @asimshankar
|
||||
#tensorflow/tensorboard/* @jart @dandelionmane
|
||||
#tensorflow/tools/docs/* @markdaoust
|
||||
|
||||
# contrib
|
||||
|
||||
# NEED OWNER: tensorflow/contrib/avro/*
|
||||
tensorflow/contrib/batching/* @alextp @chrisolston
|
||||
tensorflow/contrib/bayesflow/* @ebrevdo @rsepassi @jvdillon
|
||||
tensorflow/contrib/cmake/* @mrry @benoitsteiner
|
||||
tensorflow/contrib/copy_graph/* @tucker @poxvoculi
|
||||
tensorflow/contrib/crf/* @kentonl
|
||||
tensorflow/contrib/data/* @mrry
|
||||
tensorflow/contrib/distributions/* @jvdillon @langmore @rsepassi
|
||||
tensorflow/contrib/factorization/* @agarwal-ashish @xavigonzalvo
|
||||
tensorflow/contrib/ffmpeg/* @fredbertsch
|
||||
#tensorflow/contrib/batching/* @alextp @chrisolston
|
||||
#tensorflow/contrib/bayesflow/* @ebrevdo @rsepassi @jvdillon
|
||||
#tensorflow/contrib/cmake/* @mrry @benoitsteiner
|
||||
#tensorflow/contrib/copy_graph/* @tucker @poxvoculi
|
||||
#tensorflow/contrib/crf/* @kentonl
|
||||
#tensorflow/contrib/data/* @mrry
|
||||
#tensorflow/contrib/distributions/* @jvdillon @langmore @rsepassi
|
||||
#tensorflow/contrib/factorization/* @agarwal-ashish @xavigonzalvo
|
||||
#tensorflow/contrib/ffmpeg/* @fredbertsch
|
||||
# NEED OWNER: tensorflow/contrib/framework/*
|
||||
tensorflow/contrib/graph_editor/* @purpledog
|
||||
#tensorflow/contrib/graph_editor/* @purpledog
|
||||
# NEED OWNER: tensorflow/contrib/grid_rnn/*
|
||||
tensorflow/contrib/hvx/* @satok16
|
||||
tensorflow/contrib/imperative/* @keveman
|
||||
tensorflow/contrib/integrate/* @shoyer
|
||||
tensorflow/contrib/kernel_methods/* @petrosmol
|
||||
tensorflow/contrib/ios_examples/* @petewarden
|
||||
tensorflow/contrib/labeled_tensor/* @shoyer
|
||||
tensorflow/contrib/layers/* @fchollet @martinwicke
|
||||
tensorflow/contrib/learn/* @martinwicke @ispirmustafa @alextp
|
||||
tensorflow/contrib/linalg/* @langmore
|
||||
tensorflow/contrib/linear_optimizer/* @petrosmol @andreasst @katsiapis
|
||||
tensorflow/contrib/lookup/* @ysuematsu @andreasst
|
||||
tensorflow/contrib/losses/* @alextp @ispirmustafa
|
||||
tensorflow/contrib/makefile/* @petewarden @satok16 @wolffg
|
||||
tensorflow/contrib/metrics/* @alextp @honkentuber @ispirmustafa
|
||||
tensorflow/contrib/nccl/* @cwhipkey @zheng-xq
|
||||
tensorflow/contrib/opt/* @strategist333
|
||||
tensorflow/contrib/pi_examples/* @maciekcc
|
||||
tensorflow/contrib/quantization/* @petewarden @cwhipkey @keveman
|
||||
tensorflow/contrib/rnn/* @ebrevdo
|
||||
tensorflow/contrib/saved_model/* @nfiedel @sukritiramesh
|
||||
tensorflow/contrib/seq2seq/* @lukaszkaiser
|
||||
tensorflow/contrib/session_bundle/* @nfiedel @sukritiramesh
|
||||
tensorflow/contrib/slim/* @sguada @thenbasilmanran
|
||||
tensorflow/contrib/stateless/* @girving
|
||||
tensorflow/contrib/tensor_forest/* @gilberthendry @thomascolthurst
|
||||
tensorflow/contrib/testing/* @dandelionmane
|
||||
tensorflow/contrib/timeseries/* @allenlavoie
|
||||
tensorflow/contrib/tpu/* @frankchn @saeta @jhseu
|
||||
tensorflow/contrib/training/* @joel-shor @ebrevdo
|
||||
tensorflow/contrib/util/* @sherrym
|
||||
#tensorflow/contrib/hvx/* @satok16
|
||||
#tensorflow/contrib/imperative/* @keveman
|
||||
#tensorflow/contrib/integrate/* @shoyer
|
||||
#tensorflow/contrib/kernel_methods/* @petrosmol
|
||||
#tensorflow/contrib/ios_examples/* @petewarden
|
||||
#tensorflow/contrib/labeled_tensor/* @shoyer
|
||||
#tensorflow/contrib/layers/* @fchollet @martinwicke
|
||||
#tensorflow/contrib/learn/* @martinwicke @ispirmustafa @alextp
|
||||
#tensorflow/contrib/linalg/* @langmore
|
||||
#tensorflow/contrib/linear_optimizer/* @petrosmol @andreasst @katsiapis
|
||||
#tensorflow/contrib/lookup/* @ysuematsu @andreasst
|
||||
#tensorflow/contrib/losses/* @alextp @ispirmustafa
|
||||
#tensorflow/contrib/makefile/* @petewarden @satok16 @wolffg
|
||||
#tensorflow/contrib/metrics/* @alextp @honkentuber @ispirmustafa
|
||||
#tensorflow/contrib/nccl/* @cwhipkey @zheng-xq
|
||||
#tensorflow/contrib/opt/* @strategist333
|
||||
#tensorflow/contrib/pi_examples/* @maciekcc
|
||||
#tensorflow/contrib/quantization/* @petewarden @cwhipkey @keveman
|
||||
#tensorflow/contrib/rnn/* @ebrevdo
|
||||
#tensorflow/contrib/saved_model/* @nfiedel @sukritiramesh
|
||||
#tensorflow/contrib/seq2seq/* @lukaszkaiser
|
||||
#tensorflow/contrib/session_bundle/* @nfiedel @sukritiramesh
|
||||
#tensorflow/contrib/slim/* @sguada @thenbasilmanran
|
||||
#tensorflow/contrib/stateless/* @girving
|
||||
#tensorflow/contrib/tensor_forest/* @gilberthendry @thomascolthurst
|
||||
#tensorflow/contrib/testing/* @dandelionmane
|
||||
#tensorflow/contrib/timeseries/* @allenlavoie
|
||||
#tensorflow/contrib/tpu/* @frankchn @saeta @jhseu
|
||||
#tensorflow/contrib/training/* @joel-shor @ebrevdo
|
||||
#tensorflow/contrib/util/* @sherrym
|
||||
|
|
|
|||
12
README.md
12
README.md
|
|
@ -30,16 +30,16 @@ tracking requests and bugs. So please see
|
|||
and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).**
|
||||
|
||||
## Installation
|
||||
*See [Installing TensorFlow](https://www.tensorflow.org/install) for instructions on how to install our release binaries or how to build from source.*
|
||||
*See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.*
|
||||
|
||||
People who are a little more adventurous can also try our nightly binaries:
|
||||
|
||||
|
||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
||||
* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0rc1-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0rc1-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/))
|
||||
* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0rc1-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0rc1-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/))
|
||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
||||
* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0rc2-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/))
|
||||
* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0rc2-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/))
|
||||
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
|
||||
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
* `DNNLinearCombinedClassifier`
|
||||
* `DNNLinearCombinedRegressor`.
|
||||
* All our prebuilt binaries have been built with cuDNN 6.
|
||||
* `import tensorflow` now goes much faster.
|
||||
* Adds a file cache to the GCS filesystem with configurable max staleness for file contents. This permits caching of file contents across close/open boundaries.
|
||||
* Added an axis parameter to `tf.gather`.
|
||||
* Added a `constant_values` keyword argument to `tf.pad`.
|
||||
|
|
@ -31,6 +32,7 @@
|
|||
* GPU kernels and speed improvements for for unary `tf.where` and `tf.nn.top_k`.
|
||||
* Monotonic Attention wrappers added to `tf.contrib.seq2seq`.
|
||||
* Added `tf.contrib.signal`, a library for signal processing primitives.
|
||||
* Added `tf.contrib.resampler`, containing CPU and GPU ops for differentiable resampling of images.
|
||||
|
||||
## Breaking Changes to the API
|
||||
* `tf.RewriterConfig` was removed from the Python API after being available in 1.2 release candidates (it was never in an actual release). Graph rewriting is still available, just not as `tf.RewriterConfig`. Instead add an explicit import.
|
||||
|
|
@ -64,7 +66,7 @@
|
|||
* Exported model signatures using the 'predict' method will no longer have their input and output keys silently ignored and rewritten to 'inputs' and 'outputs'. If a model was exported with different names before 1.2, and is now served with tensorflow/serving, it will accept requests using 'inputs' and 'outputs'. Starting at 1.2, such a model will accept the keys specified during export. Therefore, inference requests using 'inputs' and 'outputs' may start to fail. To fix this, either update any inference clients to send requests with the actual input and output keys used by the trainer code, or conversely, update the trainer code to name the input and output Tensors 'inputs' and 'outputs', respectively. Signatures using the 'classify' and 'regress' methods are not affected by this change; they will continue to standardize their input and output keys as before.
|
||||
* Add in-memory caching to the Dataset API.
|
||||
* Set default end_of_sequence variable in datasets iterators to false.
|
||||
* [Performance] Increase performance of `tf.layers.con2d` when setting use_bias=True by 2x by using nn.bias_add.
|
||||
* [Performance] Increase performance of `tf.layers.conv2d` when setting use_bias=True by 2x by using nn.bias_add.
|
||||
* Update iOS examples to use CocoaPods, and moved to tensorflow/examples/ios.
|
||||
* Adds a family= attribute in `tf.summary` ops to allow controlling the tab name used in Tensorboard for organizing summaries.
|
||||
* When GPU is configured, do not require --config=cuda, instead, automatically build for GPU if this is requested in the configure script.
|
||||
|
|
|
|||
10
configure.py
10
configure.py
|
|
@ -384,12 +384,16 @@ def set_action_env_var(environ_cp,
|
|||
def convert_version_to_int(version):
|
||||
"""Convert a version number to a integer that can be used to compare.
|
||||
|
||||
Version strings of the form X.YZ and X.Y.Z-xxxxx are supported. The
|
||||
'xxxxx' part, for instance 'homebrew' on OS/X, is ignored.
|
||||
|
||||
Args:
|
||||
version: a version to be covnerted
|
||||
version: a version to be converted
|
||||
|
||||
Returns:
|
||||
An integer if converted successfully, otherwise return None.
|
||||
"""
|
||||
version = version.split('-')[0]
|
||||
version_segments = version.split('.')
|
||||
for seg in version_segments:
|
||||
if not seg.isdigit():
|
||||
|
|
@ -428,6 +432,8 @@ def check_bazel_version(min_version):
|
|||
print('Make sure you are running at least bazel %s' % min_version)
|
||||
return curr_version
|
||||
|
||||
print("You have bazel %s installed." % curr_version)
|
||||
|
||||
if curr_version_int < min_version_int:
|
||||
print('Please upgrade your bazel installation to version %s or higher to '
|
||||
'build TensorFlow!' % min_version)
|
||||
|
|
@ -938,6 +944,8 @@ def main():
|
|||
'with_hdfs_support', False)
|
||||
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
||||
False)
|
||||
set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
|
||||
False)
|
||||
set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support',
|
||||
False)
|
||||
|
||||
|
|
|
|||
|
|
@ -182,6 +182,12 @@ config_setting(
|
|||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_gdr_support",
|
||||
values = {"define": "with_gdr_support=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_verbs_support",
|
||||
values = {"define": "with_verbs_support=true"},
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class TF_ManagedBuffer : public TensorBuffer {
|
|||
void* allocate_tensor(const char* operation, size_t len) {
|
||||
void* data =
|
||||
tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
|
||||
if (tensorflow::LogMemory::IsEnabled()) {
|
||||
if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
|
||||
tensorflow::LogMemory::RecordRawAllocation(
|
||||
operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID,
|
||||
len, data, tensorflow::cpu_allocator());
|
||||
|
|
@ -155,7 +155,7 @@ void* allocate_tensor(const char* operation, size_t len) {
|
|||
}
|
||||
|
||||
void deallocate_buffer(void* data, size_t len, void* arg) {
|
||||
if (tensorflow::LogMemory::IsEnabled()) {
|
||||
if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
|
||||
tensorflow::LogMemory::RecordRawDeallocation(
|
||||
"TensorFlow C Api",
|
||||
tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ void ConcurrentSteps(const Options* opts, int session_index) {
|
|||
std::unique_ptr<Session> session(NewSession(options));
|
||||
GraphDef def = CreateGraphDef();
|
||||
if (options.target.empty()) {
|
||||
graph::SetDefaultDevice(opts->use_gpu ? "/gpu:0" : "/cpu:0", &def);
|
||||
graph::SetDefaultDevice(opts->use_gpu ? "/device:GPU:0" : "/cpu:0", &def);
|
||||
}
|
||||
|
||||
TF_CHECK_OK(session->Create(def));
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ class MatcherBase {
|
|||
TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase);
|
||||
};
|
||||
|
||||
// WhileConditionComputationMatcher attempst to match a target computation
|
||||
// WhileConditionComputationMatcher attempts to match a target computation
|
||||
// pattern in the while condition sub-computation.
|
||||
// If the target pattern is matched, two pieces of information are extracted
|
||||
// from 'tagged' instructions returned by the matcher:
|
||||
|
|
|
|||
|
|
@ -626,7 +626,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
|
|||
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
||||
CHECK(instruction_to_fuse->IsFusable());
|
||||
if (GetModule()) {
|
||||
XLA_VLOG_LINES(1, GetModule()->ToString());
|
||||
XLA_VLOG_LINES(3, GetModule()->ToString());
|
||||
}
|
||||
HloInstruction* clone = nullptr;
|
||||
if (called_computations_.empty()) {
|
||||
|
|
@ -1909,9 +1909,10 @@ bool HloInstruction::IsFusable() const {
|
|||
case HloOpcode::kRecv:
|
||||
return false;
|
||||
// Only fuse Rng if it is used once, otherwise the random numbers generated
|
||||
// will be different in each fusion.
|
||||
// will be different in each fusion. If it is the root (user count = 0)
|
||||
// then it is the equivalent of having one user.
|
||||
case HloOpcode::kRng:
|
||||
return users_.size() == 1;
|
||||
return users_.size() <= 1;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1077,6 +1077,48 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
|
|||
root2->operand(1)->operand(0)->shape()));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, IsRandomFusable) {
|
||||
auto shape = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
{
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR0<float>(0.0)));
|
||||
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR0<float>(1.0)));
|
||||
auto rng = builder.AddInstruction(HloInstruction::CreateRng(
|
||||
shape, RandomDistribution::RNG_NORMAL, {const0, const1}));
|
||||
|
||||
auto* computation = hlo_module->AddEntryComputation(builder.Build());
|
||||
computation->CreateFusionInstruction({rng, const0, const1},
|
||||
HloInstruction::FusionKind::kLoop);
|
||||
|
||||
auto* root = computation->root_instruction();
|
||||
|
||||
EXPECT_EQ(HloOpcode::kFusion, root->opcode());
|
||||
}
|
||||
{
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR0<float>(0.0)));
|
||||
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR0<float>(1.0)));
|
||||
auto rng = builder.AddInstruction(HloInstruction::CreateRng(
|
||||
shape, RandomDistribution::RNG_NORMAL, {const0, const1}));
|
||||
builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
shape, HloOpcode::kNegate, rng));
|
||||
auto* computation = hlo_module->AddEntryComputation(builder.Build());
|
||||
computation->CreateFusionInstruction({rng, const0, const1},
|
||||
HloInstruction::FusionKind::kLoop);
|
||||
|
||||
auto* root = computation->root_instruction();
|
||||
|
||||
EXPECT_EQ(HloOpcode::kFusion, root->operand(0)->opcode());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_F(HloInstructionTest, CloneSuffixNames) {
|
||||
// Test that the suffix string added to cloned instructions is not
|
||||
// duplicated. Rather a numeric incrementing value should be appended. That
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) {
|
|||
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
|
||||
auto result = builder.Neg(a);
|
||||
|
|
@ -66,7 +66,7 @@ TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<int32>({-1, 0, 1, 324,
|
||||
std::numeric_limits<int32>::min(),
|
||||
|
|
@ -126,7 +126,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) {
|
|||
{});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
|
||||
auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
|
||||
|
|
@ -185,7 +185,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
|
||||
auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
|
||||
|
|
@ -204,7 +204,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) {
|
|||
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<int32>({-1, 0, 2, 1000000000});
|
||||
auto b = builder.ConstantR1<int32>({-1, 2, 1, -1});
|
||||
|
|
@ -222,7 +222,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
|
|||
ComputeAndCompareR1<int32>(&builder, {}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
|
||||
auto b = builder.ConstantR1<float>({10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
|
||||
|
|
@ -241,7 +241,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
|
|||
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, DivS32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) {
|
||||
// clang-format off
|
||||
// Some interesting values to test.
|
||||
std::vector<int32> vals = {
|
||||
|
|
@ -316,7 +316,7 @@ TEST_F(ArrayElementwiseOpTest, DivS32s) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, DivU32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) {
|
||||
// clang-format off
|
||||
// Some interesting values to test.
|
||||
std::vector<uint32> vals = {
|
||||
|
|
@ -420,7 +420,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
|
||||
auto b = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
|
||||
|
|
@ -439,7 +439,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) {
|
|||
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
|
||||
std::vector<int32> data = {0,
|
||||
1,
|
||||
-1,
|
||||
|
|
@ -474,7 +474,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) {
|
|||
ComputeAndCompareR1<int32>(&builder, {}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
|
||||
std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234,
|
||||
0x1a243514, 0xFFFFFFFF, 0x80808080};
|
||||
|
||||
|
|
@ -496,7 +496,7 @@ TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
|
|||
ComputeAndCompareR1<uint32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, LogicalAnd) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, LogicalAnd) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<bool>({false, false, true, true});
|
||||
auto b = builder.ConstantR1<bool>({false, true, false, true});
|
||||
|
|
@ -514,7 +514,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogicalAndZeroElement) {
|
|||
ComputeAndCompareR1<bool>(&builder, {}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, LogicalOr) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, LogicalOr) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<bool>({false, false, true, true});
|
||||
auto b = builder.ConstantR1<bool>({false, true, false, true});
|
||||
|
|
@ -532,7 +532,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogicalOrZeroElement) {
|
|||
ComputeAndCompareR1<bool>(&builder, {}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, LogicalNot) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, LogicalNot) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<bool>({false, true, true, false});
|
||||
auto out = builder.LogicalNot(a);
|
||||
|
|
@ -548,7 +548,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogicalNotZeroElement) {
|
|||
ComputeAndCompareR1<bool>(&builder, {}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
|
||||
SetFastMathDisabled(true);
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
|
||||
|
|
@ -567,7 +567,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
|
|||
ComputeAndCompareR1<bool>(&builder, {}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
|
||||
SetFastMathDisabled(true);
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
|
||||
|
|
@ -577,7 +577,7 @@ TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
|
|||
ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
|
||||
SetFastMathDisabled(true);
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
|
||||
|
|
@ -587,7 +587,7 @@ TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
|
|||
ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
|
||||
SetFastMathDisabled(true);
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR1<float>({-2.5f, 5.0f, 2.25f, NAN, 6.0f});
|
||||
|
|
@ -597,7 +597,7 @@ TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
|
|||
ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
|
||||
SetFastMathDisabled(true);
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
|
||||
|
|
@ -607,7 +607,7 @@ TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
|
|||
ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
|
||||
const int32 min = std::numeric_limits<int32>::min();
|
||||
const int32 max = std::numeric_limits<int32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -629,7 +629,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
|
|||
ComputeAndCompareR1<bool>(&builder, {}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
|
||||
// Disable fast-math because we're operating on NaNs.
|
||||
SetFastMathDisabled(true);
|
||||
|
||||
|
|
@ -641,7 +641,7 @@ TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
|
|||
ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
|
||||
const int32 min = std::numeric_limits<int32>::min();
|
||||
const int32 max = std::numeric_limits<int32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -653,7 +653,7 @@ TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
|
|||
&builder, {false, true, true, true, false, true, true, true, false}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
|
||||
const int32 min = std::numeric_limits<int32>::min();
|
||||
const int32 max = std::numeric_limits<int32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -665,7 +665,7 @@ TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
|
|||
&builder, {true, false, false, true, true, false, true, true, true}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
|
||||
const int32 min = std::numeric_limits<int32>::min();
|
||||
const int32 max = std::numeric_limits<int32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -678,7 +678,7 @@ TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
|
|||
{});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
|
||||
const int32 min = std::numeric_limits<int32>::min();
|
||||
const int32 max = std::numeric_limits<int32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -690,7 +690,7 @@ TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
|
|||
&builder, {true, true, true, false, true, true, false, false, true}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
|
||||
const int32 min = std::numeric_limits<int32>::min();
|
||||
const int32 max = std::numeric_limits<int32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -703,7 +703,7 @@ TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
|
|||
{});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
|
||||
const uint32 max = std::numeric_limits<uint32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
|
||||
|
|
@ -715,7 +715,7 @@ TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
|
|||
{});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
|
||||
const uint32 max = std::numeric_limits<uint32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
|
||||
|
|
@ -726,7 +726,7 @@ TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
|
|||
&builder, {false, true, true, true, false, true, true, true, false}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
|
||||
const uint32 max = std::numeric_limits<uint32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
|
||||
|
|
@ -737,7 +737,7 @@ TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
|
|||
&builder, {true, false, false, true, true, false, true, true, true}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
|
||||
const uint32 max = std::numeric_limits<uint32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
|
||||
|
|
@ -749,7 +749,7 @@ TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
|
|||
{});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
|
||||
const uint32 max = std::numeric_limits<uint32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
|
||||
|
|
@ -760,7 +760,7 @@ TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
|
|||
&builder, {true, true, true, false, true, true, false, false, true}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
|
||||
const uint32 max = std::numeric_limits<uint32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
|
||||
|
|
@ -772,7 +772,7 @@ TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
|
|||
{});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, PowF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
|
||||
SetFastMathDisabled(true);
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs =
|
||||
|
|
@ -795,7 +795,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
|
|||
}
|
||||
|
||||
// Some Pow cases that can be implemented more efficiently.
|
||||
TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
|
||||
ComputationBuilder b(client_, TestName());
|
||||
|
||||
std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
|
||||
|
|
@ -823,7 +823,7 @@ TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
|
|||
ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
|
||||
ComputationBuilder b(client_, TestName());
|
||||
|
||||
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
|
||||
|
|
@ -848,7 +848,7 @@ TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
|
||||
ComputationBuilder b(client_, TestName());
|
||||
|
||||
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
|
||||
|
|
@ -873,7 +873,7 @@ TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
|
||||
ComputationBuilder b(client_, TestName());
|
||||
|
||||
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
|
||||
|
|
@ -898,7 +898,7 @@ TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
|
||||
ComputationBuilder b(client_, TestName());
|
||||
|
||||
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
|
||||
|
|
@ -923,7 +923,7 @@ TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
|
||||
ComputationBuilder b(client_, TestName());
|
||||
|
||||
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
|
||||
|
|
@ -955,7 +955,7 @@ TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
|
|||
&b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
|
||||
ComputationBuilder b(client_, TestName());
|
||||
|
||||
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
|
||||
|
|
@ -988,7 +988,7 @@ TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
|
|||
&b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
|
||||
ComputationBuilder b(client_, TestName());
|
||||
|
||||
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
|
||||
|
|
@ -1021,7 +1021,7 @@ TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
|
|||
&b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, Div4F32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
|
||||
ComputationBuilder b(client_, TestName());
|
||||
|
||||
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
|
||||
|
|
@ -1081,7 +1081,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
|
|||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Array4D<float> values(2, 2, 2, 2);
|
||||
|
||||
|
|
@ -1120,7 +1120,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
|
|||
//
|
||||
// TODO(b/28180546): Make this compile in a way that is consistent
|
||||
// among backends.
|
||||
TEST_F(ArrayElementwiseOpTest, MinF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
#if !defined(XLA_TEST_BACKEND_CPU)
|
||||
auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f});
|
||||
|
|
@ -1174,7 +1174,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
|
|||
|
||||
// TODO(b/28180546): Make this compile in a way that is consistent
|
||||
// among backends. See comment on MinF32s test above.
|
||||
TEST_F(ArrayElementwiseOpTest, MaxF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
#if !defined(XLA_TEST_BACKEND_CPU)
|
||||
auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f});
|
||||
|
|
@ -1226,7 +1226,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) {
|
|||
{}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, MaxS32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) {
|
||||
const int32 min = std::numeric_limits<int32>::min();
|
||||
const int32 max = std::numeric_limits<int32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -1241,7 +1241,7 @@ TEST_F(ArrayElementwiseOpTest, MaxS32s) {
|
|||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, MinS32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) {
|
||||
const int32 min = std::numeric_limits<int32>::min();
|
||||
const int32 max = std::numeric_limits<int32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -1256,7 +1256,7 @@ TEST_F(ArrayElementwiseOpTest, MinS32s) {
|
|||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, MaxU32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) {
|
||||
const uint32 max = std::numeric_limits<uint32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
|
||||
|
|
@ -1267,7 +1267,7 @@ TEST_F(ArrayElementwiseOpTest, MaxU32s) {
|
|||
ComputeAndCompareR1<uint32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, MinU32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) {
|
||||
const uint32 max = std::numeric_limits<uint32>::max();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
|
||||
|
|
@ -1278,7 +1278,7 @@ TEST_F(ArrayElementwiseOpTest, MinU32s) {
|
|||
ComputeAndCompareR1<uint32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR1<float>(
|
||||
{-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
|
||||
|
|
@ -1311,7 +1311,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto v = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
|
||||
auto m =
|
||||
|
|
@ -1354,7 +1354,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
|
|||
ComputeAndCompareR3<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto m =
|
||||
builder.ConstantR2<float>({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
|
||||
|
|
@ -1431,7 +1431,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) {
|
|||
ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {});
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto minimum = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
|
||||
auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
|
||||
|
|
@ -1442,7 +1442,7 @@ TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto minimum = builder.ConstantR0<float>(0.0f);
|
||||
auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
|
||||
|
|
@ -1453,7 +1453,7 @@ TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto min_scalar = builder.ConstantR0<float>(0.0f);
|
||||
auto min_vector = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
|
||||
|
|
@ -1472,7 +1472,7 @@ TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
std::unique_ptr<Literal> param0_literal =
|
||||
|
|
@ -1516,7 +1516,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
|
|||
&builder, expected, {param0_data.get(), param1_data.get()}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
std::unique_ptr<Literal> param0_literal =
|
||||
|
|
@ -1550,7 +1550,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, TanhF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f});
|
||||
auto result = builder.Tanh(a);
|
||||
|
|
@ -1559,7 +1559,7 @@ TEST_F(ArrayElementwiseOpTest, TanhF32s) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
|
||||
// This is like the test ArrayElementwiseOpTest.TanhF32s above, except that
|
||||
// the input tensor is large enough to exercise the vectorized tanh
|
||||
// implementation.
|
||||
|
|
@ -1603,7 +1603,7 @@ TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
|
|||
ErrorSpec(0.004, 0.004));
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
|
||||
// a ------ (add) --------- (add)
|
||||
// / /
|
||||
// b -----/ /
|
||||
|
|
@ -1621,7 +1621,7 @@ TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
|
||||
// b ------ (add) --------- (add)
|
||||
// / /
|
||||
// c -----/ /
|
||||
|
|
@ -1639,7 +1639,7 @@ TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
|
||||
// a ----- (neg) ----- (add)
|
||||
// /
|
||||
// b ----- (neg) ----/
|
||||
|
|
@ -1656,7 +1656,7 @@ TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
|
||||
// a ------ (add) ------------\
|
||||
// / \
|
||||
// b -----/ (add)
|
||||
|
|
@ -1679,7 +1679,7 @@ TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a =
|
||||
builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
|
||||
|
|
@ -1704,7 +1704,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
|
|||
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
|
||||
// Add a matrix + scalar.
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a =
|
||||
|
|
@ -1820,7 +1820,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
|
|||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
|
||||
// Test simple broadcasting of a R1F32 over R2F32 when the order of binary op
|
||||
// arguments is reversed.
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -1831,7 +1831,7 @@ TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
|
|||
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
|
||||
// Tests broadcasting for arrays with degenerate (size == 1) dimensions.
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
// m's shape in XLA notation is {3, 2}
|
||||
|
|
@ -1891,7 +1891,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) {
|
|||
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
|
||||
// Add together a (2,2) array and a (2) array, using dimension 1 for
|
||||
// broadcasting (though there are two ways to broadcast these shapes).
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -1902,7 +1902,7 @@ TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
|
|||
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
|
||||
// Binary add of two R3s together
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
|
||||
|
|
@ -2033,7 +2033,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
|
|||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
|
||||
|
|
@ -2060,7 +2060,7 @@ TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
|
|||
ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
|
||||
|
|
@ -2088,7 +2088,7 @@ TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
|
|||
ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
|
||||
constexpr int d0 = 16;
|
||||
constexpr int d1 = 16;
|
||||
constexpr int d2 = 2;
|
||||
|
|
@ -2119,7 +2119,7 @@ TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
|
|||
}
|
||||
|
||||
// Show that we can't add two opaques.
|
||||
TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto shape = ShapeUtil::MakeOpaqueShape();
|
||||
auto x = builder.Parameter(0, shape, "x");
|
||||
|
|
@ -2133,7 +2133,7 @@ TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
|
|||
|
||||
// Regression test for b/31927799. "slice - y" is fused and requires implicit
|
||||
// broadcast.
|
||||
TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x_literal = Literal::CreateR1<float>({1, 2, 3});
|
||||
auto y_literal = Literal::CreateR1<float>({4, 5});
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ def xla_test(name,
|
|||
args=[],
|
||||
tags=[],
|
||||
copts=[],
|
||||
data=[],
|
||||
backend_tags={},
|
||||
backend_args={},
|
||||
**kwargs):
|
||||
|
|
@ -114,6 +115,7 @@ def xla_test(name,
|
|||
this_backend_tags = ["xla_%s" % backend]
|
||||
this_backend_copts = []
|
||||
this_backend_args = backend_args.get(backend, [])
|
||||
this_backend_data = []
|
||||
if backend == "cpu":
|
||||
backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
|
||||
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
|
||||
|
|
@ -130,6 +132,7 @@ def xla_test(name,
|
|||
this_backend_copts += plugins[backend]["copts"]
|
||||
this_backend_tags += plugins[backend]["tags"]
|
||||
this_backend_args += plugins[backend]["args"]
|
||||
this_backend_data += plugins[backend]["data"]
|
||||
else:
|
||||
fail("Unknown backend %s" % backend)
|
||||
|
||||
|
|
@ -145,6 +148,7 @@ def xla_test(name,
|
|||
this_backend_copts,
|
||||
args=args + this_backend_args,
|
||||
deps=deps + backend_deps,
|
||||
data=data + this_backend_data,
|
||||
**kwargs)
|
||||
|
||||
test_names.append(test_name)
|
||||
|
|
@ -227,14 +231,18 @@ def generate_backend_test_macros(backends=[]):
|
|||
if not backends:
|
||||
backends = all_backends
|
||||
for backend in filter_backends(backends):
|
||||
manifest = ""
|
||||
if backend in plugins:
|
||||
manifest = plugins[backend]["disabled_manifest"]
|
||||
|
||||
native.cc_library(
|
||||
name="test_macros_%s" % backend,
|
||||
testonly = True,
|
||||
srcs = ["test_macros.cc"],
|
||||
hdrs = ["test_macros.h"],
|
||||
copts = [
|
||||
"-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
|
||||
"-DXLA_DISABLED_MANIFEST=\\\"\\\""
|
||||
"-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
|
||||
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:types",
|
||||
|
|
|
|||
|
|
@ -22,9 +22,13 @@
|
|||
# "//tensorflow/compiler/plugin/foo:foo_lib",
|
||||
# "//tensorflow/compiler/plugin/foo:test_macros",
|
||||
# ],
|
||||
# "disabled_manifest": "tensorflow/compiler/plugin/foo/disabled_test_manifest.txt",
|
||||
# "copts": [],
|
||||
# "tags": [],
|
||||
# "args": []
|
||||
# "data": [
|
||||
# "//tensorflow/compiler/plugin/foo:disabled_test_manifest.txt",
|
||||
# ],
|
||||
# },
|
||||
# }
|
||||
|
||||
|
|
|
|||
|
|
@ -69,35 +69,35 @@ class ScalarComputationsTest : public ClientLibraryTestBase {
|
|||
}
|
||||
};
|
||||
|
||||
TEST_F(ScalarComputationsTest, NegateScalarF32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Neg(builder.ConstantR0<float>(2.1f));
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, -2.1f, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, NegateScalarS32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Neg(builder.ConstantR0<int32>(2));
|
||||
|
||||
ComputeAndCompareR0<int32>(&builder, -2, {});
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, AddTwoScalarsF32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Add(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f));
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 7.6f, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, AddTwoScalarsS32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Add(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5));
|
||||
|
||||
ComputeAndCompareR0<int32>(&builder, 7, {});
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, AddTwoScalarsU32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Add(builder.ConstantR0<uint32>(35), builder.ConstantR0<uint32>(57));
|
||||
|
||||
|
|
@ -137,21 +137,21 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) {
|
|||
ComputeAndCompareR0<double>(&builder, 3.75, {});
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Sub(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f));
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, -3.4f, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Sub(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5));
|
||||
|
||||
ComputeAndCompareR0<int32>(&builder, -3, {});
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Mul(builder.Mul(builder.ConstantR0<float>(2.1f),
|
||||
builder.ConstantR0<float>(5.5f)),
|
||||
|
|
@ -160,7 +160,7 @@ TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
|
|||
ComputeAndCompareR0<float>(&builder, 5.775f, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MulTwoScalarsS32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsS32) {
|
||||
std::vector<int32> data = {0,
|
||||
1,
|
||||
-1,
|
||||
|
|
@ -184,7 +184,7 @@ TEST_F(ScalarComputationsTest, MulTwoScalarsS32) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
|
||||
std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234,
|
||||
0x1a243514, 0xFFFFFFFF, 0x80808080};
|
||||
|
||||
|
|
@ -199,7 +199,7 @@ TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Mul(
|
||||
builder.Mul(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5)),
|
||||
|
|
@ -208,7 +208,7 @@ TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
|
|||
ComputeAndCompareR0<int32>(&builder, 10, {});
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
std::unique_ptr<Literal> a_literal = Literal::CreateR0<float>(2.1f);
|
||||
std::unique_ptr<Literal> b_literal = Literal::CreateR0<float>(5.5f);
|
||||
|
|
@ -231,7 +231,7 @@ TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Div(builder.ConstantR0<float>(5.0f), builder.ConstantR0<float>(2.5f));
|
||||
|
||||
|
|
@ -337,7 +337,7 @@ INSTANTIATE_TEST_CASE_P(
|
|||
DivS32Params{INT32_MIN, -0x40000000, 2, 0}, //
|
||||
DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff}));
|
||||
|
||||
TEST_F(ScalarComputationsTest, DivU32s) {
|
||||
XLA_TEST_F(ScalarComputationsTest, DivU32s) {
|
||||
// clang-format off
|
||||
// Some interesting values to test.
|
||||
std::vector<uint32> vals = {
|
||||
|
|
@ -378,7 +378,7 @@ TEST_F(ScalarComputationsTest, DivU32s) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, RemU32s) {
|
||||
XLA_TEST_F(ScalarComputationsTest, RemU32s) {
|
||||
// clang-format off
|
||||
// Some interesting values to test.
|
||||
std::vector<uint32> vals = {
|
||||
|
|
@ -419,7 +419,7 @@ TEST_F(ScalarComputationsTest, RemU32s) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
|
||||
builder.Rem(x, builder.ConstantR0<int32>(80000));
|
||||
|
|
@ -446,7 +446,7 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) {
|
|||
ComputeAndCompareR0<uint32>(&builder, 2, {});
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, LogicalAnd) {
|
||||
XLA_TEST_F(ScalarComputationsTest, LogicalAnd) {
|
||||
for (bool x : {false, true}) {
|
||||
for (bool y : {false, true}) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -458,7 +458,7 @@ TEST_F(ScalarComputationsTest, LogicalAnd) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, LogicalOr) {
|
||||
XLA_TEST_F(ScalarComputationsTest, LogicalOr) {
|
||||
for (bool x : {false, true}) {
|
||||
for (bool y : {false, true}) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -470,7 +470,7 @@ TEST_F(ScalarComputationsTest, LogicalOr) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, LogicalNot) {
|
||||
XLA_TEST_F(ScalarComputationsTest, LogicalNot) {
|
||||
for (bool x : {false, true}) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.LogicalNot(builder.ConstantR0<bool>(x));
|
||||
|
|
@ -479,7 +479,7 @@ TEST_F(ScalarComputationsTest, LogicalNot) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, SelectScalarTrue) {
|
||||
XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Select(builder.ConstantR0<bool>(true), // The predicate.
|
||||
builder.ConstantR0<float>(123.0f), // The value on true.
|
||||
|
|
@ -488,7 +488,7 @@ TEST_F(ScalarComputationsTest, SelectScalarTrue) {
|
|||
ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, SelectScalarFalse) {
|
||||
XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Select(builder.ConstantR0<bool>(false), // The predicate.
|
||||
builder.ConstantR0<float>(123.0f), // The value on true.
|
||||
|
|
@ -499,7 +499,7 @@ TEST_F(ScalarComputationsTest, SelectScalarFalse) {
|
|||
|
||||
// This test is an explicit version of what is happening in the following
|
||||
// templatized comparison tests.
|
||||
TEST_F(ScalarComputationsTest, CompareGtScalar) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Gt(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(1.0f));
|
||||
|
||||
|
|
@ -507,30 +507,30 @@ TEST_F(ScalarComputationsTest, CompareGtScalar) {
|
|||
}
|
||||
|
||||
// S32 comparisons.
|
||||
TEST_F(ScalarComputationsTest, CompareEqS32Greater) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) {
|
||||
TestCompare<int32>(2, 1, false, &ComputationBuilder::Eq);
|
||||
}
|
||||
TEST_F(ScalarComputationsTest, CompareEqS32Equal) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) {
|
||||
TestCompare<int32>(3, 3, true, &ComputationBuilder::Eq);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareNeS32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareNeS32) {
|
||||
TestCompare<int32>(2, 1, true, &ComputationBuilder::Ne);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareGeS32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareGeS32) {
|
||||
TestCompare<int32>(2, 1, true, &ComputationBuilder::Ge);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareGtS32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareGtS32) {
|
||||
TestCompare<int32>(1, 5, false, &ComputationBuilder::Gt);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareLeS32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareLeS32) {
|
||||
TestCompare<int32>(2, 1, false, &ComputationBuilder::Le);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareLtS32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareLtS32) {
|
||||
TestCompare<int32>(9, 7, false, &ComputationBuilder::Lt);
|
||||
TestCompare<int32>(std::numeric_limits<int32>::min(),
|
||||
std::numeric_limits<int32>::max(), true,
|
||||
|
|
@ -538,105 +538,105 @@ TEST_F(ScalarComputationsTest, CompareLtS32) {
|
|||
}
|
||||
|
||||
// U32 comparisons.
|
||||
TEST_F(ScalarComputationsTest, CompareEqU32False) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) {
|
||||
TestCompare<uint32>(2, 1, false, &ComputationBuilder::Eq);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareNeU32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareNeU32) {
|
||||
TestCompare<uint32>(2, 1, true, &ComputationBuilder::Ne);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareGeU32Greater) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) {
|
||||
TestCompare<uint32>(2, 1, true, &ComputationBuilder::Ge);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareGeU32Equal) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) {
|
||||
TestCompare<uint32>(3, 3, true, &ComputationBuilder::Ge);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareGtU32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareGtU32) {
|
||||
TestCompare<uint32>(1, 5, false, &ComputationBuilder::Gt);
|
||||
TestCompare<uint32>(5, 5, false, &ComputationBuilder::Gt);
|
||||
TestCompare<uint32>(5, 1, true, &ComputationBuilder::Gt);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareLeU32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareLeU32) {
|
||||
TestCompare<uint32>(2, 1, false, &ComputationBuilder::Le);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareLtU32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareLtU32) {
|
||||
TestCompare<uint32>(9, 7, false, &ComputationBuilder::Lt);
|
||||
TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true,
|
||||
&ComputationBuilder::Lt);
|
||||
}
|
||||
|
||||
// F32 comparisons.
|
||||
TEST_F(ScalarComputationsTest, CompareEqF32False) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) {
|
||||
TestCompare<float>(2.0, 1.3, false, &ComputationBuilder::Eq);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareNeF32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareNeF32) {
|
||||
TestCompare<float>(2.0, 1.3, true, &ComputationBuilder::Ne);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareGeF32Greater) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) {
|
||||
TestCompare<float>(2.0, 1.9, true, &ComputationBuilder::Ge);
|
||||
}
|
||||
TEST_F(ScalarComputationsTest, CompareGeF32Equal) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) {
|
||||
TestCompare<float>(3.5, 3.5, true, &ComputationBuilder::Ge);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareGtF32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareGtF32) {
|
||||
TestCompare<float>(1.0, 5.2, false, &ComputationBuilder::Gt);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareLeF32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareLeF32) {
|
||||
TestCompare<float>(2.0, 1.2, false, &ComputationBuilder::Le);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareLtF32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareLtF32) {
|
||||
TestCompare<float>(9.0, 7.2, false, &ComputationBuilder::Lt);
|
||||
}
|
||||
|
||||
// F32 comparisons with exceptional values. The test names encode the
|
||||
// left/right operands at the end, and use Minf and Mzero for -inf and -0.0.
|
||||
TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) {
|
||||
TestCompare<float>(-INFINITY, -0.0, true, &ComputationBuilder::Lt);
|
||||
}
|
||||
TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) {
|
||||
// Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
|
||||
TestCompare<float>(-0.0, 0.0, false, &ComputationBuilder::Lt);
|
||||
}
|
||||
TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) {
|
||||
TestCompare<float>(0.0, INFINITY, true, &ComputationBuilder::Lt);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) {
|
||||
TestCompare<float>(-INFINITY, -0.0, false, &ComputationBuilder::Ge);
|
||||
}
|
||||
TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) {
|
||||
// Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
|
||||
TestCompare<float>(-0.0, 0.0, true, &ComputationBuilder::Ge);
|
||||
}
|
||||
TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) {
|
||||
XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) {
|
||||
TestCompare<float>(0.0, INFINITY, false, &ComputationBuilder::Ge);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, ExpScalar) {
|
||||
XLA_TEST_F(ScalarComputationsTest, ExpScalar) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Exp(builder.ConstantR0<float>(2.0f));
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 7.3890562, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, LogScalar) {
|
||||
XLA_TEST_F(ScalarComputationsTest, LogScalar) {
|
||||
ComputationBuilder builder(client_, "log");
|
||||
builder.Log(builder.ConstantR0<float>(2.0f));
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 0.6931471, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, TanhScalar) {
|
||||
XLA_TEST_F(ScalarComputationsTest, TanhScalar) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Tanh(builder.ConstantR0<float>(2.0f));
|
||||
|
||||
|
|
@ -650,14 +650,14 @@ XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) {
|
|||
ComputeAndCompareR0<double>(&builder, 0.96402758, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, PowScalar) {
|
||||
XLA_TEST_F(ScalarComputationsTest, PowScalar) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Pow(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(3.0f));
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 8.0, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, ClampScalarHigh) {
|
||||
XLA_TEST_F(ScalarComputationsTest, ClampScalarHigh) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
|
||||
builder.ConstantR0<float>(5.0f), // The operand to be clamped.
|
||||
|
|
@ -666,7 +666,7 @@ TEST_F(ScalarComputationsTest, ClampScalarHigh) {
|
|||
ComputeAndCompareR0<float>(&builder, 3.0, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, ClampScalarMiddle) {
|
||||
XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddle) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
|
||||
builder.ConstantR0<float>(2.5f), // The operand to be clamped.
|
||||
|
|
@ -675,7 +675,7 @@ TEST_F(ScalarComputationsTest, ClampScalarMiddle) {
|
|||
ComputeAndCompareR0<float>(&builder, 2.5, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, ClampScalarLow) {
|
||||
XLA_TEST_F(ScalarComputationsTest, ClampScalarLow) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
|
||||
builder.ConstantR0<float>(-5.0f), // The operand to be clamped.
|
||||
|
|
@ -684,57 +684,57 @@ TEST_F(ScalarComputationsTest, ClampScalarLow) {
|
|||
ComputeAndCompareR0<float>(&builder, 2.0, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MinS32Above) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MinS32Above) {
|
||||
TestMinMax<int32>(10, 3, 3, &ComputationBuilder::Min);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MinS32Below) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MinS32Below) {
|
||||
TestMinMax<int32>(-100, 3, -100, &ComputationBuilder::Min);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MaxS32Above) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MaxS32Above) {
|
||||
TestMinMax<int32>(10, 3, 10, &ComputationBuilder::Max);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MaxS32Below) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MaxS32Below) {
|
||||
TestMinMax<int32>(-100, 3, 3, &ComputationBuilder::Max);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MinU32Above) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MinU32Above) {
|
||||
const uint32 large = std::numeric_limits<int32>::max();
|
||||
TestMinMax<uint32>(large, 3, 3, &ComputationBuilder::Min);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MinU32Below) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MinU32Below) {
|
||||
TestMinMax<uint32>(0, 5, 0, &ComputationBuilder::Min);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MaxU32Above) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MaxU32Above) {
|
||||
const uint32 large = std::numeric_limits<int32>::max();
|
||||
TestMinMax<uint32>(large, 3, large, &ComputationBuilder::Max);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MaxU32Below) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MaxU32Below) {
|
||||
TestMinMax<uint32>(0, 5, 5, &ComputationBuilder::Max);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MinF32Above) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MinF32Above) {
|
||||
TestMinMax<float>(10.1f, 3.1f, 3.1f, &ComputationBuilder::Min);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MinF32Below) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MinF32Below) {
|
||||
TestMinMax<float>(-100.1f, 3.1f, -100.1f, &ComputationBuilder::Min);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MaxF32Above) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MaxF32Above) {
|
||||
TestMinMax<float>(10.1f, 3.1f, 10.1f, &ComputationBuilder::Max);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, MaxF32Below) {
|
||||
XLA_TEST_F(ScalarComputationsTest, MaxF32Below) {
|
||||
TestMinMax<float>(-100.1f, 3.1f, 3.1f, &ComputationBuilder::Max);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
|
||||
// Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20.
|
||||
ComputationBuilder b(client_, TestName());
|
||||
b.Div(
|
||||
|
|
@ -747,7 +747,7 @@ TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
|
|||
ComputeAndCompareR0<float>(&b, 0.5, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
|
||||
XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
|
||||
// Compute the expression 1 * (3 - 1) * (7 + 0) - 4.
|
||||
ComputationBuilder b(client_, TestName());
|
||||
b.Sub(b.Mul(b.ConstantR0<int32>(1),
|
||||
|
|
@ -758,7 +758,7 @@ TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
|
|||
ComputeAndCompareR0<int32>(&b, 10, {});
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, SqrtF320) {
|
||||
XLA_TEST_F(ScalarComputationsTest, SqrtF320) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Literal zero_literal = Literal::Zero(PrimitiveType::F32);
|
||||
|
||||
|
|
|
|||
|
|
@ -85,12 +85,12 @@ XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
|
|||
AbsSize0TestHelper<float>();
|
||||
}
|
||||
|
||||
TEST_F(UnaryOpTest, AbsTestR1) {
|
||||
XLA_TEST_F(UnaryOpTest, AbsTestR1) {
|
||||
AbsTestHelper<int>();
|
||||
AbsTestHelper<float>();
|
||||
}
|
||||
|
||||
TEST_F(UnaryOpTest, AbsTestR0) {
|
||||
XLA_TEST_F(UnaryOpTest, AbsTestR0) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto argi = builder.ConstantR0<int>(-5);
|
||||
auto absi = builder.Abs(argi);
|
||||
|
|
@ -104,7 +104,7 @@ TEST_F(UnaryOpTest, AbsTestR0) {
|
|||
ComputeAndCompareR0<float>(&builder, 8.0f, {});
|
||||
}
|
||||
|
||||
TEST_F(UnaryOpTest, SignTestR0) {
|
||||
XLA_TEST_F(UnaryOpTest, SignTestR0) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto argi = builder.ConstantR0<int>(-5);
|
||||
auto absi = builder.Sign(argi);
|
||||
|
|
@ -118,17 +118,17 @@ TEST_F(UnaryOpTest, SignTestR0) {
|
|||
ComputeAndCompareR0<float>(&builder, -2.0f, {});
|
||||
}
|
||||
|
||||
TEST_F(UnaryOpTest, SignTestR1) {
|
||||
XLA_TEST_F(UnaryOpTest, SignTestR1) {
|
||||
SignTestHelper<int>();
|
||||
SignTestHelper<float>();
|
||||
}
|
||||
|
||||
TEST_F(UnaryOpTest, SignAbsTestR1) {
|
||||
XLA_TEST_F(UnaryOpTest, SignAbsTestR1) {
|
||||
SignAbsTestHelper<int>();
|
||||
SignAbsTestHelper<float>();
|
||||
}
|
||||
|
||||
TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
|
||||
XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto arg = builder.ConstantR1<unsigned int>(
|
||||
{2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
|
||||
|
|
@ -138,7 +138,7 @@ TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
|
|||
&builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}, {});
|
||||
}
|
||||
|
||||
TEST_F(UnaryOpTest, UnsignedSignTestR1) {
|
||||
XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto arg = builder.ConstantR1<unsigned int>(
|
||||
{2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
|
||||
|
|
@ -147,7 +147,7 @@ TEST_F(UnaryOpTest, UnsignedSignTestR1) {
|
|||
ComputeAndCompareR1<unsigned int>(&builder, {1, 1, 0, 1, 1}, {});
|
||||
}
|
||||
|
||||
TEST_F(UnaryOpTest, SignAbsTestR2) {
|
||||
XLA_TEST_F(UnaryOpTest, SignAbsTestR2) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto arg = builder.ConstantR2<float>({{1.0, -2.0}, {-3.0, 4.0}});
|
||||
auto sign = builder.Sign(arg);
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class VecOpsSimpleTest : public ClientLibraryTestBase {
|
|||
ErrorSpec error_spec_{0.0001};
|
||||
};
|
||||
|
||||
TEST_F(VecOpsSimpleTest, ExpTenValues) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR1<float>(
|
||||
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
|
||||
|
|
@ -61,7 +61,7 @@ TEST_F(VecOpsSimpleTest, ExpTenValues) {
|
|||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, ExpManyValues) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) {
|
||||
for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
std::vector<float> exponents;
|
||||
|
|
@ -83,7 +83,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, ExpIn4D) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Array4D<float> exponents(2, 2, 2, 2);
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ TEST_F(VecOpsSimpleTest, ExpIn4D) {
|
|||
ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3));
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, NegateTenFloatValues) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR1<float>(
|
||||
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
|
||||
|
|
@ -116,7 +116,7 @@ TEST_F(VecOpsSimpleTest, NegateTenFloatValues) {
|
|||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, NegateTenInt32Values) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR1<int32>({2, -2, 12, -4, 5, 20, -15, 0, -2, 1});
|
||||
builder.Neg(x);
|
||||
|
|
@ -125,7 +125,7 @@ TEST_F(VecOpsSimpleTest, NegateTenInt32Values) {
|
|||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, NegateUint32Values) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR1<uint32>(
|
||||
{0, 1, 42, static_cast<uint32>(-1), static_cast<uint32>(-12)});
|
||||
|
|
@ -135,7 +135,7 @@ TEST_F(VecOpsSimpleTest, NegateUint32Values) {
|
|||
ComputeAndCompareR1<uint32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, SquareTenValues) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR1<float>(
|
||||
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
|
||||
|
|
@ -146,7 +146,7 @@ TEST_F(VecOpsSimpleTest, SquareTenValues) {
|
|||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, ReciprocalTenValues) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR1<float>(
|
||||
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
|
||||
|
|
@ -187,7 +187,7 @@ XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) {
|
|||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto add = CreateScalarAddComputation(F32, &builder);
|
||||
|
||||
|
|
@ -202,7 +202,7 @@ TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) {
|
|||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, MaxTenValues) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR1<float>(
|
||||
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
|
||||
|
|
@ -215,7 +215,7 @@ TEST_F(VecOpsSimpleTest, MaxTenValues) {
|
|||
ComputeAndCompareR1<float>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) {
|
||||
// Similar to MaxTenValues, except that the inputs come from params rather
|
||||
// than constants.
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -233,7 +233,7 @@ TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) {
|
||||
// Similar to MaxTenValuesFromParams, except that the data size passed in and
|
||||
// out is large.
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
@ -273,7 +273,7 @@ TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) {
|
|||
error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR1<float>(
|
||||
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
|
||||
|
|
@ -285,7 +285,7 @@ TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) {
|
|||
ComputeAndCompareR1<float>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, MinTenValues) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, MinTenValues) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.ConstantR1<float>(
|
||||
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
|
||||
|
|
@ -298,7 +298,7 @@ TEST_F(VecOpsSimpleTest, MinTenValues) {
|
|||
ComputeAndCompareR1<float>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, MinMaxTenValues) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto zero = builder.ConstantR0<float>(0);
|
||||
auto one = builder.ConstantR0<float>(1);
|
||||
|
|
@ -311,7 +311,7 @@ TEST_F(VecOpsSimpleTest, MinMaxTenValues) {
|
|||
ComputeAndCompareR1<float>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto zero = builder.ConstantR0<float>(0);
|
||||
auto one = builder.ConstantR0<float>(1);
|
||||
|
|
@ -324,7 +324,7 @@ TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) {
|
|||
ComputeAndCompareR1<float>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto zero = builder.ConstantR1<float>({0.0f, 0.0f});
|
||||
auto one = builder.ConstantR1<float>({1.0f, 1.0f});
|
||||
|
|
@ -335,7 +335,7 @@ TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) {
|
|||
ComputeAndCompareR1<float>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto one = builder.ConstantR0<float>(1);
|
||||
auto two = builder.ConstantR0<float>(2);
|
||||
|
|
@ -348,7 +348,7 @@ TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
|
|||
ComputeAndCompareR1<float>(&builder, expected, {});
|
||||
}
|
||||
|
||||
TEST_F(VecOpsSimpleTest, MapTenValues) {
|
||||
XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
|
||||
Computation add_half;
|
||||
{
|
||||
// add_half(x) = x + 0.5
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ def batch_function(num_batch_threads, max_batch_size, batch_timeout_micros,
|
|||
|
||||
So, for example, in the following code
|
||||
|
||||
```
|
||||
```python
|
||||
@batch_function(1, 2, 3)
|
||||
def layer(a):
|
||||
return tf.matmul(a, a)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON)
|
|||
option(tensorflow_BUILD_CONTRIB_KERNELS "Build OpKernels from tensorflow/contrib/..." ON)
|
||||
option(tensorflow_BUILD_CC_TESTS "Build cc unit tests " OFF)
|
||||
option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF)
|
||||
option(tensorflow_BUILD_MORE_PYTHON_TESTS "Build more python unit tests for contrib packages" OFF)
|
||||
option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF)
|
||||
option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON)
|
||||
option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions")
|
||||
|
|
|
|||
|
|
@ -241,6 +241,13 @@ Step-by-step Windows build
|
|||
```
|
||||
ctest -C RelWithDebInfo
|
||||
```
|
||||
* `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python tests on
|
||||
serveral major packages. This option is only valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`.
|
||||
After building the python wheel, you need to install the new wheel before running the tests.
|
||||
To execute the tests, use
|
||||
```
|
||||
ctest -C RelWithDebInfo
|
||||
```
|
||||
|
||||
4. Invoke MSBuild to build TensorFlow.
|
||||
|
||||
|
|
|
|||
|
|
@ -76,7 +76,9 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
|
|||
#"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/generate_vocab_remapping_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/zero_initializer_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/ops/checkpoint_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc"
|
||||
|
|
|
|||
|
|
@ -156,6 +156,21 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
|||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/*_test.py"
|
||||
)
|
||||
|
||||
if (tensorflow_BUILD_MORE_PYTHON_TESTS)
|
||||
# Adding other major packages
|
||||
file(GLOB_RECURSE tf_test_src_py
|
||||
${tf_test_src_py}
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/legacy_seq2seq/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/linalg/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/keras/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/*_test.py"
|
||||
)
|
||||
endif()
|
||||
|
||||
# exclude the ones we don't want
|
||||
set(tf_test_src_py_exclude
|
||||
# Python source line inspection tests are flaky on Windows (b/36375074).
|
||||
|
|
@ -183,6 +198,9 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
|||
# Loading resources in contrib doesn't seem to work on Windows
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/client/random_forest_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py"
|
||||
# dask need fix
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py"
|
||||
# Test is flaky on Windows GPU builds (b/38283730).
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py"
|
||||
)
|
||||
|
|
@ -215,11 +233,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
|||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py"
|
||||
# training tests
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/basic_session_run_hooks_test.py" # Needs tf.contrib fix.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" # Needs tf.contrib fix.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" # Needs portpicker.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/monitored_session_test.py" # Needs tf.contrib fix.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/saver_large_variable_test.py" # Overflow error.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker.
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops
|
||||
|
|
@ -233,6 +248,45 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
|||
"${tensorflow_source_dir}/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py" # No libcurl support
|
||||
# Newly running on Windows since TensorBoard backend move. Fail on Windows and need debug.
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows.
|
||||
# Dask.Dataframe bugs on Window Build
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/io_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/graph_actions_test.py"
|
||||
# Need extra build
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_distribution_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py"
|
||||
# Windows Path
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py" #TODO: Fix path
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/models_test.py"
|
||||
# Related to Windows Multiprocessing https://github.com/fchollet/keras/issues/5071
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/engine/training_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/utils/data_utils_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/callbacks_test.py"
|
||||
# Scipy needed
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/image_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/logistic_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py"
|
||||
# Failing with TF 1.3 (TODO)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py"
|
||||
)
|
||||
endif()
|
||||
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import itertools
|
|||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.crf.python.ops import crf
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
|
@ -199,6 +200,52 @@ class CrfTest(test.TestCase):
|
|||
self.assertEqual(actual_max_sequence,
|
||||
expected_max_sequence[:sequence_lengths])
|
||||
|
||||
def testCrfDecode(self):
|
||||
inputs = np.array(
|
||||
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
|
||||
transition_params = np.array(
|
||||
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
|
||||
sequence_lengths = np.array(3, dtype=np.int32)
|
||||
num_words = inputs.shape[0]
|
||||
num_tags = inputs.shape[1]
|
||||
|
||||
with self.test_session() as sess:
|
||||
all_sequence_scores = []
|
||||
all_sequences = []
|
||||
|
||||
# Compare the dynamic program with brute force computation.
|
||||
for tag_indices in itertools.product(
|
||||
range(num_tags), repeat=sequence_lengths):
|
||||
tag_indices = list(tag_indices)
|
||||
tag_indices.extend([0] * (num_words - sequence_lengths))
|
||||
all_sequences.append(tag_indices)
|
||||
sequence_score = crf.crf_sequence_score(
|
||||
inputs=array_ops.expand_dims(inputs, 0),
|
||||
tag_indices=array_ops.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
|
||||
transition_params=constant_op.constant(transition_params))
|
||||
sequence_score = array_ops.squeeze(sequence_score, [0])
|
||||
all_sequence_scores.append(sequence_score)
|
||||
|
||||
tf_all_sequence_scores = sess.run(all_sequence_scores)
|
||||
|
||||
expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
|
||||
expected_max_sequence = all_sequences[expected_max_sequence_index]
|
||||
expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
|
||||
|
||||
actual_max_sequence, actual_max_score = crf.crf_decode(
|
||||
array_ops.expand_dims(inputs, 0),
|
||||
constant_op.constant(transition_params),
|
||||
array_ops.expand_dims(sequence_lengths, 0))
|
||||
actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0])
|
||||
actual_max_score = array_ops.squeeze(actual_max_score, [0])
|
||||
tf_actual_max_sequence, tf_actual_max_score = sess.run(
|
||||
[actual_max_sequence, actual_max_score])
|
||||
|
||||
self.assertAllClose(tf_actual_max_score, expected_max_score)
|
||||
self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]),
|
||||
expected_max_sequence[:sequence_lengths])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -16,13 +16,24 @@
|
|||
|
||||
The following snippet is an example of a CRF layer on top of a batched sequence
|
||||
of unary scores (logits for every word). This example also decodes the most
|
||||
likely sequence at test time:
|
||||
likely sequence at test time. There are two ways to do decoding. One
|
||||
is using crf_decode to do decoding in Tensorflow , and the other one is using
|
||||
viterbi_decode in Numpy.
|
||||
|
||||
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
|
||||
unary_scores, gold_tags, sequence_lengths)
|
||||
|
||||
loss = tf.reduce_mean(-log_likelihood)
|
||||
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
|
||||
|
||||
# Decoding in Tensorflow.
|
||||
viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode(
|
||||
unary_scores, transition_params, sequence_lengths)
|
||||
|
||||
tf_viterbi_sequence, tf_viterbi_score, _ = session.run(
|
||||
[viterbi_sequence, viterbi_score, train_op])
|
||||
|
||||
# Decoding in Numpy.
|
||||
tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run(
|
||||
[unary_scores, sequence_lengths, transition_params, train_op])
|
||||
for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
|
||||
|
|
@ -31,7 +42,7 @@ for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
|
|||
tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_]
|
||||
|
||||
# Compute the highest score and its tag sequence.
|
||||
viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(
|
||||
tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode(
|
||||
tf_unary_scores_, tf_transition_params)
|
||||
"""
|
||||
|
||||
|
|
@ -43,6 +54,7 @@ import numpy as np
|
|||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
|
|
@ -50,7 +62,9 @@ from tensorflow.python.ops import variable_scope as vs
|
|||
|
||||
__all__ = [
|
||||
"crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
|
||||
"crf_unary_score", "crf_binary_score", "CrfForwardRnnCell", "viterbi_decode"
|
||||
"crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
|
||||
"viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell",
|
||||
"CrfDecodeBackwardRnnCell"
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -310,3 +324,154 @@ def viterbi_decode(score, transition_params):
|
|||
|
||||
viterbi_score = np.max(trellis[-1])
|
||||
return viterbi, viterbi_score
|
||||
|
||||
|
||||
class CrfDecodeForwardRnnCell(rnn_cell.RNNCell):
|
||||
"""Computes the forward decoding in a linear-chain CRF.
|
||||
"""
|
||||
|
||||
def __init__(self, transition_params):
|
||||
"""Initialize the CrfDecodeForwardRnnCell.
|
||||
|
||||
Args:
|
||||
transition_params: A [num_tags, num_tags] matrix of binary
|
||||
potentials. This matrix is expanded into a
|
||||
[1, num_tags, num_tags] in preparation for the broadcast
|
||||
summation occurring within the cell.
|
||||
"""
|
||||
self._transition_params = array_ops.expand_dims(transition_params, 0)
|
||||
self._num_tags = transition_params.get_shape()[0].value
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._num_tags
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._num_tags
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
"""Build the CrfDecodeForwardRnnCell.
|
||||
|
||||
Args:
|
||||
inputs: A [batch_size, num_tags] matrix of unary potentials.
|
||||
state: A [batch_size, num_tags] matrix containing the previous step's
|
||||
score values.
|
||||
scope: Unused variable scope of this cell.
|
||||
|
||||
Returns:
|
||||
backpointers: [batch_size, num_tags], containing backpointers.
|
||||
new_state: [batch_size, num_tags], containing new score values.
|
||||
"""
|
||||
# For simplicity, in shape comments, denote:
|
||||
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
|
||||
state = array_ops.expand_dims(state, 2) # [B, O, 1]
|
||||
|
||||
# This addition op broadcasts self._transitions_params along the zeroth
|
||||
# dimension and state along the second dimension.
|
||||
# [B, O, 1] + [1, O, O] -> [B, O, O]
|
||||
transition_scores = state + self._transition_params # [B, O, O]
|
||||
new_state = inputs + math_ops.reduce_max(transition_scores, [1]) # [B, O]
|
||||
backpointers = math_ops.argmax(transition_scores, 1)
|
||||
backpointers = math_ops.cast(backpointers, dtype=dtypes.int32) # [B, O]
|
||||
return backpointers, new_state
|
||||
|
||||
|
||||
class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
|
||||
"""Computes backward decoding in a linear-chain CRF.
|
||||
"""
|
||||
|
||||
def __init__(self, num_tags):
|
||||
"""Initialize the CrfDecodeBackwardRnnCell.
|
||||
|
||||
Args:
|
||||
num_tags
|
||||
"""
|
||||
self._num_tags = num_tags
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return 1
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
"""Build the CrfDecodeBackwardRnnCell.
|
||||
|
||||
Args:
|
||||
inputs: [batch_size, num_tags], backpointer of next step (in time order).
|
||||
state: [batch_size, 1], next position's tag index.
|
||||
scope: Unused variable scope of this cell.
|
||||
|
||||
Returns:
|
||||
new_tags, new_tags: A pair of [batch_size, num_tags]
|
||||
tensors containing the new tag indices.
|
||||
"""
|
||||
state = array_ops.squeeze(state, axis=[1]) # [B]
|
||||
batch_size = array_ops.shape(inputs)[0]
|
||||
b_indices = math_ops.range(batch_size) # [B]
|
||||
indices = array_ops.stack([b_indices, state], axis=1) # [B, 2]
|
||||
new_tags = array_ops.expand_dims(
|
||||
gen_array_ops.gather_nd(inputs, indices), # [B]
|
||||
axis=-1) # [B, 1]
|
||||
|
||||
return new_tags, new_tags
|
||||
|
||||
|
||||
def crf_decode(potentials, transition_params, sequence_length):
|
||||
"""Decode the highest scoring sequence of tags in TensorFlow.
|
||||
|
||||
This is a function for tensor.
|
||||
|
||||
Args:
|
||||
potentials: A [batch_size, max_seq_len, num_tags] tensor, matrix of
|
||||
unary potentials.
|
||||
transition_params: A [num_tags, num_tags] tensor, matrix of
|
||||
binary potentials.
|
||||
sequence_length: A [batch_size] tensor, containing sequence lengths.
|
||||
|
||||
Returns:
|
||||
decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32.
|
||||
Contains the highest scoring tag indicies.
|
||||
best_score: A [batch_size] tensor, containing the score of decode_tags.
|
||||
"""
|
||||
# For simplicity, in shape comments, denote:
|
||||
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
|
||||
num_tags = potentials.get_shape()[2].value
|
||||
|
||||
# Computes forward decoding. Get last score and backpointers.
|
||||
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
|
||||
initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
|
||||
initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
|
||||
inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
|
||||
backpointers, last_score = rnn.dynamic_rnn(
|
||||
crf_fwd_cell,
|
||||
inputs=inputs,
|
||||
sequence_length=sequence_length - 1,
|
||||
initial_state=initial_state,
|
||||
time_major=False,
|
||||
dtype=dtypes.int32) # [B, T - 1, O], [B, O]
|
||||
backpointers = gen_array_ops.reverse_sequence(
|
||||
backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O]
|
||||
|
||||
# Computes backward decoding. Extract tag indices from backpointers.
|
||||
crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
|
||||
initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),
|
||||
dtype=dtypes.int32) # [B]
|
||||
initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
|
||||
decode_tags, _ = rnn.dynamic_rnn(
|
||||
crf_bwd_cell,
|
||||
inputs=backpointers,
|
||||
sequence_length=sequence_length - 1,
|
||||
initial_state=initial_state,
|
||||
time_major=False,
|
||||
dtype=dtypes.int32) # [B, T - 1, 1]
|
||||
decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
|
||||
decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T]
|
||||
decode_tags = gen_array_ops.reverse_sequence(
|
||||
decode_tags, sequence_length, seq_dim=1) # [B, T]
|
||||
|
||||
best_score = math_ops.reduce_max(last_score, axis=1) # [B]
|
||||
return decode_tags, best_score
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class CudnnRNNBenchmark(test.Benchmark):
|
|||
batch_size = config["batch_size"]
|
||||
seq_length = config["seq_length"]
|
||||
|
||||
with ops.Graph().as_default(), ops.device("/gpu:0"):
|
||||
with ops.Graph().as_default(), ops.device("/device:GPU:0"):
|
||||
model = cudnn_rnn_ops.CudnnLSTM(num_layers, num_units, num_units)
|
||||
params_size_t = model.params_size()
|
||||
input_data = variables.Variable(
|
||||
|
|
@ -125,7 +125,7 @@ class CudnnRNNBenchmark(test.Benchmark):
|
|||
batch_size = config["batch_size"]
|
||||
seq_length = config["seq_length"]
|
||||
|
||||
with ops.Graph().as_default(), ops.device("/gpu:0"):
|
||||
with ops.Graph().as_default(), ops.device("/device:GPU:0"):
|
||||
inputs = seq_length * [
|
||||
array_ops.zeros([batch_size, num_units], dtypes.float32)
|
||||
]
|
||||
|
|
@ -153,7 +153,7 @@ class CudnnRNNBenchmark(test.Benchmark):
|
|||
batch_size = config["batch_size"]
|
||||
seq_length = config["seq_length"]
|
||||
|
||||
with ops.Graph().as_default(), ops.device("/gpu:0"):
|
||||
with ops.Graph().as_default(), ops.device("/device:GPU:0"):
|
||||
inputs = seq_length * [
|
||||
array_ops.zeros([batch_size, num_units], dtypes.float32)
|
||||
]
|
||||
|
|
|
|||
|
|
@ -286,14 +286,14 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase):
|
|||
save_path = os.path.join(self.get_temp_dir(),
|
||||
"save-restore-variable-test")
|
||||
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
|
||||
# Passing graph explictly, otherwise an old sess would be reused.
|
||||
# Passing graph explicitly, otherwise an old sess would be reused.
|
||||
with self.test_session(
|
||||
use_gpu=True, graph=ops.get_default_graph()) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
params_v = sess.run(params)
|
||||
val = saver.save(sess, save_path)
|
||||
self.assertEqual(save_path, val)
|
||||
# Passing graph explictly, otherwise an old sess would be reused.
|
||||
# Passing graph explicitly, otherwise an old sess would be reused.
|
||||
with self.test_session(
|
||||
use_gpu=True, graph=ops.get_default_graph()) as sess:
|
||||
reset_params = state_ops.assign(
|
||||
|
|
@ -328,14 +328,14 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase):
|
|||
save_path = os.path.join(self.get_temp_dir(),
|
||||
"save-restore-variable-test")
|
||||
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
|
||||
# Passing graph explictly, otherwise an old sess would be reused.
|
||||
# Passing graph explicitly, otherwise an old sess would be reused.
|
||||
with self.test_session(
|
||||
use_gpu=True, graph=ops.get_default_graph()) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
params_v = sess.run(param_vars)
|
||||
val = saver.save(sess, save_path)
|
||||
self.assertEqual(save_path, val)
|
||||
# Passing graph explictly, otherwise an old sess would be reused.
|
||||
# Passing graph explicitly, otherwise an old sess would be reused.
|
||||
with self.test_session(
|
||||
use_gpu=True, graph=ops.get_default_graph()) as sess:
|
||||
reset_params = [
|
||||
|
|
@ -398,14 +398,14 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase):
|
|||
params=params,
|
||||
is_training=False)
|
||||
total_sum = sum(map(math_ops.reduce_sum, outputs))
|
||||
# Passing graph explictly, otherwise an old sess would be reused.
|
||||
# Passing graph explicitly, otherwise an old sess would be reused.
|
||||
with self.test_session(
|
||||
use_gpu=True, graph=ops.get_default_graph()) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
total_sum_v = sess.run(total_sum)
|
||||
val = saver.save(sess, save_path)
|
||||
self.assertEqual(save_path, val)
|
||||
# Passing graph explictly, otherwise an old sess would be reused.
|
||||
# Passing graph explicitly, otherwise an old sess would be reused.
|
||||
with self.test_session(
|
||||
use_gpu=True, graph=ops.get_default_graph()) as sess:
|
||||
reset_params = state_ops.assign(
|
||||
|
|
|
|||
|
|
@ -258,11 +258,12 @@ class Iterator(object):
|
|||
# initializers that simply reset their state to the beginning.
|
||||
raise ValueError("Iterator does not have an initializer.")
|
||||
|
||||
def make_initializer(self, dataset):
|
||||
def make_initializer(self, dataset, name=None):
|
||||
"""Returns a `tf.Operation` that initializes this iterator on `dataset`.
|
||||
|
||||
Args:
|
||||
dataset: A `Dataset` with compatible structure to this iterator.
|
||||
name: (Optional.) A name for the created operation.
|
||||
|
||||
Returns:
|
||||
A `tf.Operation` that can be run to initialize this iterator on the given
|
||||
|
|
@ -272,22 +273,25 @@ class Iterator(object):
|
|||
TypeError: If `dataset` and this iterator do not have a compatible
|
||||
element structure.
|
||||
"""
|
||||
nest.assert_same_structure(self._output_types, dataset.output_types)
|
||||
nest.assert_same_structure(self._output_shapes, dataset.output_shapes)
|
||||
for iterator_dtype, dataset_dtype in zip(
|
||||
nest.flatten(self._output_types), nest.flatten(dataset.output_types)):
|
||||
if iterator_dtype != dataset_dtype:
|
||||
raise TypeError(
|
||||
"Expected output types %r but got dataset with output types %r." %
|
||||
(self._output_types, dataset.output_types))
|
||||
for iterator_shape, dataset_shape in zip(
|
||||
nest.flatten(self._output_shapes), nest.flatten(dataset.output_shapes)):
|
||||
if not iterator_shape.is_compatible_with(dataset_shape):
|
||||
raise TypeError("Expected output shapes compatible with %r but got "
|
||||
"dataset with output shapes %r." %
|
||||
(self._output_shapes, dataset.output_shapes))
|
||||
return gen_dataset_ops.make_iterator(dataset.make_dataset_resource(),
|
||||
self._iterator_resource)
|
||||
with ops.name_scope(name, "make_initializer") as name:
|
||||
nest.assert_same_structure(self._output_types, dataset.output_types)
|
||||
nest.assert_same_structure(self._output_shapes, dataset.output_shapes)
|
||||
for iterator_dtype, dataset_dtype in zip(
|
||||
nest.flatten(self._output_types), nest.flatten(dataset.output_types)):
|
||||
if iterator_dtype != dataset_dtype:
|
||||
raise TypeError(
|
||||
"Expected output types %r but got dataset with output types %r." %
|
||||
(self._output_types, dataset.output_types))
|
||||
for iterator_shape, dataset_shape in zip(
|
||||
nest.flatten(self._output_shapes),
|
||||
nest.flatten(dataset.output_shapes)):
|
||||
if not iterator_shape.is_compatible_with(dataset_shape):
|
||||
raise TypeError("Expected output shapes compatible with %r but got "
|
||||
"dataset with output shapes %r." %
|
||||
(self._output_shapes, dataset.output_shapes))
|
||||
return gen_dataset_ops.make_iterator(dataset.make_dataset_resource(),
|
||||
self._iterator_resource,
|
||||
name=name)
|
||||
|
||||
def get_next(self, name=None):
|
||||
"""Returns a nested structure of `tf.Tensor`s containing the next element.
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import *
|
|||
from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
|
||||
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
|
||||
from tensorflow.contrib.distributions.python.ops.sample_stats import *
|
||||
from tensorflow.contrib.distributions.python.ops.test_util import *
|
||||
from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import *
|
||||
from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import *
|
||||
from tensorflow.contrib.distributions.python.ops.wishart import *
|
||||
|
|
|
|||
|
|
@ -634,7 +634,7 @@ class MixtureBenchmark(test.Benchmark):
|
|||
np.random.seed(127)
|
||||
with session.Session(config=config, graph=ops.Graph()) as sess:
|
||||
random_seed.set_random_seed(0)
|
||||
with ops.device("/gpu:0" if use_gpu else "/cpu:0"):
|
||||
with ops.device("/device:GPU:0" if use_gpu else "/cpu:0"):
|
||||
mixture = create_distribution(
|
||||
num_components=num_components,
|
||||
batch_size=batch_size,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,9 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util
|
||||
from tensorflow.contrib.framework.python.ops import variables as variables_lib2
|
||||
from tensorflow.python.framework import constant_op
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from tensorflow.python.platform import tf_logging as logging
|
|||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.training import saver as tf_saver
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
|
||||
|
||||
__all__ = ['add_model_variable',
|
||||
|
|
@ -82,7 +83,7 @@ def zero_initializer(ref, use_locking=True, name="zero_initializer"):
|
|||
resource_loader.get_path_to_datafile("_variable_ops.so"))
|
||||
return gen_variable_ops.zero_initializer(ref, name=name)
|
||||
|
||||
|
||||
@deprecated(None, "Please switch to tf.train.assert_global_step")
|
||||
def assert_global_step(global_step_tensor):
|
||||
training_util.assert_global_step(global_step_tensor)
|
||||
|
||||
|
|
@ -110,11 +111,11 @@ def assert_or_get_global_step(graph=None, global_step_tensor=None):
|
|||
assert_global_step(global_step_tensor)
|
||||
return global_step_tensor
|
||||
|
||||
|
||||
@deprecated(None, "Please switch to tf.train.get_global_step")
|
||||
def get_global_step(graph=None):
|
||||
return training_util.get_global_step(graph)
|
||||
|
||||
|
||||
@deprecated(None, "Please switch to tf.train.create_global_step")
|
||||
def create_global_step(graph=None):
|
||||
"""Create global step tensor in graph.
|
||||
|
||||
|
|
@ -132,7 +133,7 @@ def create_global_step(graph=None):
|
|||
"""
|
||||
return training_util.create_global_step(graph)
|
||||
|
||||
|
||||
@deprecated(None, "Please switch to tf.train.get_or_create_global_step")
|
||||
def get_or_create_global_step(graph=None):
|
||||
"""Returns and create (if necessary) the global step tensor.
|
||||
|
||||
|
|
@ -561,7 +562,7 @@ def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False):
|
|||
grouped_vars[ckpt_name].append(var)
|
||||
|
||||
else:
|
||||
for ckpt_name, value in var_list.iteritems():
|
||||
for ckpt_name, value in var_list.items():
|
||||
if isinstance(value, (tuple, list)):
|
||||
grouped_vars[ckpt_name] = value
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -443,19 +443,19 @@ class VariablesTest(test.TestCase):
|
|||
e = variables_lib2.variable('e', initializer=e_init)
|
||||
# The values below highlight how the VariableDeviceChooser puts initial
|
||||
# values on the same device as the variable job.
|
||||
self.assertDeviceEqual(a.device, '/gpu:0')
|
||||
self.assertDeviceEqual(a.device, '/device:GPU:0')
|
||||
self.assertEqual(a.initial_value.op.colocation_groups(),
|
||||
a.op.colocation_groups())
|
||||
self.assertDeviceEqual(b.device, '/gpu:0')
|
||||
self.assertDeviceEqual(b.device, '/device:GPU:0')
|
||||
self.assertEqual(b.initial_value.op.colocation_groups(),
|
||||
b.op.colocation_groups())
|
||||
self.assertDeviceEqual(c.device, '/cpu:12')
|
||||
self.assertEqual(c.initial_value.op.colocation_groups(),
|
||||
c.op.colocation_groups())
|
||||
self.assertDeviceEqual(d.device, '/gpu:0')
|
||||
self.assertDeviceEqual(d.device, '/device:GPU:0')
|
||||
self.assertEqual(d.initial_value.op.colocation_groups(),
|
||||
d.op.colocation_groups())
|
||||
self.assertDeviceEqual(e.device, '/gpu:0')
|
||||
self.assertDeviceEqual(e.device, '/device:GPU:0')
|
||||
self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
|
||||
|
||||
|
||||
|
|
|
|||
125
tensorflow/contrib/gdr/BUILD
Normal file
125
tensorflow/contrib/gdr/BUILD
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
# Description:
|
||||
# GPU Direct RDMA Out-of-Band Tensor transport for TensorFlow.
|
||||
|
||||
package(default_visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "c_srcs",
|
||||
data = glob([
|
||||
"**/*.cc",
|
||||
"**/*.h",
|
||||
]),
|
||||
)
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cuda_library",
|
||||
)
|
||||
|
||||
# For platform specific build config
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_proto_library_cc",
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "gdr_proto",
|
||||
srcs = ["gdr.proto"],
|
||||
cc_api_version = 2,
|
||||
visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "gdr_memory_manager",
|
||||
srcs = ["gdr_memory_manager.cc"],
|
||||
hdrs = ["gdr_memory_manager.h"],
|
||||
linkopts = select({
|
||||
"//tensorflow:with_gdr_support": [
|
||||
"-libverbs",
|
||||
"-lrdmacm",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
":gdr_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "gdr_worker",
|
||||
srcs = ["gdr_worker.cc"],
|
||||
hdrs = ["gdr_worker.h"],
|
||||
deps = [
|
||||
":gdr_memory_manager",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:graph_mgr",
|
||||
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||
"//tensorflow/core/distributed_runtime:worker",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime:worker_session",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_tensor_coding",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gdr_rendezvous_mgr",
|
||||
srcs = ["gdr_rendezvous_mgr.cc"],
|
||||
hdrs = ["gdr_rendezvous_mgr.h"],
|
||||
deps = [
|
||||
":gdr_memory_manager",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gdr_server_lib",
|
||||
srcs = ["gdr_server_lib.cc"],
|
||||
hdrs = ["gdr_server_lib.h"],
|
||||
linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
|
||||
deps = [
|
||||
":gdr_memory_manager",
|
||||
":gdr_rendezvous_mgr",
|
||||
":gdr_worker",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
122
tensorflow/contrib/gdr/README.md
Normal file
122
tensorflow/contrib/gdr/README.md
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
Introduction
|
||||
===
|
||||
|
||||
This is an implementation of GDR out-of-band transport for TensorFlow distributed runtime, complementary to current gRPC transport. It uses gRPC as control plane to setup rendezvous for each tensor transmission, and utilizes [GPU Direct RDMA](https://developer.nvidia.com/gpudirect) whenever possible to transmit tensors in remote GPU memory through network interface card (NIC), bypassing host memory and CPU entirely. It gracefully falls back to ordinary RDMA or even gRPC when GDR is not available.
|
||||
|
||||
Design
|
||||
===
|
||||
|
||||
The GDR out-of-band transport is designed to avoid any unnecessary memory copies, especially for large tensors (>100MB). That typically requires registration of tensor buffers to NIC in an ad-hoc manner, which is rather slow as described in the design trade-off of the verbs runtime. The verbs runtime thus chooses to manage its own NIC-registered buffers and copy the tensors from/to those buffers for every single tensor transfer.
|
||||
|
||||
We show that, however, such design trade-off is not always relevant. In this patch, we manage both computation and communication buffers in a unified manner. By pre-registration of large buffers to NIC and allocating small tensors from the buffer pool using a BFC allocator, it is possible to avoid both ad-hoc buffer registration and memory copies all together.
|
||||
|
||||
For the actual tensor transport, we rely on gRPC to transmit the [remote buffer information](gdr.proto). This greatly simplifies our design, and there are only 2 types of RDMA messages: a single READ to retrieve the tensor data (bypassing remote CPU), and another invalidate using WRITE with IMM to release the tensor buffer on the remote side. The remote side will only be polling the invalidate message and `Unref` the tensor buffers that read by its peer.
|
||||
|
||||
Environment
|
||||
===
|
||||
|
||||
To fully utilize GDR, the target environment has to meet 3 conditions:
|
||||
|
||||
1. There is an RDMA capable device with corresponding [OFED package](https://www.openfabrics.org/index.php/overview.html) installed (detailed information is available from your [Infiniband/RoCE](http://www.mellanox.com/page/products_dyn?product_family=116)/[iWarp](http://www.chelsio.com/gpudirect-rdma/) vendor), which could be verified through `ibv_devinfo`, e.g.
|
||||
|
||||
```
|
||||
$ ibv_devinfo
|
||||
hca_id: mlx4_0
|
||||
transport: InfiniBand (0)
|
||||
fw_ver: 2.40.7000
|
||||
node_guid: 248a:0703:00f6:3370
|
||||
sys_image_guid: 248a:0703:00f6:3370
|
||||
vendor_id: 0x02c9
|
||||
vendor_part_id: 4099
|
||||
hw_ver: 0x1
|
||||
board_id: MT_1090110023
|
||||
phys_port_cnt: 2
|
||||
Device ports:
|
||||
port: 1
|
||||
state: PORT_ACTIVE (4)
|
||||
max_mtu: 4096 (5)
|
||||
active_mtu: 1024 (3)
|
||||
sm_lid: 0
|
||||
port_lid: 0
|
||||
port_lmc: 0x00
|
||||
link_layer: Ethernet
|
||||
|
||||
port: 2
|
||||
state: PORT_ACTIVE (4)
|
||||
max_mtu: 4096 (5)
|
||||
active_mtu: 1024 (3)
|
||||
sm_lid: 0
|
||||
port_lid: 0
|
||||
port_lmc: 0x00
|
||||
link_layer: Ethernet
|
||||
```
|
||||
|
||||
2. There is a GDR capable GPU, i.e. of Fermi, Kepler or later architecture with [corresponding driver](http://docs.nvidia.com/cuda/gpudirect-rdma/index.html) installed. The PCI-e topology could be confirmed by `nvidia-smi topo -m`. For example, in the following topology, `GPU2` and `GPU3` are adjacent to `mlx4_0`, and tensors on these devices could benefit from GDR in current implementation.
|
||||
|
||||
```
|
||||
$ nvidia-smi topo -m
|
||||
GPU0 GPU1 GPU2 GPU3 mlx4_0 CPU Affinity
|
||||
GPU0 X PHB SOC SOC SOC 0-5
|
||||
GPU1 PHB X SOC SOC SOC 0-5
|
||||
GPU2 SOC SOC X PHB PHB 6-11
|
||||
GPU3 SOC SOC PHB X PHB 6-11
|
||||
mlx4_0 SOC SOC PHB PHB X
|
||||
|
||||
Legend:
|
||||
|
||||
X = Self
|
||||
SOC = Connection traversing PCIe as well as the SMP link between CPU sockets(e.g. QPI)
|
||||
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
|
||||
PXB = Connection traversing multiple PCIe switches (without traversing the PCIe Host Bridge)
|
||||
PIX = Connection traversing a single PCIe switch
|
||||
NV# = Connection traversing a bonded set of # NVLinks
|
||||
```
|
||||
|
||||
3. The [`nv_peer_mem`](https://github.com/Mellanox/nv_peer_memory) kernel module is installed.
|
||||
|
||||
How to build and run in GDR mode
|
||||
===
|
||||
|
||||
To test it out on a GDR capable environment, choose to enable GDR in your configure script.
|
||||
|
||||
```
|
||||
Do you wish to build TensorFlow with GDR support? [y/N]: y
|
||||
GDR support will be enabled for TensorFlow.
|
||||
```
|
||||
|
||||
Change your `protocol` to `grpc+gdr` to enable GDR in your deployment.
|
||||
|
||||
```
|
||||
server = tf.train.Server(cluster, job_name="local", task_index=0, protocol='grpc+gdr') # default protocol is 'grpc'
|
||||
```
|
||||
|
||||
Currently the out-of-band transport service listens to the same IP and port address as specified in gRPC.
|
||||
|
||||
A successful initialization looks like this:
|
||||
|
||||
```
|
||||
2017-08-05 19:10:38.601718: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:0) -> (device: 0, name: Tesla K40m, pci bus id: 0000:02:00.0)
|
||||
2017-08-05 19:10:38.601728: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:1) -> (device: 1, name: Tesla K40m, pci bus id: 0000:03:00.0)
|
||||
2017-08-05 19:10:38.601736: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:2) -> (device: 2, name: Tesla K40m, pci bus id: 0000:82:00.0)
|
||||
2017-08-05 19:10:38.601742: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:3) -> (device: 3, name: Tesla K40m, pci bus id: 0000:83:00.0)
|
||||
2017-08-05 19:10:39.591026: I tensorflow/contrib/gdr/gdr_memory_manager.cc:235] RDMA server is listening on 10.40.2.200:5001
|
||||
2017-08-05 19:10:39.591071: I tensorflow/contrib/gdr/gdr_memory_manager.cc:285] Instrumenting CPU allocator cuda_host_bfc
|
||||
2017-08-05 19:10:39.591083: I tensorflow/contrib/gdr/gdr_memory_manager.cc:285] Instrumenting CPU allocator cpu_pool
|
||||
2017-08-05 19:10:39.591095: I tensorflow/contrib/gdr/gdr_memory_manager.cc:285] Instrumenting CPU allocator cpu_rdma_bfc
|
||||
2017-08-05 19:10:39.591278: I tensorflow/contrib/gdr/gdr_memory_manager.cc:78] NUMA node for device: mlx4_0 is 1
|
||||
2017-08-05 19:10:39.740253: I tensorflow/contrib/gdr/gdr_memory_manager.cc:296] Instrumenting GPU allocator with bus_id 2
|
||||
```
|
||||
|
||||
The last line suggests that the GPUs with bus id 2 (mapped to pci bus id prefixed 0000:8) will benefit from GDR and host memory bypass, which is `/gpu:2` and `/gpu:3` in this case.
|
||||
|
||||
Caveats
|
||||
===
|
||||
|
||||
In current implementation, only tensors that reside in host memory or in GPU memory such that the GPU is adjacent to an RDMA capable NIC will use direct RDMA as its transport. When RDMA is available but not GDR, a temporary tensor copy on host memory will be used as RDMA source/destination (and copied from/to the target device). When there is no RDMA device present, it can even fallback to the original gRPC runtime. While it is theoretically possible to mix GDR enabled TF with non-GDR deployments in the same job, make sure the environment is properly setup so the GDR mode is enabled whenever possible (i.e. do not fall back to gRPC when it is not absolutely necessary).
|
||||
|
||||
In the original design (as in the reference), tensor buffers are only registered to NIC when we could determine that the tensor will be either a source of Send or a sink of Recv across physical machine boundary. However, to implement the precise allocations, we need to change all the devices to possibly return a NIC compatible allocator. As GDR is currently in contrib, we would like to avoid the unnecessary code disruption to the TF core, so we allocate all tensors from NIC-registered buffers using a BFC allocator. This behaviour is similar to the effect of enabling the extra GPU option `force_gpu_compatible`, which allocate all host tensors in GPU-registered buffers no matter they will be transferred from/to GPUs or not.
|
||||
|
||||
Reference
|
||||
===
|
||||
|
||||
Bairen Yi, Jiacheng Xia, Li Chen, and Kai Chen. 2017. Towards Zero Copy Dataflows using RDMA. In Proceedings of SIGCOMM Posters and Demos'17, Los Angeles, CA, USA, August 22-24, 2017, 3 pages. https://doi.org/10.1145/3123878.3123907
|
||||
13
tensorflow/contrib/gdr/gdr.proto
Normal file
13
tensorflow/contrib/gdr/gdr.proto
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
message RemoteMemoryRegion {
|
||||
string host = 1;
|
||||
string port = 2;
|
||||
uint64 addr = 3;
|
||||
uint32 rkey = 4;
|
||||
uint32 tensor_key = 5;
|
||||
uint64 checksum = 6;
|
||||
}
|
||||
682
tensorflow/contrib/gdr/gdr_memory_manager.cc
Normal file
682
tensorflow/contrib/gdr/gdr_memory_manager.cc
Normal file
|
|
@ -0,0 +1,682 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef TENSORFLOW_USE_GDR
|
||||
|
||||
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <cerrno>
|
||||
#include <fstream>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <rdma/rdma_cma.h>
|
||||
#include <rdma/rdma_verbs.h>
|
||||
#include <sys/epoll.h>
|
||||
|
||||
#include "tensorflow/contrib/gdr/gdr.pb.h"
|
||||
#include "tensorflow/core/common_runtime/bfc_allocator.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/process_state.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
#include "tensorflow/core/framework/allocator_registry.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsGDRAvailable() {
|
||||
#if defined(__APPLE__)
|
||||
return false;
|
||||
#elif defined(PLATFORM_WINDOWS)
|
||||
return false;
|
||||
#else
|
||||
std::ifstream ifs("/proc/modules");
|
||||
string line;
|
||||
while (std::getline(ifs, line)) {
|
||||
auto sep = line.find(' ');
|
||||
CHECK_NE(sep, std::string::npos);
|
||||
if (line.substr(0, sep) == "nv_peer_mem") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
int TryToReadNumaNode(ibv_device* device) {
|
||||
#if defined(__APPLE__)
|
||||
LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
|
||||
return 0;
|
||||
#elif defined(PLATFORM_WINDOWS)
|
||||
// Windows support for NUMA is not currently implemented. Return node 0.
|
||||
return 0;
|
||||
#else
|
||||
VLOG(2) << "Trying to read NUMA node for device: " << device->name;
|
||||
static const int kUnknownNumaNode = -1;
|
||||
|
||||
auto filename = string(device->ibdev_path) + "/device/numa_node";
|
||||
|
||||
std::ifstream ifs(filename.c_str());
|
||||
string content;
|
||||
CHECK(std::getline(ifs, content));
|
||||
|
||||
int32 value;
|
||||
if (strings::safe_strto32(content, &value)) {
|
||||
if (value < 0) {
|
||||
LOG(INFO) << "Successful NUMA node read from SysFS had negative value ("
|
||||
<< value << "), but there must be at least one NUMA node"
|
||||
", so returning NUMA node zero";
|
||||
return 0;
|
||||
}
|
||||
LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
|
||||
return value;
|
||||
}
|
||||
return kUnknownNumaNode;
|
||||
#endif
|
||||
}
|
||||
|
||||
void EndpointDeleter(rdma_cm_id* id) {
|
||||
if (id) {
|
||||
rdma_destroy_ep(id);
|
||||
}
|
||||
}
|
||||
|
||||
void MRDeleter(ibv_mr* mr) {
|
||||
if (mr) {
|
||||
rdma_dereg_mr(mr);
|
||||
}
|
||||
}
|
||||
|
||||
using RdmaEndpointPtr = std::unique_ptr<rdma_cm_id, decltype(&EndpointDeleter)>;
|
||||
|
||||
using MemoryRegionPtr = std::unique_ptr<ibv_mr, decltype(&MRDeleter)>;
|
||||
|
||||
class GdrMemoryManager : public RemoteMemoryManager {
|
||||
public:
|
||||
GdrMemoryManager(const string& host, const string& port);
|
||||
|
||||
virtual ~GdrMemoryManager();
|
||||
|
||||
virtual Status Init() override;
|
||||
|
||||
virtual void Run() override;
|
||||
|
||||
virtual void Stop() override;
|
||||
|
||||
virtual Status TransportOptionsFromTensor(
|
||||
::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
|
||||
Device* device, DeviceContext* device_context, bool on_host) override;
|
||||
|
||||
virtual Status TensorFromTransportOptions(
|
||||
Tensor* tensor, const ::google::protobuf::Any& transport_options,
|
||||
Device* device, DeviceContext* device_context, bool on_host) override;
|
||||
|
||||
protected:
|
||||
Status CreateEndpoint(const string& host, const string& port,
|
||||
RdmaEndpointPtr& endpoint);
|
||||
|
||||
static bool Comparator(const void* ptr, const MemoryRegionPtr& other) {
|
||||
return ptr < reinterpret_cast<char*>(other->addr) + other->length;
|
||||
}
|
||||
|
||||
ibv_mr* FindMemoryRegion(void* addr, size_t length);
|
||||
|
||||
void InsertMemoryRegion(void* addr, size_t length);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
void InsertCUDAMemoryRegion(void* addr, size_t length);
|
||||
#endif
|
||||
|
||||
void EvictMemoryRegion(void* addr, size_t length);
|
||||
|
||||
private:
|
||||
const string host_;
|
||||
const string port_;
|
||||
RdmaEndpointPtr listening_;
|
||||
std::atomic<bool> stopped_;
|
||||
int epfd_;
|
||||
|
||||
// Server side endpoints
|
||||
// Accessed sequentially in Run() so not protected by lock
|
||||
std::list<RdmaEndpointPtr> server_clients_;
|
||||
|
||||
using TensorKey = uint32_t;
|
||||
std::atomic<TensorKey> next_key_;
|
||||
|
||||
// Server side on-the-fly tensor buffers
|
||||
mutex server_mu_;
|
||||
std::map<TensorKey, const TensorBuffer*> tensor_buffers_
|
||||
GUARDED_BY(server_mu_);
|
||||
|
||||
// Client side endpoints
|
||||
mutex client_mu_;
|
||||
std::map<std::pair<string, string>, RdmaEndpointPtr> clients_
|
||||
GUARDED_BY(cient_mu_);
|
||||
|
||||
// Managed memory regions
|
||||
mutex alloc_mu_;
|
||||
std::vector<MemoryRegionPtr> mrs_ GUARDED_BY(alloc_mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager);
|
||||
};
|
||||
|
||||
// TODO(byronyi): remove this class duplicated from the one in
|
||||
// common/runtime/gpu/pool_allocator.h when it is available in common_runtime
|
||||
class BasicCPUAllocator : public SubAllocator {
|
||||
public:
|
||||
~BasicCPUAllocator() override {}
|
||||
|
||||
void* Alloc(size_t alignment, size_t num_bytes) override {
|
||||
return port::AlignedMalloc(num_bytes, alignment);
|
||||
}
|
||||
void Free(void* ptr, size_t) override { port::AlignedFree(ptr); }
|
||||
};
|
||||
|
||||
// TODO(byronyi): remove this class and its registration when the default
|
||||
// cpu_allocator() returns visitable allocator
|
||||
class BFCRdmaAllocator : public BFCAllocator {
|
||||
public:
|
||||
BFCRdmaAllocator()
|
||||
: BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") {
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator);
|
||||
|
||||
GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
|
||||
: host_(host),
|
||||
port_(port),
|
||||
listening_(nullptr, EndpointDeleter),
|
||||
stopped_(true),
|
||||
next_key_(0) {}
|
||||
|
||||
GdrMemoryManager::~GdrMemoryManager() { close(epfd_); }
|
||||
|
||||
Status GdrMemoryManager::Init() {
|
||||
epfd_ = epoll_create1(0);
|
||||
if (epfd_ == -1) {
|
||||
return errors::Unavailable(strerror(errno), ": ", "epoll_create");
|
||||
}
|
||||
|
||||
rdma_addrinfo* addrinfo;
|
||||
rdma_addrinfo hints = {};
|
||||
hints.ai_port_space = RDMA_PS_TCP;
|
||||
hints.ai_flags = RAI_PASSIVE;
|
||||
if (rdma_getaddrinfo(const_cast<char*>(host_.c_str()),
|
||||
const_cast<char*>(port_.c_str()), &hints, &addrinfo)) {
|
||||
return errors::Unavailable(strerror(errno), ": ", "cannot resolve rdma://",
|
||||
host_, ":", port_);
|
||||
}
|
||||
|
||||
ibv_qp_init_attr init_attr = {};
|
||||
init_attr.qp_type = IBV_QPT_RC;
|
||||
init_attr.cap.max_recv_wr = 32;
|
||||
init_attr.cap.max_send_wr = 1;
|
||||
init_attr.cap.max_recv_sge = 1;
|
||||
init_attr.cap.max_send_sge = 1;
|
||||
|
||||
// Create listening endpoint
|
||||
rdma_cm_id* id;
|
||||
if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
|
||||
return errors::Unavailable(strerror(errno), ": ", "cannot bind to rdma://",
|
||||
host_, ":", port_);
|
||||
}
|
||||
listening_.reset(id);
|
||||
rdma_freeaddrinfo(addrinfo);
|
||||
|
||||
// Listen without backlog
|
||||
if (rdma_listen(listening_.get(), 0)) {
|
||||
return errors::Unavailable(strerror(errno), ": ",
|
||||
"cannot listen on rdma://", host_, ":", port_);
|
||||
}
|
||||
LOG(INFO) << "RDMA server is listening on " << host_ << ":" << port_;
|
||||
|
||||
if (listening_->verbs == nullptr) {
|
||||
return errors::Unimplemented(
|
||||
"Unsupported address ", host_, ":", port_,
|
||||
" as it does not bind to a particular RDMA device");
|
||||
}
|
||||
|
||||
int flags = fcntl(listening_->channel->fd, F_GETFL, 0);
|
||||
if (fcntl(listening_->channel->fd, F_SETFL, flags | O_NONBLOCK)) {
|
||||
return errors::Unavailable(strerror(errno), ": ",
|
||||
"cannot set server to non-blocking mode");
|
||||
}
|
||||
|
||||
epoll_event event = {};
|
||||
event.events = EPOLLIN | EPOLLPRI;
|
||||
event.data.ptr = listening_.get();
|
||||
if (epoll_ctl(epfd_, EPOLL_CTL_ADD, listening_->channel->fd, &event)) {
|
||||
return errors::Unavailable(strerror(errno), ": ",
|
||||
"cannot add server to epoll");
|
||||
}
|
||||
|
||||
Allocator* allocators[] = {
|
||||
#if GOOGLE_CUDA
|
||||
ProcessState::singleton()->GetCUDAHostAllocator(0),
|
||||
ProcessState::singleton()->GetCPUAllocator(0),
|
||||
#endif // GOOGLE_CUDA
|
||||
cpu_allocator(),
|
||||
};
|
||||
|
||||
using namespace std::placeholders;
|
||||
VisitableAllocator::Visitor alloc_visitor =
|
||||
std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
|
||||
VisitableAllocator::Visitor free_visitor =
|
||||
std::bind(&GdrMemoryManager::EvictMemoryRegion, this, _1, _2);
|
||||
|
||||
std::set<Allocator*> instrumented_;
|
||||
|
||||
// Host memory allocators
|
||||
for (Allocator* allocator : allocators) {
|
||||
auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
|
||||
CHECK(visitable_allocator) << "is not visitable for instrumentation"
|
||||
<< allocator->Name();
|
||||
// Make sure we don't instrument the same allocator twice
|
||||
if (instrumented_.find(allocator) == std::end(instrumented_)) {
|
||||
visitable_allocator->AddAllocVisitor(alloc_visitor);
|
||||
visitable_allocator->AddFreeVisitor(free_visitor);
|
||||
instrumented_.insert(allocator);
|
||||
LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
|
||||
}
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
VisitableAllocator::Visitor cuda_alloc_visitor =
|
||||
std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
|
||||
if (IsGDRAvailable()) {
|
||||
// Note we don't free allocated GPU memory so there is no free visitor
|
||||
int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1;
|
||||
ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor);
|
||||
LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
|
||||
}
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void GdrMemoryManager::Run() {
|
||||
stopped_ = false;
|
||||
while (!stopped_) {
|
||||
epoll_event events[32];
|
||||
int ret = epoll_wait(epfd_, events, 32, 1);
|
||||
if (ret == -1) {
|
||||
LOG(ERROR) << "epoll_wait: " << strerror(errno);
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < ret; i++) {
|
||||
rdma_cm_id* id = static_cast<rdma_cm_id*>(events[i].data.ptr);
|
||||
if (id == listening_.get()) {
|
||||
// Accept incoming connections
|
||||
if (!rdma_get_request(listening_.get(), &id)) {
|
||||
if (!rdma_accept(id, nullptr)) {
|
||||
LOG(INFO) << "Accepted new RDMA connection";
|
||||
if (ibv_req_notify_cq(id->recv_cq, 0)) {
|
||||
LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed";
|
||||
EndpointDeleter(id);
|
||||
continue;
|
||||
}
|
||||
for (int i = 0; i < 32; i++) {
|
||||
if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
|
||||
LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed";
|
||||
EndpointDeleter(id);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
int flags = fcntl(id->recv_cq_channel->fd, F_GETFL, 0);
|
||||
if (fcntl(id->recv_cq_channel->fd, F_SETFL, flags | O_NONBLOCK)) {
|
||||
LOG(ERROR) << strerror(errno)
|
||||
<< ": cannot set server_client to non-blocking mode";
|
||||
EndpointDeleter(id);
|
||||
continue;
|
||||
}
|
||||
epoll_event event = {};
|
||||
event.events = EPOLLIN | EPOLLPRI;
|
||||
event.data.ptr = id;
|
||||
if (epoll_ctl(epfd_, EPOLL_CTL_ADD, id->recv_cq_channel->fd,
|
||||
&event)) {
|
||||
LOG(ERROR) << strerror(errno)
|
||||
<< ": cannot add server client to epoll";
|
||||
EndpointDeleter(id);
|
||||
continue;
|
||||
}
|
||||
server_clients_.push_back({id, EndpointDeleter});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Polling work completions
|
||||
ibv_cq* cq;
|
||||
void* context;
|
||||
if (!ibv_get_cq_event(id->recv_cq_channel, &cq, &context)) {
|
||||
ibv_ack_cq_events(id->recv_cq, 1);
|
||||
if (ibv_req_notify_cq(id->recv_cq, 0)) {
|
||||
LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed";
|
||||
continue;
|
||||
}
|
||||
ibv_wc wc[32];
|
||||
int ret = ibv_poll_cq(id->recv_cq, 32, wc);
|
||||
if (ret < 0) {
|
||||
LOG(ERROR) << "ibv_poll_cq failed";
|
||||
continue;
|
||||
}
|
||||
for (int i = 0; i < ret; i++) {
|
||||
if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) {
|
||||
LOG(ERROR) << "Received unknown operation " << wc[i].opcode;
|
||||
}
|
||||
if (wc[i].status != 0) {
|
||||
LOG(ERROR) << ibv_wc_status_str(wc[i].status);
|
||||
}
|
||||
TensorKey tensor_key = ntohl(wc[i].imm_data);
|
||||
{
|
||||
mutex_lock l(server_mu_);
|
||||
auto iter = tensor_buffers_.find(tensor_key);
|
||||
if (iter == std::end(tensor_buffers_)) {
|
||||
LOG(ERROR) << "Cannot find tensor buffer for tensor key "
|
||||
<< tensor_key;
|
||||
} else {
|
||||
const TensorBuffer* buffer = iter->second;
|
||||
buffer->Unref();
|
||||
tensor_buffers_.erase(iter);
|
||||
}
|
||||
}
|
||||
if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
|
||||
perror("rdma_post_recvv");
|
||||
LOG(ERROR) << "rdma_post_recvv failed";
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GdrMemoryManager::Stop() { stopped_ = true; }
|
||||
|
||||
Status GdrMemoryManager::TransportOptionsFromTensor(
|
||||
::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
|
||||
Device* device, DeviceContext* device_context, bool on_host) {
|
||||
auto buffer = DMAHelper::buffer(&tensor);
|
||||
void* addr = buffer->data();
|
||||
size_t length = buffer->size();
|
||||
if (length == 0) {
|
||||
return errors::Unavailable("Cannot register tensor buffer of size 0");
|
||||
}
|
||||
|
||||
ibv_mr* mr = FindMemoryRegion(addr, length);
|
||||
|
||||
Tensor host_copy;
|
||||
#if GOOGLE_CUDA
|
||||
if (!on_host && mr != nullptr) {
|
||||
TF_RETURN_IF_ERROR(GPUUtil::Sync(device));
|
||||
} else if (!on_host) {
|
||||
Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
|
||||
host_copy = Tensor(alloc, tensor.dtype(), tensor.shape());
|
||||
Status s;
|
||||
Notification n;
|
||||
GPUUtil::CopyGPUTensorToCPU(device, device_context, &tensor, &host_copy,
|
||||
[&s, &n](const Status& status) {
|
||||
s.Update(status);
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
buffer = DMAHelper::buffer(&host_copy);
|
||||
addr = buffer->data();
|
||||
length = buffer->size();
|
||||
mr = FindMemoryRegion(addr, length);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (mr == nullptr) {
|
||||
return errors::Unavailable("Cannot find pinned memory region");
|
||||
}
|
||||
|
||||
buffer->Ref();
|
||||
TensorKey tensor_key = next_key_++;
|
||||
{
|
||||
mutex_lock l(server_mu_);
|
||||
tensor_buffers_.insert(std::make_pair(tensor_key, buffer));
|
||||
}
|
||||
|
||||
uint64_t checksum = 0;
|
||||
if (VLOG_IS_ON(2)) {
|
||||
#ifdef GOOGLE_CUDA
|
||||
if (device->tensorflow_gpu_device_info() && (!on_host)) {
|
||||
if (host_copy.NumElements() > 0) {
|
||||
checksum = GPUUtil::Checksum(device, device_context, host_copy);
|
||||
} else {
|
||||
checksum = GPUUtil::Checksum(device, device_context, tensor);
|
||||
}
|
||||
} else {
|
||||
checksum = GPUUtil::Checksum(tensor);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
RemoteMemoryRegion remote_mr;
|
||||
remote_mr.set_host(host_);
|
||||
remote_mr.set_port(port_);
|
||||
remote_mr.set_addr(reinterpret_cast<uint64_t>(addr));
|
||||
remote_mr.set_rkey(mr->rkey);
|
||||
remote_mr.set_tensor_key(tensor_key);
|
||||
remote_mr.set_checksum(checksum);
|
||||
mutable_transport_options->PackFrom(remote_mr);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GdrMemoryManager::TensorFromTransportOptions(
|
||||
Tensor* tensor, const ::google::protobuf::Any& transport_options,
|
||||
Device* device, DeviceContext* device_context, bool on_host) {
|
||||
RemoteMemoryRegion remote_mr;
|
||||
if (!transport_options.UnpackTo(&remote_mr)) {
|
||||
return errors::NotFound("No RDMA transport options found");
|
||||
}
|
||||
|
||||
auto buffer = DMAHelper::buffer(tensor);
|
||||
void* addr = buffer->data();
|
||||
size_t length = buffer->size();
|
||||
ibv_mr* mr = FindMemoryRegion(addr, length);
|
||||
|
||||
Tensor host_copy;
|
||||
#if GOOGLE_CUDA
|
||||
if (!on_host && mr != nullptr) {
|
||||
TF_RETURN_IF_ERROR(GPUUtil::Sync(device));
|
||||
} else if (!on_host) {
|
||||
Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
|
||||
host_copy = Tensor(alloc, tensor->dtype(), tensor->shape());
|
||||
buffer = DMAHelper::buffer(&host_copy);
|
||||
addr = buffer->data();
|
||||
length = buffer->size();
|
||||
mr = FindMemoryRegion(addr, length);
|
||||
}
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
if (mr == nullptr) {
|
||||
return errors::Unavailable("Cannot find pinned memory region");
|
||||
}
|
||||
|
||||
decltype(clients_)::iterator iter;
|
||||
bool success;
|
||||
{
|
||||
mutex_lock l(client_mu_);
|
||||
std::tie(iter, success) = clients_.insert(
|
||||
std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()),
|
||||
RdmaEndpointPtr(nullptr, EndpointDeleter)));
|
||||
if (success || iter->second.get() == nullptr) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateEndpoint(remote_mr.host(), remote_mr.port(), iter->second));
|
||||
}
|
||||
}
|
||||
rdma_cm_id* id = iter->second.get();
|
||||
|
||||
uint64_t start = Env::Default()->NowMicros();
|
||||
|
||||
if (rdma_post_read(id, nullptr, buffer->data(), buffer->size(), mr, 0,
|
||||
remote_mr.addr(), remote_mr.rkey())) {
|
||||
return errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed");
|
||||
}
|
||||
|
||||
ibv_send_wr wr = {};
|
||||
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
wr.imm_data = htonl(remote_mr.tensor_key());
|
||||
wr.send_flags = IBV_SEND_FENCE | IBV_SEND_SIGNALED;
|
||||
ibv_send_wr* bad_wr;
|
||||
if (ibv_post_send(id->qp, &wr, &bad_wr)) {
|
||||
return errors::Unavailable(strerror(errno), ": ", "ibv_post_send failed");
|
||||
}
|
||||
|
||||
ibv_wc wc = {};
|
||||
int ret = rdma_get_send_comp(id, &wc);
|
||||
if (ret < 0 || wc.status) {
|
||||
return errors::Unavailable(ibv_wc_status_str(wc.status));
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
if (host_copy.NumElements() > 0) {
|
||||
Status s;
|
||||
Notification n;
|
||||
GPUUtil::CopyCPUTensorToGPU(&host_copy, device_context, device, tensor,
|
||||
[&s, &n](const Status& status) {
|
||||
s.Update(status);
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
uint64_t end = Env::Default()->NowMicros();
|
||||
|
||||
VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey()
|
||||
<< " of size " << buffer->size() << " with tensor key "
|
||||
<< remote_mr.tensor_key() << " took " << (end - start) << " micros";
|
||||
|
||||
uint64_t checksum = 0;
|
||||
if (VLOG_IS_ON(2)) {
|
||||
#ifdef GOOGLE_CUDA
|
||||
if (device->tensorflow_gpu_device_info() && (!on_host)) {
|
||||
if (host_copy.NumElements() > 0) {
|
||||
checksum = GPUUtil::Checksum(device, device_context, host_copy);
|
||||
} else {
|
||||
checksum = GPUUtil::Checksum(device, device_context, *tensor);
|
||||
}
|
||||
} else {
|
||||
checksum = GPUUtil::Checksum(*tensor);
|
||||
}
|
||||
CHECK(checksum == remote_mr.checksum()) << "Checksum mismatch: " << checksum
|
||||
<< "!=" << remote_mr.checksum();
|
||||
#endif
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port,
|
||||
RdmaEndpointPtr& endpoint) {
|
||||
rdma_addrinfo* addrinfo;
|
||||
rdma_addrinfo hints = {};
|
||||
hints.ai_port_space = RDMA_PS_TCP;
|
||||
if (rdma_getaddrinfo(const_cast<char*>(host.c_str()),
|
||||
const_cast<char*>(port.c_str()), &hints, &addrinfo)) {
|
||||
return errors::InvalidArgument(
|
||||
strerror(errno), ": ", "cannot connect to rdma://", host, ":", port);
|
||||
}
|
||||
|
||||
ibv_qp_init_attr init_attr = {};
|
||||
init_attr.qp_type = IBV_QPT_RC;
|
||||
init_attr.cap.max_recv_wr = 1;
|
||||
init_attr.cap.max_send_wr = 32;
|
||||
init_attr.cap.max_recv_sge = 1;
|
||||
init_attr.cap.max_send_sge = 1;
|
||||
|
||||
rdma_cm_id* id;
|
||||
if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
|
||||
rdma_freeaddrinfo(addrinfo);
|
||||
return errors::Unavailable(strerror(errno), ": ",
|
||||
"cannot create endpoint to rdma://", host, ":",
|
||||
port);
|
||||
}
|
||||
rdma_freeaddrinfo(addrinfo);
|
||||
|
||||
if (rdma_connect(id, nullptr)) {
|
||||
rdma_destroy_ep(id);
|
||||
return errors::Unavailable(strerror(errno), ": ",
|
||||
"cannot connect to rdma://", host, ":", port);
|
||||
}
|
||||
|
||||
LOG(INFO) << "RDMA endpoint connected to rdma://" << host << ":" << port;
|
||||
endpoint = RdmaEndpointPtr(id, EndpointDeleter);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) {
|
||||
if (length == 0) return nullptr;
|
||||
mutex_lock l(alloc_mu_);
|
||||
auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
|
||||
if (iter == std::end(mrs_) || iter->get()->addr > addr) {
|
||||
return nullptr;
|
||||
} else {
|
||||
return iter->get();
|
||||
}
|
||||
}
|
||||
|
||||
void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) {
|
||||
if (length == 0) return;
|
||||
ibv_mr* mr = rdma_reg_read(listening_.get(), addr, length);
|
||||
if (mr != nullptr) {
|
||||
mutex_lock l(alloc_mu_);
|
||||
auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
|
||||
mrs_.insert(iter, {mr, &MRDeleter});
|
||||
} else {
|
||||
LOG(WARNING) << "Cannot register memory region";
|
||||
}
|
||||
}
|
||||
|
||||
void GdrMemoryManager::EvictMemoryRegion(void* addr, size_t length) {
|
||||
if (length == 0) return;
|
||||
mutex_lock l(alloc_mu_);
|
||||
auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
|
||||
if (iter != std::end(mrs_) && iter->get()->addr == addr) {
|
||||
mrs_.erase(iter);
|
||||
} else {
|
||||
LOG(WARNING) << "Failed to de-register memory region";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
RemoteMemoryManager* CreateRemoteMemoryManager(const string& host,
|
||||
const string& port) {
|
||||
return new GdrMemoryManager(host, port);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_GDR
|
||||
63
tensorflow/contrib/gdr/gdr_memory_manager.h
Normal file
63
tensorflow/contrib/gdr/gdr_memory_manager.h
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef GDR_MEMORY_MANAGER_H_
|
||||
#define GDR_MEMORY_MANAGER_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace google {
|
||||
namespace protobuf {
|
||||
class Any;
|
||||
}
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Device;
|
||||
class DeviceContext;
|
||||
class Tensor;
|
||||
|
||||
// Abstract interface that handles out-of-band tensor transport.
|
||||
//
|
||||
// The transport options are encoded into a protocol buffer and transmitted via
|
||||
// some other communication channels like RPC.
|
||||
// See RecvTensorRequest in tensorflow/core/protobuf/worker.proto
|
||||
class RemoteMemoryManager {
|
||||
public:
|
||||
virtual ~RemoteMemoryManager() {}
|
||||
virtual Status Init() = 0;
|
||||
virtual void Run() = 0;
|
||||
virtual void Stop() = 0;
|
||||
|
||||
// Encodes the tensor information to an arbitrary protocol buffer
|
||||
// The protocol buffer needs to be transmitted via some other channel
|
||||
virtual Status TransportOptionsFromTensor(
|
||||
::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
|
||||
Device* device, DeviceContext* device_context, bool on_host) = 0;
|
||||
|
||||
// Retrieve the tensor from the encoded protocol buffer
|
||||
// Note that the tensor has to be allocated, but not initialized
|
||||
virtual Status TensorFromTransportOptions(
|
||||
Tensor* tensor, const ::google::protobuf::Any& transport_options,
|
||||
Device* device, DeviceContext* device_context, bool on_host) = 0;
|
||||
};
|
||||
|
||||
RemoteMemoryManager* CreateRemoteMemoryManager(const string& host,
|
||||
const string& port);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GDR_MEMORY_MANAGER_H_
|
||||
201
tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
Normal file
201
tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h"
|
||||
|
||||
#include "google/protobuf/any.pb.h"
|
||||
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
class GdrRecvTensorCall : public BaseRecvTensorCall {
|
||||
public:
|
||||
GdrRecvTensorCall(WorkerInterface* wi, Device* dst_device,
|
||||
RemoteMemoryManager* remote_memory_manager,
|
||||
const Rendezvous::Args& recv_args, int64 step_id,
|
||||
StringPiece key)
|
||||
: wi_(wi),
|
||||
dst_device_(dst_device),
|
||||
remote_memory_manager_(remote_memory_manager),
|
||||
recv_args_(recv_args) {
|
||||
req_.set_step_id(step_id);
|
||||
req_.set_rendezvous_key(key.data(), key.size());
|
||||
}
|
||||
|
||||
~GdrRecvTensorCall() override {}
|
||||
|
||||
void Start(std::function<void()> recv_done) override {
|
||||
req_.set_dma_ok(true);
|
||||
resp_.InitAlloc(dst_device_, recv_args_.alloc_attrs);
|
||||
StatusCallback cb = [this, recv_done](const Status& s) {
|
||||
bool dma_ok = resp_.metadata().has_transport_options();
|
||||
if (s.ok() && tensor().TotalBytes() > 0 && (!is_dead()) && dma_ok) {
|
||||
auto transport_options = resp_.metadata().transport_options();
|
||||
const bool on_host =
|
||||
(dst_device_->tensorflow_gpu_device_info() == nullptr) ||
|
||||
recv_args_.alloc_attrs.on_host();
|
||||
Status s = remote_memory_manager_->TensorFromTransportOptions(
|
||||
const_cast<Tensor*>(&tensor()), transport_options, dst_device_,
|
||||
recv_args_.device_context, on_host);
|
||||
if (!s.ok()) {
|
||||
mutex_lock l(mu_);
|
||||
status_.Update(s);
|
||||
LOG(ERROR)
|
||||
<< "Cannot find pinned memory region from allocator "
|
||||
<< dst_device_->GetAllocator(recv_args_.alloc_attrs)->Name();
|
||||
}
|
||||
}
|
||||
if (!s.ok()) {
|
||||
mutex_lock l(mu_);
|
||||
status_.Update(s);
|
||||
}
|
||||
recv_done();
|
||||
};
|
||||
wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));
|
||||
}
|
||||
|
||||
void StartAbort(const Status& s) override {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
status_.Update(s);
|
||||
}
|
||||
opts_.StartCancel();
|
||||
}
|
||||
|
||||
Status status() const override {
|
||||
mutex_lock l(mu_);
|
||||
return status_;
|
||||
}
|
||||
|
||||
const Tensor& tensor() const { return resp_.tensor(); }
|
||||
|
||||
bool is_dead() const { return resp_.metadata().is_dead(); }
|
||||
|
||||
Device* dst_device() const { return dst_device_; }
|
||||
|
||||
const Rendezvous::Args& recv_args() const { return recv_args_; }
|
||||
|
||||
private:
|
||||
WorkerInterface* wi_;
|
||||
Device* dst_device_;
|
||||
RemoteMemoryManager* remote_memory_manager_;
|
||||
CallOptions opts_;
|
||||
RecvTensorRequest req_;
|
||||
TensorResponse resp_;
|
||||
Rendezvous::Args recv_args_;
|
||||
|
||||
mutable mutex mu_;
|
||||
Status status_ GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GdrRecvTensorCall);
|
||||
};
|
||||
|
||||
class GdrRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
public:
|
||||
GdrRemoteRendezvous(const WorkerEnv* env, int64 step_id,
|
||||
RemoteMemoryManager* remote_memory_manager)
|
||||
: BaseRemoteRendezvous(env, step_id),
|
||||
remote_memory_manager_(remote_memory_manager) {}
|
||||
|
||||
protected:
|
||||
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& recv_args,
|
||||
DoneCallback done) override {
|
||||
CHECK(is_initialized());
|
||||
|
||||
string src_worker;
|
||||
string src_rel_device;
|
||||
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker,
|
||||
&src_rel_device)) {
|
||||
Status s = errors::Internal(parsed.src_device,
|
||||
" is invalid remote source device.");
|
||||
done(s, Args(), recv_args, Tensor{}, false);
|
||||
return;
|
||||
}
|
||||
|
||||
WorkerSession* sess = session();
|
||||
WorkerInterface* rwi = sess->worker_cache->CreateWorker(src_worker);
|
||||
if (rwi == nullptr) {
|
||||
Status s = errors::Internal("No worker known as ", src_worker);
|
||||
done(s, Args(), recv_args, Tensor{}, false);
|
||||
return;
|
||||
}
|
||||
|
||||
Device* dst_device;
|
||||
Status s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
|
||||
if (!s.ok()) {
|
||||
sess->worker_cache->ReleaseWorker(src_worker, rwi);
|
||||
done(s, Args(), recv_args, Tensor{}, false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Prepare a RecvTensor call that can handle being aborted.
|
||||
GdrRecvTensorCall* call =
|
||||
new GdrRecvTensorCall(rwi, dst_device, remote_memory_manager_,
|
||||
recv_args, step_id_, parsed.FullKey());
|
||||
|
||||
// Record "call" in active_ so that it can be aborted cleanly.
|
||||
RegisterCall(call);
|
||||
|
||||
// Start "call".
|
||||
Ref();
|
||||
call->Start([this, call, src_worker, rwi, done]() {
|
||||
// Removes "call" from active_. Prevent StartAbort().
|
||||
DeregisterCall(call);
|
||||
// If StartAbort was called prior to DeregisterCall, then the
|
||||
// current status should be bad.
|
||||
Status s = call->status();
|
||||
done(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
|
||||
session()->worker_cache->ReleaseWorker(src_worker, rwi);
|
||||
delete call;
|
||||
Unref();
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
~GdrRemoteRendezvous() override {}
|
||||
|
||||
RemoteMemoryManager* remote_memory_manager_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GdrRemoteRendezvous);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
GdrRendezvousMgr::GdrRendezvousMgr(const WorkerEnv* env,
|
||||
RemoteMemoryManager* remote_memory_manager)
|
||||
: BaseRendezvousMgr(env), remote_memory_manager_(remote_memory_manager) {}
|
||||
|
||||
BaseRemoteRendezvous* GdrRendezvousMgr::Create(int64 step_id,
|
||||
const WorkerEnv* worker_env) {
|
||||
return new GdrRemoteRendezvous(worker_env, step_id, remote_memory_manager_);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
42
tensorflow/contrib/gdr/gdr_rendezvous_mgr.h
Normal file
42
tensorflow/contrib/gdr/gdr_rendezvous_mgr.h
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef GDR_RENDEZVOUS_MGR_H_
|
||||
#define GDR_RENDEZVOUS_MGR_H_
|
||||
|
||||
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
|
||||
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class GdrRendezvousMgr : public BaseRendezvousMgr {
|
||||
public:
|
||||
explicit GdrRendezvousMgr(const WorkerEnv* env,
|
||||
RemoteMemoryManager* remote_memory_manager);
|
||||
|
||||
protected:
|
||||
BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env);
|
||||
|
||||
private:
|
||||
RemoteMemoryManager* remote_memory_manager_; // Not owned
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GdrRendezvousMgr);
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // GDR_RENDEZVOUS_MGR_H_
|
||||
127
tensorflow/contrib/gdr/gdr_server_lib.cc
Normal file
127
tensorflow/contrib/gdr/gdr_server_lib.cc
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/gdr/gdr_server_lib.h"
|
||||
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
|
||||
#include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h"
|
||||
#include "tensorflow/contrib/gdr/gdr_worker.h"
|
||||
|
||||
#include "net/grpc/public/include/grpc/support/alloc.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GdrServer::GdrServer(const ServerDef& server_def, Env* env)
|
||||
: GrpcServer(server_def, env) {
|
||||
string host;
|
||||
string port;
|
||||
for (const auto& job : server_def.cluster().job()) {
|
||||
if (job.name() == server_def.job_name()) {
|
||||
auto iter = job.tasks().find(server_def.task_index());
|
||||
if (iter != job.tasks().end()) {
|
||||
const std::vector<string> hostname_port =
|
||||
str_util::Split(iter->second, ':');
|
||||
if (hostname_port.size() == 2) {
|
||||
host = hostname_port[0];
|
||||
port = hostname_port[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
remote_memory_manager_ = std::unique_ptr<RemoteMemoryManager>(
|
||||
CreateRemoteMemoryManager(host, port));
|
||||
}
|
||||
|
||||
GdrServer::~GdrServer() {}
|
||||
|
||||
Status GdrServer::Init() {
|
||||
RendezvousMgrCreationFunction rendezvous_mgr_func =
|
||||
[this](const WorkerEnv* env) {
|
||||
return new GdrRendezvousMgr(env, remote_memory_manager_.get());
|
||||
};
|
||||
WorkerCreationFunction worker_func = [this](WorkerEnv* env) {
|
||||
return std::unique_ptr<GdrWorker>(
|
||||
new GdrWorker(env, remote_memory_manager_.get()));
|
||||
};
|
||||
TF_RETURN_IF_ERROR(
|
||||
GrpcServer::Init(nullptr, rendezvous_mgr_func, worker_func));
|
||||
|
||||
return remote_memory_manager_->Init();
|
||||
}
|
||||
|
||||
Status GdrServer::Start() {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
gdr_thread_.reset(worker_env()->env->StartThread(
|
||||
ThreadOptions(), "TF_gdr_service",
|
||||
[this] { remote_memory_manager_->Run(); }));
|
||||
}
|
||||
return GrpcServer::Start();
|
||||
}
|
||||
|
||||
Status GdrServer::Stop() {
|
||||
TF_RETURN_IF_ERROR(GrpcServer::Stop());
|
||||
remote_memory_manager_->Stop();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GdrServer::Join() {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
gdr_thread_.reset();
|
||||
}
|
||||
return GrpcServer::Join();
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status GdrServer::Create(const ServerDef& server_def, Env* env,
|
||||
std::unique_ptr<ServerInterface>* out_server) {
|
||||
std::unique_ptr<GdrServer> ret(
|
||||
new GdrServer(server_def, env == nullptr ? Env::Default() : env));
|
||||
TF_RETURN_IF_ERROR(ret->Init());
|
||||
*out_server = std::move(ret);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class GdrServerFactory : public ServerFactory {
|
||||
public:
|
||||
bool AcceptsOptions(const ServerDef& server_def) override {
|
||||
return server_def.protocol() == "grpc+gdr";
|
||||
}
|
||||
|
||||
Status NewServer(const ServerDef& server_def,
|
||||
std::unique_ptr<ServerInterface>* out_server) override {
|
||||
return GdrServer::Create(server_def, Env::Default(), out_server);
|
||||
}
|
||||
};
|
||||
|
||||
// Registers a `ServerFactory` for `GdrServer` instances.
|
||||
class GdrServerRegistrar {
|
||||
public:
|
||||
GdrServerRegistrar() {
|
||||
gpr_allocation_functions alloc_fns;
|
||||
memset(&alloc_fns, 0, sizeof(alloc_fns));
|
||||
alloc_fns.malloc_fn = port::Malloc;
|
||||
alloc_fns.realloc_fn = port::Realloc;
|
||||
alloc_fns.free_fn = port::Free;
|
||||
gpr_set_allocation_functions(alloc_fns);
|
||||
ServerFactory::Register("GDR_SERVER", new GdrServerFactory());
|
||||
}
|
||||
};
|
||||
static GdrServerRegistrar registrar;
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
52
tensorflow/contrib/gdr/gdr_server_lib.h
Normal file
52
tensorflow/contrib/gdr/gdr_server_lib.h
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef GDR_SERVER_LIB_H_
|
||||
#define GDR_SERVER_LIB_H_
|
||||
|
||||
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class GdrServer : public GrpcServer {
|
||||
protected:
|
||||
GdrServer(const ServerDef& server_def, Env* env);
|
||||
|
||||
public:
|
||||
static Status Create(const ServerDef& server_def, Env* env,
|
||||
std::unique_ptr<ServerInterface>* out_server);
|
||||
|
||||
virtual ~GdrServer() override;
|
||||
|
||||
virtual Status Start() override;
|
||||
|
||||
virtual Status Stop() override;
|
||||
|
||||
virtual Status Join() override;
|
||||
|
||||
protected:
|
||||
Status Init();
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
|
||||
std::unique_ptr<RemoteMemoryManager> remote_memory_manager_;
|
||||
std::unique_ptr<Thread> gdr_thread_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GDR_SERVER_LIB_H_
|
||||
146
tensorflow/contrib/gdr/gdr_worker.cc
Normal file
146
tensorflow/contrib/gdr/gdr_worker.cc
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/gdr/gdr_worker.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_session.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GdrWorker::GdrWorker(WorkerEnv* worker_env,
|
||||
RemoteMemoryManager* remote_memory_manager)
|
||||
: GrpcWorker(worker_env), remote_memory_manager_(remote_memory_manager) {}
|
||||
|
||||
void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
||||
const RecvTensorRequest* request,
|
||||
::grpc::ByteBuffer* response,
|
||||
StatusCallback done) {
|
||||
const int64 step_id = request->step_id();
|
||||
const string& key = request->rendezvous_key();
|
||||
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = Rendezvous::ParseKey(key, &parsed);
|
||||
Device* src_dev = nullptr;
|
||||
if (s.ok()) {
|
||||
s = PrepareRecvTensor(parsed, &src_dev);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
|
||||
// Request the tensor associated with the rendezvous key. Any time
|
||||
// while waiting for the tensor to be produced, up until the start
|
||||
// of execution of the callback lambda body below, an RPC
|
||||
// cancellation should abort the rendezvous.
|
||||
opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
|
||||
const bool dma_ok = request->dma_ok();
|
||||
env_->rendezvous_mgr->RecvLocalAsync(
|
||||
step_id, parsed,
|
||||
[this, opts, response, done, src_dev, dma_ok](
|
||||
const Status& status, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args&, const Tensor& val, const bool is_dead) {
|
||||
opts->ClearCancelCallback();
|
||||
if (status.ok()) {
|
||||
// DMA can only be used for Tensors that do not fall into
|
||||
// the following three odd edge cases: 1) a zero-size
|
||||
// buffer, 2) a dead tensor which has an uninit value, and
|
||||
// 3) the tensor has the on_host allocation attribute,
|
||||
// i.e. it's in CPU RAM *independent of its assigned
|
||||
// device type*.
|
||||
const bool on_host =
|
||||
(src_dev->tensorflow_gpu_device_info() == nullptr) ||
|
||||
send_args.alloc_attrs.on_host();
|
||||
if (val.TotalBytes() > 0 && (!is_dead) &&
|
||||
DMAHelper::CanUseDMA(&val) && dma_ok) {
|
||||
// DMA cases.
|
||||
RecvTensorResponse proto;
|
||||
auto transport_options = proto.mutable_transport_options();
|
||||
Status s = remote_memory_manager_->TransportOptionsFromTensor(
|
||||
transport_options, val, src_dev, send_args.device_context,
|
||||
on_host);
|
||||
if (s.ok()) {
|
||||
proto.set_is_dead(is_dead);
|
||||
proto.set_send_start_micros(Env::Default()->NowMicros());
|
||||
TensorProto* tensor_proto = proto.mutable_tensor();
|
||||
tensor_proto->set_dtype(val.dtype());
|
||||
val.shape().AsProto(tensor_proto->mutable_tensor_shape());
|
||||
grpc::EncodeRecvTensorResponseToByteBuffer(proto, response);
|
||||
done(Status::OK());
|
||||
return;
|
||||
} else {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
// Non-DMA cases.
|
||||
if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
|
||||
#if GOOGLE_CUDA
|
||||
const DeviceContext* send_dev_context = send_args.device_context;
|
||||
AllocatorAttributes alloc_attrs;
|
||||
alloc_attrs.set_gpu_compatible(true);
|
||||
alloc_attrs.set_on_host(true);
|
||||
Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
|
||||
Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
|
||||
CHECK(send_dev_context)
|
||||
<< "send dev name: " << src_dev->name()
|
||||
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
|
||||
// "val" is on a GPU. Uses GPUUtil to fill the response proto.
|
||||
StatusCallback copy_ready = [response, done, copy,
|
||||
is_dead](const Status& s) {
|
||||
// The value is now ready to be returned on the wire.
|
||||
grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
|
||||
done(s);
|
||||
delete copy;
|
||||
};
|
||||
|
||||
GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy,
|
||||
copy_ready);
|
||||
#else
|
||||
done(errors::Internal("No GPU device in process"));
|
||||
#endif // GOOGLE_CUDA
|
||||
} else {
|
||||
grpc::EncodeTensorToByteBuffer(is_dead, val, response);
|
||||
done(Status::OK());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// !s.ok()
|
||||
done(status);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
45
tensorflow/contrib/gdr/gdr_worker.h
Normal file
45
tensorflow/contrib/gdr/gdr_worker.h
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef GDR_WORKER_H_
|
||||
#define GDR_WORKER_H_
|
||||
|
||||
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class GdrWorker : public GrpcWorker {
|
||||
public:
|
||||
GdrWorker(WorkerEnv* env, RemoteMemoryManager* remote_memory_manager);
|
||||
|
||||
// Serve the RecvTensorRequest but omit the tensor content and transmit it
|
||||
// out-of-band using GPU Direct RDMA whenever possible.
|
||||
// If it's not possible, it falls back to gRPC in-band tensor transport by
|
||||
// encoding the tensor content into the grpc::ByteBuffer.
|
||||
// The RecvTensorResponse will carry the necessary information for RDMA.
|
||||
virtual void GrpcRecvTensorAsync(CallOptions* opts,
|
||||
const RecvTensorRequest* request,
|
||||
::grpc::ByteBuffer* response,
|
||||
StatusCallback done) override;
|
||||
|
||||
private:
|
||||
RemoteMemoryManager* remote_memory_manager_; // Not owned
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GDR_WORKER_H_
|
||||
|
|
@ -3570,7 +3570,7 @@ def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
|
|||
|
||||
Returns:
|
||||
the tensor after 1d conv with un-shared weights, with shape (batch_size,
|
||||
output_lenght, filters)
|
||||
output_length, filters)
|
||||
|
||||
Raises:
|
||||
ValueError: if `data_format` is neither `channels_last` or
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import marshal
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import types as python_types
|
||||
|
|
@ -195,7 +196,10 @@ def func_dump(func):
|
|||
Returns:
|
||||
A tuple `(code, defaults, closure)`.
|
||||
"""
|
||||
code = marshal.dumps(func.__code__).decode('raw_unicode_escape')
|
||||
if os.name == 'nt':
|
||||
code = marshal.dumps(func.__code__).replace(b'\\',b'/').decode('raw_unicode_escape')
|
||||
else:
|
||||
code = marshal.dumps(func.__code__).decode('raw_unicode_escape')
|
||||
defaults = func.__defaults__
|
||||
if func.__closure__:
|
||||
closure = tuple(c.cell_contents for c in func.__closure__)
|
||||
|
|
|
|||
|
|
@ -1944,7 +1944,7 @@ def gdn(inputs,
|
|||
spatial dimensions. It is similar to local response normalization, but much
|
||||
more flexible, as `beta` and `gamma` are trainable parameters.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
inputs: Tensor input.
|
||||
inverse: If `False` (default), compute GDN response. If `True`, compute IGDN
|
||||
response (one step of fixed point iteration to invert GDN; the division
|
||||
|
|
|
|||
|
|
@ -66,11 +66,11 @@ from tensorflow.python.platform import gfile
|
|||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import builder as saved_model_builder
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.summary import summary as core_summary
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import device_setter
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import summary_io
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import tf_decorator
|
||||
|
|
@ -337,7 +337,7 @@ def _write_dict_to_summary(output_dir, dictionary, current_global_step):
|
|||
"""
|
||||
logging.info('Saving dict for global step %d: %s', current_global_step,
|
||||
_dict_to_str(dictionary))
|
||||
summary_writer = summary_io.SummaryWriterCache.get(output_dir)
|
||||
summary_writer = core_summary.FileWriterCache.get(output_dir)
|
||||
summary_proto = summary_pb2.Summary()
|
||||
for key in dictionary:
|
||||
if dictionary[key] is None:
|
||||
|
|
@ -1034,7 +1034,7 @@ class BaseEstimator(
|
|||
loss = None
|
||||
while not mon_sess.should_stop():
|
||||
_, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
|
||||
summary_io.SummaryWriterCache.clear()
|
||||
core_summary.FileWriterCache.clear()
|
||||
return loss
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -506,7 +506,7 @@ class EstimatorModelFnTest(test.TestCase):
|
|||
return input_fn_utils.InputFnOps(
|
||||
features, labels, {'examples': serialized_tf_example})
|
||||
|
||||
est.export_savedmodel(est.model_dir + '/export', serving_input_fn)
|
||||
est.export_savedmodel(os.path.join(est.model_dir, 'export'), serving_input_fn)
|
||||
self.assertTrue(self.mock_saver.restore.called)
|
||||
|
||||
|
||||
|
|
@ -988,10 +988,11 @@ class EstimatorTest(test.TestCase):
|
|||
self.assertTrue('input_example_tensor' in graph_ops)
|
||||
self.assertTrue('ParseExample/ParseExample' in graph_ops)
|
||||
self.assertTrue('linear/linear/feature/matmul' in graph_ops)
|
||||
self.assertSameElements(
|
||||
['bogus_lookup', 'feature'],
|
||||
graph.get_collection(
|
||||
constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS))
|
||||
self.assertItemsEqual(
|
||||
['bogus_lookup', 'feature'],
|
||||
[compat.as_str_any(x) for x in graph.get_collection(
|
||||
constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS)])
|
||||
|
||||
|
||||
# cleanup
|
||||
gfile.DeleteRecursively(tmpdir)
|
||||
|
|
|
|||
|
|
@ -44,15 +44,16 @@ import six
|
|||
|
||||
from tensorflow.contrib.framework import deprecated
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.learn.python.learn import session_run_hook
|
||||
from tensorflow.contrib.learn.python.learn.summary_writer_cache import SummaryWriterCache
|
||||
from tensorflow.core.framework.summary_pb2 import Summary
|
||||
from tensorflow.core.util.event_pb2 import SessionLog
|
||||
from tensorflow.python.estimator import estimator as core_estimator
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.summary import summary as core_summary
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import summary_io
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
|
@ -521,7 +522,7 @@ class SummarySaver(EveryN):
|
|||
self._summary_op = summary_op
|
||||
self._summary_writer = summary_writer
|
||||
if summary_writer is None and output_dir:
|
||||
self._summary_writer = summary_io.SummaryWriter(output_dir)
|
||||
self._summary_writer = core_summary.FileWriter(output_dir)
|
||||
self._scaffold = scaffold
|
||||
# TODO(mdan): Throw an error if output_dir and summary_writer are None.
|
||||
|
||||
|
|
@ -529,7 +530,7 @@ class SummarySaver(EveryN):
|
|||
super(SummarySaver, self).set_estimator(estimator)
|
||||
# TODO(mdan): This line looks redundant.
|
||||
if self._summary_writer is None:
|
||||
self._summary_writer = summary_io.SummaryWriter(estimator.model_dir)
|
||||
self._summary_writer = core_summary.FileWriter(estimator.model_dir)
|
||||
|
||||
def every_n_step_begin(self, step):
|
||||
super(SummarySaver, self).every_n_step_begin(step)
|
||||
|
|
@ -1029,7 +1030,7 @@ class CheckpointSaver(BaseMonitor):
|
|||
logging.info("Create CheckpointSaver.")
|
||||
super(CheckpointSaver, self).__init__()
|
||||
self._saver = saver
|
||||
self._summary_writer = SummaryWriterCache.get(checkpoint_dir)
|
||||
self._summary_writer = core_summary.FileWriterCache.get(checkpoint_dir)
|
||||
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
|
||||
self._scaffold = scaffold
|
||||
self._save_secs = save_secs
|
||||
|
|
@ -1098,12 +1099,12 @@ class StepCounter(EveryN):
|
|||
self._last_reported_time = None
|
||||
self._summary_writer = summary_writer
|
||||
if summary_writer is None and output_dir:
|
||||
self._summary_writer = SummaryWriterCache.get(output_dir)
|
||||
self._summary_writer = core_summary.FileWriterCache.get(output_dir)
|
||||
|
||||
def set_estimator(self, estimator):
|
||||
super(StepCounter, self).set_estimator(estimator)
|
||||
if self._summary_writer is None:
|
||||
self._summary_writer = SummaryWriterCache.get(estimator.model_dir)
|
||||
self._summary_writer = core_summary.FileWriterCache.get(estimator.model_dir)
|
||||
|
||||
def every_n_step_end(self, current_step, outputs):
|
||||
current_time = time.time()
|
||||
|
|
@ -1169,7 +1170,7 @@ class RunHookAdapterForMonitors(session_run_hook.SessionRunHook):
|
|||
|
||||
def begin(self):
|
||||
self._last_step = None
|
||||
self._global_step_tensor = contrib_variables.get_global_step()
|
||||
self._global_step_tensor = training_util.get_global_step()
|
||||
for m in self._monitors:
|
||||
m.begin(max_steps=None)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
|||
|
||||
from tensorflow.contrib import testing
|
||||
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
||||
from tensorflow.contrib.framework.python.ops import variables as variables_lib
|
||||
from tensorflow.contrib.learn.python import learn
|
||||
from tensorflow.contrib.learn.python.learn import estimators
|
||||
from tensorflow.python.client import session as session_lib
|
||||
|
|
@ -43,6 +42,7 @@ from tensorflow.python.summary import summary
|
|||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
class _MyEveryN(learn.monitors.EveryN):
|
||||
|
|
@ -616,7 +616,7 @@ class CheckpointSaverTest(test.TestCase):
|
|||
self.graph = ops.Graph()
|
||||
with self.graph.as_default():
|
||||
self.scaffold = monitored_session.Scaffold()
|
||||
self.global_step = variables_lib.get_or_create_global_step()
|
||||
self.global_step = training_util.get_or_create_global_step()
|
||||
self.train_op = state_ops.assign_add(self.global_step, 1)
|
||||
|
||||
def tearDown(self):
|
||||
|
|
@ -780,7 +780,7 @@ class RunHookAdapterForMonitorsTest(test.TestCase):
|
|||
|
||||
def test_calls_and_steps(self):
|
||||
with ops.Graph().as_default(), session_lib.Session() as sess:
|
||||
global_step_tensor = variables_lib.create_global_step()
|
||||
global_step_tensor = training_util.create_global_step()
|
||||
inc_5 = state_ops.assign_add(global_step_tensor, 5)
|
||||
mock_mon = FakeMonitor()
|
||||
mock_mon2 = FakeMonitor()
|
||||
|
|
@ -821,7 +821,7 @@ class RunHookAdapterForMonitorsTest(test.TestCase):
|
|||
|
||||
def test_requests(self):
|
||||
with ops.Graph().as_default(), session_lib.Session() as sess:
|
||||
variables_lib.create_global_step()
|
||||
training_util.create_global_step()
|
||||
mock_mon = FakeMonitor()
|
||||
mock_mon2 = FakeMonitor()
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from tensorflow.contrib.session_bundle import exporter
|
|||
from tensorflow.contrib.session_bundle import manifest_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import gfile
|
||||
|
|
@ -49,9 +50,8 @@ def _training_input_fn():
|
|||
|
||||
|
||||
class ExportTest(test.TestCase):
|
||||
|
||||
def _get_default_signature(self, export_meta_filename):
|
||||
"""Gets the default signature from the export.meta file."""
|
||||
""" Gets the default signature from the export.meta file. """
|
||||
with session.Session():
|
||||
save = saver.import_meta_graph(export_meta_filename)
|
||||
meta_graph_def = save.export_meta_graph()
|
||||
|
|
@ -68,18 +68,19 @@ class ExportTest(test.TestCase):
|
|||
self.assertTrue(gfile.Exists(export_dir))
|
||||
# Only the written checkpoints are exported.
|
||||
self.assertTrue(
|
||||
saver.checkpoint_exists(export_dir + '00000001/export'),
|
||||
saver.checkpoint_exists(os.path.join(export_dir, '00000001', 'export')),
|
||||
'Exported checkpoint expected but not found: %s' %
|
||||
(export_dir + '00000001/export'))
|
||||
os.path.join(export_dir, '00000001', 'export'))
|
||||
self.assertTrue(
|
||||
saver.checkpoint_exists(export_dir + '00000010/export'),
|
||||
saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')),
|
||||
'Exported checkpoint expected but not found: %s' %
|
||||
(export_dir + '00000010/export'))
|
||||
os.path.join(export_dir, '00000010', 'export'))
|
||||
self.assertEquals(
|
||||
six.b(os.path.join(export_dir, '00000010')),
|
||||
export_monitor.last_export_dir)
|
||||
# Validate the signature
|
||||
signature = self._get_default_signature(export_dir + '00000010/export.meta')
|
||||
signature = self._get_default_signature(
|
||||
os.path.join(export_dir, '00000010', 'export.meta'))
|
||||
self.assertTrue(signature.HasField(expected_signature))
|
||||
|
||||
def testExportMonitor_EstimatorProvidesSignature(self):
|
||||
|
|
@ -88,7 +89,7 @@ class ExportTest(test.TestCase):
|
|||
y = 2 * x + 3
|
||||
cont_features = [feature_column.real_valued_column('', dimension=1)]
|
||||
regressor = learn.LinearRegressor(feature_columns=cont_features)
|
||||
export_dir = tempfile.mkdtemp() + 'export/'
|
||||
export_dir = os.path.join(tempfile.mkdtemp(), 'export')
|
||||
export_monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1, export_dir=export_dir, exports_to_keep=2)
|
||||
regressor.fit(x, y, steps=10, monitors=[export_monitor])
|
||||
|
|
@ -99,7 +100,7 @@ class ExportTest(test.TestCase):
|
|||
x = np.random.rand(1000)
|
||||
y = 2 * x + 3
|
||||
cont_features = [feature_column.real_valued_column('', dimension=1)]
|
||||
export_dir = tempfile.mkdtemp() + 'export/'
|
||||
export_dir = os.path.join(tempfile.mkdtemp(), 'export')
|
||||
export_monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=export_dir,
|
||||
|
|
@ -122,7 +123,7 @@ class ExportTest(test.TestCase):
|
|||
input_feature_key = 'my_example_key'
|
||||
monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=tempfile.mkdtemp() + 'export/',
|
||||
export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
|
||||
input_fn=_serving_input_fn,
|
||||
input_feature_key=input_feature_key,
|
||||
exports_to_keep=2,
|
||||
|
|
@ -140,7 +141,7 @@ class ExportTest(test.TestCase):
|
|||
|
||||
monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=tempfile.mkdtemp() + 'export/',
|
||||
export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
|
||||
input_fn=_serving_input_fn,
|
||||
input_feature_key=input_feature_key,
|
||||
exports_to_keep=2,
|
||||
|
|
@ -165,7 +166,7 @@ class ExportTest(test.TestCase):
|
|||
|
||||
monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=tempfile.mkdtemp() + 'export/',
|
||||
export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
|
||||
input_fn=_serving_input_fn,
|
||||
input_feature_key=input_feature_key,
|
||||
exports_to_keep=2,
|
||||
|
|
@ -187,7 +188,7 @@ class ExportTest(test.TestCase):
|
|||
|
||||
monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=tempfile.mkdtemp() + 'export/',
|
||||
export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
|
||||
input_fn=_serving_input_fn,
|
||||
input_feature_key=input_feature_key,
|
||||
exports_to_keep=2,
|
||||
|
|
@ -210,7 +211,7 @@ class ExportTest(test.TestCase):
|
|||
shape=(1,), minval=0.0, maxval=1000.0)
|
||||
}, None
|
||||
|
||||
export_dir = tempfile.mkdtemp() + 'export/'
|
||||
export_dir = os.path.join(tempfile.mkdtemp(), 'export')
|
||||
monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=export_dir,
|
||||
|
|
@ -235,7 +236,7 @@ class ExportTest(test.TestCase):
|
|||
y = 2 * x + 3
|
||||
cont_features = [feature_column.real_valued_column('', dimension=1)]
|
||||
regressor = learn.LinearRegressor(feature_columns=cont_features)
|
||||
export_dir = tempfile.mkdtemp() + 'export/'
|
||||
export_dir = os.path.join(tempfile.mkdtemp(), 'export')
|
||||
export_monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=export_dir,
|
||||
|
|
@ -244,10 +245,13 @@ class ExportTest(test.TestCase):
|
|||
regressor.fit(x, y, steps=10, monitors=[export_monitor])
|
||||
|
||||
self.assertTrue(gfile.Exists(export_dir))
|
||||
self.assertFalse(saver.checkpoint_exists(export_dir + '00000000/export'))
|
||||
self.assertTrue(saver.checkpoint_exists(export_dir + '00000010/export'))
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export'))
|
||||
self.assertTrue(
|
||||
saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')))
|
||||
# Validate the signature
|
||||
signature = self._get_default_signature(export_dir + '00000010/export.meta')
|
||||
signature = self._get_default_signature(
|
||||
os.path.join(export_dir, '00000010', 'export.meta'))
|
||||
self.assertTrue(signature.HasField('regression_signature'))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -33,8 +33,13 @@ from tensorflow.python.util import compat
|
|||
def _create_parser(base_dir):
|
||||
# create a simple parser that pulls the export_version from the directory.
|
||||
def parser(path):
|
||||
match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$",
|
||||
compat.as_str_any(path.path))
|
||||
# Modify the path object for RegEx match for Windows Paths
|
||||
if os.name == 'nt':
|
||||
match = re.match("^" + compat.as_str_any(base_dir).replace('\\','/') + "/(\\d+)$",
|
||||
compat.as_str_any(path.path).replace('\\','/'))
|
||||
else:
|
||||
match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$",
|
||||
compat.as_str_any(path.path))
|
||||
if not match:
|
||||
return None
|
||||
return path._replace(export_version=int(match.group(1)))
|
||||
|
|
@ -48,13 +53,13 @@ class GcTest(test_util.TensorFlowTestCase):
|
|||
paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
|
||||
newest = gc.largest_export_versions(2)
|
||||
n = newest(paths)
|
||||
self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
|
||||
self.assertEqual(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
|
||||
|
||||
def testLargestExportVersionsDoesNotDeleteZeroFolder(self):
|
||||
paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)]
|
||||
newest = gc.largest_export_versions(2)
|
||||
n = newest(paths)
|
||||
self.assertEquals(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
|
||||
self.assertEqual(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
|
||||
|
||||
def testModExportVersion(self):
|
||||
paths = [
|
||||
|
|
@ -62,9 +67,9 @@ class GcTest(test_util.TensorFlowTestCase):
|
|||
gc.Path("/foo", 9)
|
||||
]
|
||||
mod = gc.mod_export_version(2)
|
||||
self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
|
||||
self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
|
||||
mod = gc.mod_export_version(3)
|
||||
self.assertEquals(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
|
||||
self.assertEqual(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
|
||||
|
||||
def testOneOfEveryNExportVersions(self):
|
||||
paths = [
|
||||
|
|
@ -73,7 +78,7 @@ class GcTest(test_util.TensorFlowTestCase):
|
|||
gc.Path("/foo", 8), gc.Path("/foo", 33)
|
||||
]
|
||||
one_of = gc.one_of_every_n_export_versions(3)
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
one_of(paths), [
|
||||
gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 8),
|
||||
gc.Path("/foo", 33)
|
||||
|
|
@ -84,14 +89,14 @@ class GcTest(test_util.TensorFlowTestCase):
|
|||
# Test that here.
|
||||
paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)]
|
||||
one_of = gc.one_of_every_n_export_versions(3)
|
||||
self.assertEquals(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)])
|
||||
self.assertEqual(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)])
|
||||
|
||||
def testUnion(self):
|
||||
paths = []
|
||||
for i in xrange(10):
|
||||
paths.append(gc.Path("/foo", i))
|
||||
f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
f(paths), [
|
||||
gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6),
|
||||
gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9)
|
||||
|
|
@ -103,9 +108,9 @@ class GcTest(test_util.TensorFlowTestCase):
|
|||
gc.Path("/foo", 9)
|
||||
]
|
||||
mod = gc.negation(gc.mod_export_version(2))
|
||||
self.assertEquals(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
|
||||
self.assertEqual(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
|
||||
mod = gc.negation(gc.mod_export_version(3))
|
||||
self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
|
||||
self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
|
||||
|
||||
def testPathsWithParse(self):
|
||||
base_dir = os.path.join(test.get_temp_dir(), "paths_parse")
|
||||
|
|
@ -115,7 +120,7 @@ class GcTest(test_util.TensorFlowTestCase):
|
|||
# add a base_directory to ignore
|
||||
gfile.MakeDirs(os.path.join(base_dir, "ignore"))
|
||||
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
gc.get_paths(base_dir, _create_parser(base_dir)),
|
||||
[
|
||||
gc.Path(os.path.join(base_dir, "0"), 0),
|
||||
|
|
|
|||
|
|
@ -57,6 +57,11 @@ REGISTER_KERNEL_BUILDER(Name("BytesLimit").Device(DEVICE_CPU), BytesLimitOp);
|
|||
REGISTER_KERNEL_BUILDER(Name("BytesLimit").Device(DEVICE_GPU).HostMemory("out"),
|
||||
BytesLimitOp);
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("BytesLimit").Device(DEVICE_SYCL).HostMemory("out"),
|
||||
BytesLimitOp);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
// Op that measures the peak memory in bytes.
|
||||
class MaxBytesInUseOp : public MemoryStatsOp {
|
||||
public:
|
||||
|
|
@ -76,4 +81,10 @@ REGISTER_KERNEL_BUILDER(
|
|||
Name("MaxBytesInUse").Device(DEVICE_GPU).HostMemory("out"),
|
||||
MaxBytesInUseOp);
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("MaxBytesInUse").Device(DEVICE_SYCL).HostMemory("out"),
|
||||
MaxBytesInUseOp);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "grpc/support/alloc.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ void MPIUtils::InitMPI() {
|
|||
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &number_of_procs));
|
||||
MPI_CHECK(MPI_Get_processor_name(my_host_name, &len));
|
||||
fprintf(stderr,
|
||||
"MPI Environment initialised. Process id: %d Total processes: %d "
|
||||
"MPI Environment initialized. Process id: %d Total processes: %d "
|
||||
"|| Hostname: %s \n",
|
||||
proc_id, number_of_procs, my_host_name);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class AllReduceTest(test.TestCase):
|
|||
self._testSingleAllReduce(sess, dtype, nccl.all_max, np.maximum)
|
||||
|
||||
def _testSingleAllReduce(self, sess, np_type, nccl_fn, numpy_accumulation_fn):
|
||||
for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]:
|
||||
for devices in [['/device:GPU:0', '/device:GPU:0', '/device:GPU:0'], ['/device:GPU:0', '/device:GPU:0']]:
|
||||
shape = (3, 4)
|
||||
np_ans = None
|
||||
tensors = []
|
||||
|
|
@ -84,7 +84,7 @@ class BroadcastTest(test.TestCase):
|
|||
# Create session inside outer loop to test use of
|
||||
# same communicator across multiple sessions.
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]:
|
||||
for devices in [['/device:GPU:0', '/device:GPU:0', '/device:GPU:0'], ['/device:GPU:0', '/device:GPU:0']]:
|
||||
shape = (3, 4)
|
||||
sender = np.random.randint(0, len(devices) - 1)
|
||||
with ops.device(devices[sender]):
|
||||
|
|
@ -115,7 +115,7 @@ class CombinedTest(test.TestCase):
|
|||
# Create session inside outer loop to test use of
|
||||
# same communicator across multiple sessions.
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]:
|
||||
for devices in [['/device:GPU:0', '/device:GPU:0', '/device:GPU:0'], ['/device:GPU:0', '/device:GPU:0']]:
|
||||
shape = (3, 4)
|
||||
|
||||
# all-reduce
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ py_library(
|
|||
"__init__.py",
|
||||
"python/__init__.py",
|
||||
"python/ops/__init__.py",
|
||||
"python/ops/alpha_dropout.py",
|
||||
"python/ops/cross_entropy.py",
|
||||
"python/ops/sampling_ops.py",
|
||||
],
|
||||
|
|
@ -44,6 +45,23 @@ py_test(
|
|||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "alpha_dropout_test",
|
||||
size = "small",
|
||||
srcs = ["python/ops/alpha_dropout_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":nn_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:random_ops",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
# ==============================================================================
|
||||
"""Module for variants of ops in tf.nn.
|
||||
|
||||
@@alpha_dropout
|
||||
@@deprecated_flipped_softmax_cross_entropy_with_logits
|
||||
@@deprecated_flipped_sparse_softmax_cross_entropy_with_logits
|
||||
@@deprecated_flipped_sigmoid_cross_entropy_with_logits
|
||||
|
|
@ -27,6 +28,7 @@ from __future__ import print_function
|
|||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.nn.python.ops.cross_entropy import *
|
||||
from tensorflow.contrib.nn.python.ops.sampling_ops import *
|
||||
from tensorflow.contrib.nn.python.ops.alpha_dropout import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
|
|
|||
88
tensorflow/contrib/nn/python/ops/alpha_dropout.py
Normal file
88
tensorflow/contrib/nn/python/ops/alpha_dropout.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numbers
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_impl
|
||||
|
||||
|
||||
def alpha_dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name
|
||||
"""Computes alpha dropout.
|
||||
|
||||
Alpha Dropout is a dropout that maintains the self-normalizing property. For
|
||||
an input with zero mean and unit standard deviation, the output of
|
||||
Alpha Dropout maintains the original mean and standard deviation of the input.
|
||||
|
||||
See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
|
||||
|
||||
Args:
|
||||
x: A tensor.
|
||||
keep_prob: A scalar `Tensor` with the same type as x. The probability
|
||||
that each element is kept.
|
||||
noise_shape: A 1-D `Tensor` of type `int32`, representing the
|
||||
shape for randomly generated keep/drop flags.
|
||||
seed: A Python integer. Used to create random seeds. See
|
||||
@{tf.set_random_seed} for behavior.
|
||||
name: A name for this operation (optional).
|
||||
|
||||
Returns:
|
||||
A Tensor of the same shape of `x`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `keep_prob` is not in `(0, 1]`.
|
||||
|
||||
"""
|
||||
with ops.name_scope(name, "alpha_dropout", [x]) as name:
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1.:
|
||||
raise ValueError("keep_prob must be a scalar tensor or a float in the "
|
||||
"range (0, 1], got %g" % keep_prob)
|
||||
keep_prob = ops.convert_to_tensor(keep_prob,
|
||||
dtype=x.dtype,
|
||||
name="keep_prob")
|
||||
keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())
|
||||
|
||||
# Do nothing if we know keep_prob == 1
|
||||
if tensor_util.constant_value(keep_prob) == 1:
|
||||
return x
|
||||
|
||||
alpha = -1.7580993408473766
|
||||
|
||||
noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x)
|
||||
random_tensor = random_ops.random_uniform(noise_shape,
|
||||
seed=seed,
|
||||
dtype=x.dtype)
|
||||
kept_idx = gen_math_ops.greater_equal(random_tensor, 1 - keep_prob)
|
||||
kept_idx = math_ops.cast(kept_idx, x.dtype)
|
||||
# Mask
|
||||
x = x * kept_idx + alpha * (1 - kept_idx)
|
||||
|
||||
# Affine transformation parameters
|
||||
a = (keep_prob + keep_prob * (1 - keep_prob) * alpha ** 2) ** -0.5
|
||||
b = -a * alpha * (1 - keep_prob)
|
||||
|
||||
# Affine transformation
|
||||
return a * x + b
|
||||
88
tensorflow/contrib/nn/python/ops/alpha_dropout_test.py
Normal file
88
tensorflow/contrib/nn/python/ops/alpha_dropout_test.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for sampling_ops.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.nn.python.ops.alpha_dropout import alpha_dropout
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import nn_impl
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class AlphaDropoutTest(test.TestCase):
|
||||
|
||||
def testAlphaDropout(self):
|
||||
x_dim, y_dim = 40, 30
|
||||
for keep_prob in [0.1, 0.5, 0.8]:
|
||||
with self.test_session():
|
||||
t = random_ops.random_normal([x_dim, y_dim])
|
||||
output = alpha_dropout(t, keep_prob)
|
||||
self.assertEqual([x_dim, y_dim], output.get_shape())
|
||||
t_mean, t_std = nn_impl.moments(t, axes=[0, 1])
|
||||
output_mean, output_std = nn_impl.moments(output, axes=[0, 1])
|
||||
self.assertLess(abs(t_mean.eval() - output_mean.eval()), 0.1)
|
||||
self.assertLess(abs(t_std.eval() - output_std.eval()), 0.1)
|
||||
|
||||
def testShapedDropoutShapeError(self):
|
||||
# Runs shaped dropout and verifies an error is thrown on misshapen noise.
|
||||
x_dim = 40
|
||||
y_dim = 30
|
||||
keep_prob = 0.5
|
||||
t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = alpha_dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = alpha_dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = alpha_dropout(t, keep_prob, noise_shape=[x_dim + 3])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = alpha_dropout(t, keep_prob, noise_shape=[x_dim])
|
||||
|
||||
# test that broadcasting proceeds
|
||||
_ = alpha_dropout(t, keep_prob, noise_shape=[y_dim])
|
||||
_ = alpha_dropout(t, keep_prob, noise_shape=[1, y_dim])
|
||||
_ = alpha_dropout(t, keep_prob, noise_shape=[x_dim, 1])
|
||||
_ = alpha_dropout(t, keep_prob, noise_shape=[1, 1])
|
||||
|
||||
def testInvalidKeepProb(self):
|
||||
x_dim, y_dim = 40, 30
|
||||
t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
|
||||
with self.assertRaises(ValueError):
|
||||
alpha_dropout(t, -1.0)
|
||||
with self.assertRaises(ValueError):
|
||||
alpha_dropout(t, 1.1)
|
||||
with self.assertRaises(ValueError):
|
||||
alpha_dropout(t, [0.0, 1.0])
|
||||
with self.assertRaises(ValueError):
|
||||
alpha_dropout(t, array_ops.placeholder(dtypes.float64))
|
||||
with self.assertRaises(ValueError):
|
||||
alpha_dropout(t, array_ops.placeholder(dtypes.float32, shape=[2]))
|
||||
|
||||
def testNoDropoutFast(self):
|
||||
x = array_ops.zeros((5,))
|
||||
for p in 1, constant_op.constant(1.0):
|
||||
y = alpha_dropout(x, keep_prob=p)
|
||||
self.assertTrue(x is y)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
@ -50,6 +50,10 @@ See @{$python/contrib.rnn} guide.
|
|||
@@UGRNNCell
|
||||
@@IntersectionRNNCell
|
||||
@@PhasedLSTMCell
|
||||
@@ConvLSTMCell
|
||||
@@Conv1DLSTMCell
|
||||
@@Conv2DLSTMCell
|
||||
@@Conv3DLSTMCell
|
||||
@@HighwayWrapper
|
||||
@@GLSTMCell
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ from tensorflow.python.ops import rnn_cell_impl
|
|||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.framework import test_util
|
||||
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
|
@ -445,11 +446,12 @@ class RNNCellTest(test.TestCase):
|
|||
# Can't perform this test w/o a GPU
|
||||
return
|
||||
|
||||
gpu_dev = test.gpu_device_name()
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 1, 3])
|
||||
cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), "/gpu:0")
|
||||
cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), gpu_dev)
|
||||
with ops.device("/cpu:0"):
|
||||
outputs, _ = rnn.dynamic_rnn(
|
||||
cell=cell, inputs=x, dtype=dtypes.float32)
|
||||
|
|
@ -461,7 +463,7 @@ class RNNCellTest(test.TestCase):
|
|||
_ = sess.run(outputs, options=opts, run_metadata=run_metadata)
|
||||
|
||||
step_stats = run_metadata.step_stats
|
||||
ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1
|
||||
ix = 0 if gpu_dev in step_stats.dev_stats[0].device else 1
|
||||
gpu_stats = step_stats.dev_stats[ix].node_stats
|
||||
cpu_stats = step_stats.dev_stats[1 - ix].node_stats
|
||||
self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name])
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ from tensorflow.python.ops import variables as variables_lib
|
|||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.framework import test_util
|
||||
|
||||
class Plus1RNNCell(rnn_lib.RNNCell):
|
||||
"""RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
|
||||
|
|
@ -2208,11 +2207,11 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
|
|||
if not test.is_gpu_available():
|
||||
return # Test requires access to a GPU
|
||||
|
||||
gpu_dev = test.gpu_device_name()
|
||||
run_metadata = self._execute_rnn_on(
|
||||
rnn_device="/cpu:0", cell_device=test_util.gpu_device_name())
|
||||
rnn_device="/cpu:0", cell_device=gpu_dev)
|
||||
step_stats = run_metadata.step_stats
|
||||
ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or
|
||||
("sycl" in step_stats.dev_stats[0].device)) else 1
|
||||
ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1
|
||||
gpu_stats = step_stats.dev_stats[ix].node_stats
|
||||
cpu_stats = step_stats.dev_stats[1 - ix].node_stats
|
||||
|
||||
|
|
@ -2233,12 +2232,12 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
|
|||
if not test.is_gpu_available():
|
||||
return # Test requires access to a GPU
|
||||
|
||||
gpu_dev = test.gpu_device_name()
|
||||
run_metadata = self._execute_rnn_on(
|
||||
rnn_device="/cpu:0", cell_device="/cpu:0",
|
||||
input_device=test_util.gpu_device_name())
|
||||
input_device=gpu_dev)
|
||||
step_stats = run_metadata.step_stats
|
||||
ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or
|
||||
("sycl" in step_stats.dev_stats[0].device)) else 1
|
||||
ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1
|
||||
gpu_stats = step_stats.dev_stats[ix].node_stats
|
||||
cpu_stats = step_stats.dev_stats[1 - ix].node_stats
|
||||
|
||||
|
|
@ -2253,11 +2252,11 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
|
|||
if not test.is_gpu_available():
|
||||
return # Test requires access to a GPU
|
||||
|
||||
gpu_dev = test.gpu_device_name()
|
||||
run_metadata = self._execute_rnn_on(
|
||||
input_device=test_util.gpu_device_name())
|
||||
input_device=gpu_dev)
|
||||
step_stats = run_metadata.step_stats
|
||||
ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or
|
||||
("sycl" in step_stats.dev_stats[0].device)) else 1
|
||||
ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1
|
||||
gpu_stats = step_stats.dev_stats[ix].node_stats
|
||||
cpu_stats = step_stats.dev_stats[1 - ix].node_stats
|
||||
|
||||
|
|
|
|||
|
|
@ -357,7 +357,7 @@ def training_gru_block_vs_gru_cell(batch_size,
|
|||
ops.reset_default_graph()
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
# Specify the device which is been used.
|
||||
with ops.device("/cpu:0" if not use_gpu else "/gpu:0"):
|
||||
with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"):
|
||||
|
||||
# Random initializers.
|
||||
seed = 1994
|
||||
|
|
@ -429,7 +429,7 @@ def inference_gru_block_vs_gru_cell(batch_size,
|
|||
"""Benchmark inference speed between GRUBlockCell vs GRUCell."""
|
||||
ops.reset_default_graph()
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
with ops.device("/cpu:0" if not use_gpu else "/gpu:0"):
|
||||
with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"):
|
||||
|
||||
# Random initializers.
|
||||
seed = 1994
|
||||
|
|
@ -484,7 +484,7 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size,
|
|||
"""Benchmark single bprop step speed between GRUBlockCell vs GRUCell."""
|
||||
ops.reset_default_graph()
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
with ops.device("/cpu:0" if not use_gpu else "/gpu:0"):
|
||||
with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"):
|
||||
initializer = init_ops.random_uniform_initializer(-1, 1, seed=1989)
|
||||
# Inputs
|
||||
x = vs.get_variable("x", [batch_size, input_size])
|
||||
|
|
|
|||
|
|
@ -875,6 +875,152 @@ class RNNCellTest(test.TestCase):
|
|||
self.assertAllClose(res[1].c, expected_state_c)
|
||||
self.assertAllClose(res[1].h, expected_state_h)
|
||||
|
||||
def testConv1DLSTMCell(self):
|
||||
with self.test_session() as sess:
|
||||
shape = [2,1]
|
||||
filter_size = [3]
|
||||
num_features = 1
|
||||
batch_size = 2
|
||||
expected_state_c = np.array(
|
||||
[[[1.4375670191], [1.4375670191]],
|
||||
[[2.7542609292], [2.7542609292]]],
|
||||
dtype=np.float32)
|
||||
expected_state_h = np.array(
|
||||
[[[0.6529865603], [0.6529865603]],
|
||||
[[0.8736877431], [0.8736877431]]],
|
||||
dtype=np.float32)
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(1.0/2.0)):
|
||||
x = array_ops.placeholder(dtypes.float32, [None, None, 1])
|
||||
cell = contrib_rnn_cell.Conv1DLSTMCell(input_shape=shape,
|
||||
kernel_shape=filter_size,
|
||||
output_channels=num_features)
|
||||
hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
|
||||
output, state = cell(x, hidden)
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([output, state], {
|
||||
hidden[0].name:
|
||||
np.array([[[1.],[1.]],
|
||||
[[2.],[2.]]]),
|
||||
x.name:
|
||||
np.array([[[1.],[1.]],
|
||||
[[2.],[2.]]]),
|
||||
})
|
||||
# This is a smoke test, making sure expected values are unchanged.
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertAllClose(res[0], res[1].h)
|
||||
self.assertAllClose(res[1].c, expected_state_c)
|
||||
self.assertAllClose(res[1].h, expected_state_h)
|
||||
|
||||
def testConv2DLSTMCell(self):
|
||||
with self.test_session() as sess:
|
||||
shape = [2,2,1]
|
||||
filter_size = [3,3]
|
||||
num_features = 1
|
||||
batch_size = 2
|
||||
expected_state_c = np.array(
|
||||
[[[[1.4375670191], [1.4375670191]],
|
||||
[[1.4375670191], [1.4375670191]]],
|
||||
[[[2.7542609292], [2.7542609292]],
|
||||
[[2.7542609292], [2.7542609292]]]],
|
||||
dtype=np.float32)
|
||||
expected_state_h = np.array(
|
||||
[[[[0.6529865603], [0.6529865603]],
|
||||
[[0.6529865603], [0.6529865603]]],
|
||||
[[[0.8736877431], [0.8736877431]],
|
||||
[[0.8736877431], [0.8736877431]]]],
|
||||
dtype=np.float32)
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(1.0/4.0)):
|
||||
x = array_ops.placeholder(dtypes.float32, [None, None, None, 1])
|
||||
cell = contrib_rnn_cell.Conv2DLSTMCell(input_shape=shape,
|
||||
kernel_shape=filter_size,
|
||||
output_channels=num_features)
|
||||
hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
|
||||
output, state = cell(x, hidden)
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([output, state], {
|
||||
hidden[0].name:
|
||||
np.array([[[[1.],[1.]],
|
||||
[[1.],[1.]]],
|
||||
[[[2.],[2.]],
|
||||
[[2.],[2.]]]]),
|
||||
x.name:
|
||||
np.array([[[[1.],[1.]],
|
||||
[[1.],[1.]]],
|
||||
[[[2.],[2.]],
|
||||
[[2.],[2.]]]]),
|
||||
})
|
||||
# This is a smoke test, making sure expected values are unchanged.
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertAllClose(res[0], res[1].h)
|
||||
self.assertAllClose(res[1].c, expected_state_c)
|
||||
self.assertAllClose(res[1].h, expected_state_h)
|
||||
|
||||
def testConv3DLSTMCell(self):
|
||||
with self.test_session() as sess:
|
||||
shape = [2,2,2,1]
|
||||
filter_size = [3,3,3]
|
||||
num_features = 1
|
||||
batch_size = 2
|
||||
expected_state_c = np.array(
|
||||
[[[[[1.4375670191], [1.4375670191]],
|
||||
[[1.4375670191], [1.4375670191]]],
|
||||
[[[1.4375670191], [1.4375670191]],
|
||||
[[1.4375670191], [1.4375670191]]]],
|
||||
[[[[2.7542609292], [2.7542609292]],
|
||||
[[2.7542609292], [2.7542609292]]],
|
||||
[[[2.7542609292], [2.7542609292]],
|
||||
[[2.7542609292], [2.7542609292]]]]],
|
||||
dtype=np.float32)
|
||||
expected_state_h = np.array(
|
||||
[[[[[0.6529865603], [0.6529865603]],
|
||||
[[0.6529865603], [0.6529865603]]],
|
||||
[[[0.6529865603], [0.6529865603]],
|
||||
[[0.6529865603], [0.6529865603]]]],
|
||||
[[[[0.8736877431], [0.8736877431]],
|
||||
[[0.8736877431], [0.8736877431]]],
|
||||
[[[0.8736877431], [0.8736877431]],
|
||||
[[0.8736877431], [0.8736877431]]]]],
|
||||
dtype=np.float32)
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(1.0/8.0)):
|
||||
x = array_ops.placeholder(dtypes.float32, [None, None, None, None, 1])
|
||||
cell = contrib_rnn_cell.Conv3DLSTMCell(input_shape=shape,
|
||||
kernel_shape=filter_size,
|
||||
output_channels=num_features)
|
||||
hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
|
||||
output, state = cell(x, hidden)
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([output, state], {
|
||||
hidden[0].name:
|
||||
np.array([[[[[1.],[1.]],
|
||||
[[1.],[1.]]],
|
||||
[[[1.],[1.]],
|
||||
[[1.],[1.]]]],
|
||||
[[[[2.],[2.]],
|
||||
[[2.],[2.]]],
|
||||
[[[2.],[2.]],
|
||||
[[2.],[2.]]]]]),
|
||||
x.name:
|
||||
np.array([[[[[1.],[1.]],
|
||||
[[1.],[1.]]],
|
||||
[[[1.],[1.]],
|
||||
[[1.],[1.]]]],
|
||||
[[[[2.],[2.]],
|
||||
[[2.],[2.]]],
|
||||
[[[2.],[2.]],
|
||||
[[2.],[2.]]]]])
|
||||
})
|
||||
# This is a smoke test, making sure expected values are unchanged.
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertAllClose(res[0], res[1].h)
|
||||
self.assertAllClose(res[1].c, expected_state_c)
|
||||
self.assertAllClose(res[1].h, expected_state_h)
|
||||
|
||||
def testHighwayWrapper(self):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from tensorflow.contrib.layers.python.layers import layers
|
|||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
|
|
@ -1921,6 +1922,181 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell):
|
|||
|
||||
return new_h, new_state
|
||||
|
||||
class ConvLSTMCell(rnn_cell_impl.RNNCell):
|
||||
"""Convolutional LSTM recurrent network cell.
|
||||
|
||||
https://arxiv.org/pdf/1506.04214v1.pdf
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
conv_ndims,
|
||||
input_shape,
|
||||
output_channels,
|
||||
kernel_shape,
|
||||
use_bias=True,
|
||||
skip_connection=False,
|
||||
forget_bias=1.0,
|
||||
initializers=None,
|
||||
name="conv_lstm_cell"):
|
||||
"""Construct ConvLSTMCell.
|
||||
Args:
|
||||
conv_ndims: Convolution dimensionality (1, 2 or 3).
|
||||
input_shape: Shape of the input as int tuple, excluding the batch size.
|
||||
output_channels: int, number of output channels of the conv LSTM.
|
||||
kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3).
|
||||
use_bias: Use bias in convolutions.
|
||||
skip_connection: If set to `True`, concatenate the input to the
|
||||
output of the conv LSTM. Default: `False`.
|
||||
forget_bias: Forget bias.
|
||||
name: Name of the module.
|
||||
Raises:
|
||||
ValueError: If `skip_connection` is `True` and stride is different from 1
|
||||
or if `input_shape` is incompatible with `conv_ndims`.
|
||||
"""
|
||||
super(ConvLSTMCell, self).__init__(name=name)
|
||||
|
||||
if conv_ndims != len(input_shape)-1:
|
||||
raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(
|
||||
input_shape, conv_ndims))
|
||||
|
||||
self._conv_ndims = conv_ndims
|
||||
self._input_shape = input_shape
|
||||
self._output_channels = output_channels
|
||||
self._kernel_shape = kernel_shape
|
||||
self._use_bias = use_bias
|
||||
self._forget_bias = forget_bias
|
||||
self._skip_connection = skip_connection
|
||||
|
||||
self._total_output_channels = output_channels
|
||||
if self._skip_connection:
|
||||
self._total_output_channels += self._input_shape[-1]
|
||||
|
||||
state_size = tensor_shape.TensorShape(self._input_shape[:-1]
|
||||
+ [self._output_channels])
|
||||
self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
|
||||
self._output_size = tensor_shape.TensorShape(self._input_shape[:-1]
|
||||
+ [self._total_output_channels])
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._output_size
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._state_size
|
||||
|
||||
def call(self, inputs, state, scope=None):
|
||||
cell, hidden = state
|
||||
new_hidden = _conv([inputs, hidden],
|
||||
self._kernel_shape,
|
||||
4*self._output_channels,
|
||||
self._use_bias)
|
||||
gates = array_ops.split(value=new_hidden,
|
||||
num_or_size_splits=4,
|
||||
axis=self._conv_ndims+1)
|
||||
|
||||
input_gate, new_input, forget_gate, output_gate = gates
|
||||
new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell
|
||||
new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input)
|
||||
output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate)
|
||||
|
||||
if self._skip_connection:
|
||||
output = array_ops.concat([output, inputs], axis=-1)
|
||||
new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output)
|
||||
return output, new_state
|
||||
|
||||
class Conv1DLSTMCell(ConvLSTMCell):
|
||||
"""1D Convolutional LSTM recurrent network cell.
|
||||
|
||||
https://arxiv.org/pdf/1506.04214v1.pdf
|
||||
"""
|
||||
def __init__(self, name="conv_1d_lstm_cell", **kwargs):
|
||||
"""Construct Conv1DLSTM. See `ConvLSTMCell` for more details."""
|
||||
super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs)
|
||||
|
||||
class Conv2DLSTMCell(ConvLSTMCell):
|
||||
"""2D Convolutional LSTM recurrent network cell.
|
||||
|
||||
https://arxiv.org/pdf/1506.04214v1.pdf
|
||||
"""
|
||||
def __init__(self, name="conv_2d_lstm_cell", **kwargs):
|
||||
"""Construct Conv2DLSTM. See `ConvLSTMCell` for more details."""
|
||||
super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs)
|
||||
|
||||
class Conv3DLSTMCell(ConvLSTMCell):
|
||||
"""3D Convolutional LSTM recurrent network cell.
|
||||
|
||||
https://arxiv.org/pdf/1506.04214v1.pdf
|
||||
"""
|
||||
def __init__(self, name="conv_3d_lstm_cell", **kwargs):
|
||||
"""Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
|
||||
super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs)
|
||||
|
||||
def _conv(args,
|
||||
filter_size,
|
||||
num_features,
|
||||
bias,
|
||||
bias_start=0.0):
|
||||
"""convolution:
|
||||
Args:
|
||||
args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
|
||||
batch x n, Tensors.
|
||||
filter_size: int tuple of filter height and width.
|
||||
num_features: int, number of features.
|
||||
bias_start: starting value to initialize the bias; 0 by default.
|
||||
Returns:
|
||||
A 3D, 4D, or 5D Tensor with shape [batch ... num_features]
|
||||
Raises:
|
||||
ValueError: if some of the arguments has unspecified or wrong shape.
|
||||
"""
|
||||
|
||||
# Calculate the total size of arguments on dimension 1.
|
||||
total_arg_size_depth = 0
|
||||
shapes = [a.get_shape().as_list() for a in args]
|
||||
shape_length = len(shapes[0])
|
||||
for shape in shapes:
|
||||
if len(shape) not in [3,4,5]:
|
||||
raise ValueError("Conv Linear expects 3D, 4D or 5D arguments: %s" % str(shapes))
|
||||
if len(shape) != len(shapes[0]):
|
||||
raise ValueError("Conv Linear expects all args to be of same Dimensiton: %s" % str(shapes))
|
||||
else:
|
||||
total_arg_size_depth += shape[-1]
|
||||
dtype = [a.dtype for a in args][0]
|
||||
|
||||
# determine correct conv operation
|
||||
if shape_length == 3:
|
||||
conv_op = nn_ops.conv1d
|
||||
strides = 1
|
||||
elif shape_length == 4:
|
||||
conv_op = nn_ops.conv2d
|
||||
strides = shape_length*[1]
|
||||
elif shape_length == 5:
|
||||
conv_op = nn_ops.conv3d
|
||||
strides = shape_length*[1]
|
||||
|
||||
# Now the computation.
|
||||
kernel = vs.get_variable(
|
||||
"kernel",
|
||||
filter_size + [total_arg_size_depth, num_features],
|
||||
dtype=dtype)
|
||||
if len(args) == 1:
|
||||
res = conv_op(args[0],
|
||||
kernel,
|
||||
strides,
|
||||
padding='SAME')
|
||||
else:
|
||||
res = conv_op(array_ops.concat(axis=shape_length-1, values=args),
|
||||
kernel,
|
||||
strides,
|
||||
padding='SAME')
|
||||
if not bias:
|
||||
return res
|
||||
bias_term = vs.get_variable(
|
||||
"biases", [num_features],
|
||||
dtype=dtype,
|
||||
initializer=init_ops.constant_initializer(
|
||||
bias_start, dtype=dtype))
|
||||
return res + bias_term
|
||||
|
||||
class GLSTMCell(rnn_cell_impl.RNNCell):
|
||||
"""Group LSTM cell (G-LSTM).
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class GatherTreeTest(test.TestCase):
|
|||
sequence_length = [[3, 3, 3]]
|
||||
expected_result = _transpose_batch_time(
|
||||
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
with ops.device("/gpu:0"):
|
||||
with ops.device("/device:GPU:0"):
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids, parent_ids=parent_ids,
|
||||
sequence_length=sequence_length)
|
||||
|
|
|
|||
|
|
@ -979,9 +979,9 @@ def _compute_attention(attention_mechanism, cell_output, previous_alignments,
|
|||
# alignments shape is
|
||||
# [batch_size, 1, memory_time]
|
||||
# attention_mechanism.values shape is
|
||||
# [batch_size, memory_time, attention_mechanism.num_units]
|
||||
# [batch_size, memory_time, memory_size]
|
||||
# the batched matmul is over memory_time, so the output shape is
|
||||
# [batch_size, 1, attention_mechanism.num_units].
|
||||
# [batch_size, 1, memory_size].
|
||||
# we then squeeze out the singleton dim.
|
||||
context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
|
||||
context = array_ops.squeeze(context, [1])
|
||||
|
|
|
|||
|
|
@ -301,7 +301,12 @@ class Exporter(object):
|
|||
if exports_to_keep:
|
||||
# create a simple parser that pulls the export_version from the directory.
|
||||
def parser(path):
|
||||
match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path)
|
||||
if os.name == 'nt':
|
||||
match = re.match("^" + export_dir_base.replace('\\','/') + "/(\\d{8})$",
|
||||
path.path.replace('\\','/'))
|
||||
else:
|
||||
match = re.match("^" + export_dir_base + "/(\\d{8})$",
|
||||
path.path)
|
||||
if not match:
|
||||
return None
|
||||
return path._replace(export_version=int(match.group(1)))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,187 @@
|
|||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""SGDR learning rate decay function."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops, control_flow_ops
|
||||
|
||||
|
||||
def sgdr_decay(learning_rate, global_step, initial_period_steps,
|
||||
t_mul=2.0, m_mul=1.0, name=None):
|
||||
"""Implements Stochastic Gradient Descent with Warm Restarts (SGDR).
|
||||
|
||||
As described in "SGDR: Stochastic Gradient Descent
|
||||
with Warm Restarts" by Ilya Loshchilov & Frank Hutter, Proceedings of
|
||||
ICLR'2017, available at https://arxiv.org/pdf/1608.03983.pdf
|
||||
|
||||
The learning rate decreases according to cosine annealing:
|
||||
|
||||
```python
|
||||
learning_rate * 0.5 * (1 + cos(x_val * pi)) # for x_val defined in [0, 1]
|
||||
```
|
||||
|
||||
Thus, at the beginning (when the restart index i = 0),
|
||||
the learning rate decreases for `initial_period_steps` steps from the initial
|
||||
learning rate `learning_rate` (when `x_val=0`, we get `cos(0)=1`) to
|
||||
0 (when `x_val=1`, we get `cos(pi)=-1`).
|
||||
|
||||
The decrease within the i-th period takes `t_i` steps,
|
||||
where `t_0` = `initial_period_steps` is the user-defined number of batch
|
||||
iterations (not epochs as in the paper) to be performed before the first
|
||||
restart is launched.
|
||||
|
||||
Then, we perform the first restart (i=1) by setting the learning rate to
|
||||
`learning_rate*(m_mul^i)`, where `m_mul in [0,1]` (set to 1 by default).
|
||||
The i-th restart runs for `t_i=t_0*(t_mul^i)` steps, i.e., every new
|
||||
restart runs `t_mul` times longer than the previous one.
|
||||
|
||||
Importantly, when one has no access to a validation set, SGDR suggests
|
||||
to report the best expected / recommended solution in the following way:
|
||||
When we are within our initial run (i=0), every new solution represents
|
||||
SGDR's recommended solution. Instead, when i>0, the recommended solution is
|
||||
the one obtained at the end of each restart.
|
||||
|
||||
Note that the minimum learning rate is set to 0 for simplicity,
|
||||
you can adjust the code to deal with any positive minimum learning rate
|
||||
as defined in the paper.
|
||||
|
||||
`initial_period_steps` is the duration of the first period measured in terms
|
||||
of number of minibatch updates. If one wants to use epochs, one should compute
|
||||
the number of updates required for an epoch.
|
||||
|
||||
For example, assume the following parameters and intention:
|
||||
Minibatch size: 100
|
||||
Training dataset size: 10000
|
||||
If the user wants the first decay period to span across 5 epochs, then
|
||||
`initial_period_steps` = 5 * 10000/100 = 500
|
||||
|
||||
Train for 10000 batch iterations with the initial learning rate set to
|
||||
0.1, then restart to run 2 times longer, i.e, for 20000 batch iterations
|
||||
and with the initial learning rate 0.05, then restart again and again,
|
||||
doubling the runtime of each new period and with two times smaller
|
||||
initial learning rate.
|
||||
|
||||
To accomplish the above, one would write:
|
||||
|
||||
```python
|
||||
...
|
||||
global_step = tf.Variable(0, trainable=False)
|
||||
starter_learning_rate = 0.1
|
||||
learning_rate = sgdr_decay(starter_learning_rate, global_step,
|
||||
initial_period_steps=10000, t_mul=2, m_mul=0.5)
|
||||
# Passing global_step to minimize() will increment it at each step.
|
||||
learning_step = (
|
||||
tf.train.GradientDescentOptimizer(learning_rate)
|
||||
.minimize(...my loss..., global_step=global_step)
|
||||
)
|
||||
|
||||
# Step | 0 | 1000 | 5000 | 9000 | 9999 | 10000 | 11000 |
|
||||
# LR | 0.1 | 0.097 | 0.05 | 0.002 | 0.00 | 0.05 | 0.0496 |
|
||||
|
||||
# Step | 20000 | 29000 | 29999 | 30000 |
|
||||
# LR | 0.025 | 0.0003 | 0.00 | 0.025 |
|
||||
```
|
||||
|
||||
Args:
|
||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
||||
Python number. The initial learning rate.
|
||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||
Global step to use for the decay computation. Must not be negative.
|
||||
initial_period_steps: Duration of the first period measured as the number
|
||||
of minibatch updates, if one wants to use epochs, one should compute
|
||||
the number of updates required for an epoch.
|
||||
t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||
Must be positive.
|
||||
Used to derive the number of iterations in the i-th period:
|
||||
`initial_period_steps * (t_mul^i)`. Defaults to 2.0.
|
||||
m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||
Must be positive.
|
||||
Used to derive the initial learning rate of the i-th period:
|
||||
`learning_rate * (m_mul^i)`. Defaults to 1.0
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of the same type as `learning_rate`.
|
||||
The learning rate for a provided global_step.
|
||||
Raises:
|
||||
ValueError: if `global_step` is not supplied.
|
||||
"""
|
||||
|
||||
if global_step is None:
|
||||
raise ValueError("global_step is required for sgdr_decay.")
|
||||
with ops.name_scope(name, "SGDRDecay",
|
||||
[learning_rate, global_step,
|
||||
initial_period_steps, t_mul, m_mul]) as name:
|
||||
learning_rate = ops.convert_to_tensor(learning_rate,
|
||||
name="initial_learning_rate")
|
||||
dtype = learning_rate.dtype
|
||||
global_step = math_ops.cast(global_step, dtype)
|
||||
t_0 = math_ops.cast(initial_period_steps, dtype)
|
||||
t_mul = math_ops.cast(t_mul, dtype)
|
||||
m_mul = math_ops.cast(m_mul, dtype)
|
||||
|
||||
c_one = math_ops.cast(constant_op.constant(1.0), dtype)
|
||||
c_half = math_ops.cast(constant_op.constant(0.5), dtype)
|
||||
c_pi = math_ops.cast(constant_op.constant(math.pi), dtype)
|
||||
|
||||
# Find normalized value of the current step
|
||||
x_val = math_ops.div(global_step, t_0)
|
||||
|
||||
def compute_step(x_val, geometric=False):
|
||||
if geometric:
|
||||
# Consider geometric series where t_mul != 1
|
||||
# 1 + t_mul + t_mul^2 ... = (1 - t_mul^i_restart) / (1 - t_mul)
|
||||
|
||||
# First find how many restarts were performed for a given x_val
|
||||
# Find maximal integer i_restart value for which this equation holds
|
||||
# x_val >= (1 - t_mul^i_restart) / (1 - t_mul)
|
||||
# x_val * (1 - t_mul) <= (1 - t_mul^i_restart)
|
||||
# t_mul^i_restart <= (1 - x_val * (1 - t_mul))
|
||||
|
||||
# tensorflow allows only log with base e
|
||||
# i_restart <= log(1 - x_val * (1 - t_mul) / log(t_mul)
|
||||
# Find how many restarts were performed
|
||||
|
||||
i_restart = math_ops.floor(
|
||||
math_ops.log(c_one - x_val * (c_one - t_mul)) / math_ops.log(t_mul))
|
||||
# Compute the sum of all restarts before the current one
|
||||
sum_r = (c_one - t_mul ** i_restart) / (c_one - t_mul)
|
||||
# Compute our position within the current restart
|
||||
x_val = (x_val - sum_r) / t_mul ** i_restart
|
||||
|
||||
else:
|
||||
# Find how many restarts were performed
|
||||
i_restart = math_ops.floor(x_val)
|
||||
# Compute our position within the current restart
|
||||
x_val = x_val - i_restart
|
||||
return i_restart, x_val
|
||||
|
||||
i_restart, x_val = control_flow_ops.cond(
|
||||
math_ops.equal(t_mul, c_one),
|
||||
lambda: compute_step(x_val, geometric=False),
|
||||
lambda: compute_step(x_val, geometric=True))
|
||||
|
||||
# If m_mul < 1, then the initial learning rate of every new restart will be
|
||||
# smaller, i.e., by a factor of m_mul ** i_restart at i_restart-th restart
|
||||
m_fac = learning_rate * (m_mul ** i_restart)
|
||||
|
||||
return math_ops.multiply(c_half * m_fac,
|
||||
(math_ops.cos(x_val * c_pi) + c_one), name=name)
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Functional test for sgdr learning rate decay."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from sgdr_learning_rate_decay import sgdr_decay
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow import placeholder
|
||||
|
||||
|
||||
class SGDRDecayTest(test_util.TensorFlowTestCase):
|
||||
"""Unit tests for SGDR learning rate decay."""
|
||||
|
||||
def get_original_values(self, lr, t_e, mult_factor, iter_per_epoch, epochs):
|
||||
"""Get an array with learning rate values from the consecutive steps using
|
||||
the original implementation
|
||||
(https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
|
||||
t0 = math.pi / 2.0
|
||||
tt = 0
|
||||
te_next = t_e
|
||||
|
||||
lr_values = []
|
||||
sh_lr = lr
|
||||
for epoch in range(epochs):
|
||||
for _ in range(iter_per_epoch):
|
||||
# In the original approach training function is executed here
|
||||
lr_values.append(sh_lr)
|
||||
dt = 2.0 * math.pi / float(2.0 * t_e)
|
||||
tt = tt + float(dt) / iter_per_epoch
|
||||
if tt >= math.pi:
|
||||
tt = tt - math.pi
|
||||
cur_t = t0 + tt
|
||||
new_lr = lr * (1.0 + math.sin(cur_t)) / 2.0 # lr_min = 0, lr_max = lr
|
||||
sh_lr = new_lr
|
||||
if (epoch + 1) == te_next: # time to restart
|
||||
sh_lr = lr
|
||||
tt = 0 # by setting to 0 we set lr to lr_max, see above
|
||||
t_e = t_e * mult_factor # change the period of restarts
|
||||
te_next = te_next + t_e # note the next restart's epoch
|
||||
|
||||
return lr_values
|
||||
|
||||
def get_sgdr_values(self, lr, initial_period_steps, t_mul, iters):
|
||||
"""Get an array with learning rate values from the consecutive steps
|
||||
using current tensorflow implementation."""
|
||||
with self.test_session():
|
||||
step = placeholder(dtypes.int32)
|
||||
|
||||
decay = sgdr_decay(lr, step, initial_period_steps, t_mul)
|
||||
lr_values = []
|
||||
for i in range(iters):
|
||||
lr_values.append(decay.eval(feed_dict={step: i}))
|
||||
|
||||
return lr_values
|
||||
|
||||
def testCompareToOriginal(self):
|
||||
"""Compare values generated by tensorflow implementation to the values
|
||||
generated by the original implementation
|
||||
(https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
|
||||
with self.test_session():
|
||||
lr = 10.0
|
||||
init_steps = 2
|
||||
t_mul = 3
|
||||
iters = 10
|
||||
epochs = 50
|
||||
|
||||
org_lr = self.get_original_values(lr, init_steps, t_mul, iters, epochs)
|
||||
sgdr_lr = self.get_sgdr_values(lr, init_steps*iters, t_mul, iters*epochs)
|
||||
|
||||
for org, sgdr in zip(org_lr, sgdr_lr):
|
||||
self.assertAllClose(org, sgdr)
|
||||
|
||||
def testMDecay(self):
|
||||
"""Test m_mul argument. Check values for learning rate at the beginning
|
||||
of the first, second, third and fourth period. """
|
||||
with self.test_session():
|
||||
step = placeholder(dtypes.int32)
|
||||
|
||||
lr = 0.1
|
||||
t_e = 10
|
||||
t_mul = 3
|
||||
m_mul = 0.9
|
||||
|
||||
decay = sgdr_decay(lr, step, t_e, t_mul, m_mul)
|
||||
|
||||
test_step = 0
|
||||
self.assertAllClose(decay.eval(feed_dict={step: test_step}),
|
||||
lr)
|
||||
|
||||
test_step = t_e
|
||||
self.assertAllClose(decay.eval(feed_dict={step: test_step}),
|
||||
lr * m_mul)
|
||||
|
||||
test_step = t_e + t_e*t_mul
|
||||
self.assertAllClose(decay.eval(feed_dict={step: test_step}),
|
||||
lr * m_mul**2)
|
||||
|
||||
test_step = t_e + t_e*t_mul + t_e * (t_mul**2)
|
||||
self.assertAllClose(decay.eval(feed_dict={step: test_step}),
|
||||
lr * (m_mul**3))
|
||||
|
||||
def testCos(self):
|
||||
"""Check learning rate values at the beginning, in the middle
|
||||
and at the end of the period."""
|
||||
with self.test_session():
|
||||
step = placeholder(dtypes.int32)
|
||||
lr = 0.2
|
||||
t_e = 1000
|
||||
t_mul = 1
|
||||
|
||||
decay = sgdr_decay(lr, step, t_e, t_mul)
|
||||
|
||||
test_step = 0
|
||||
self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr)
|
||||
|
||||
test_step = t_e//2
|
||||
self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2)
|
||||
|
||||
test_step = t_e
|
||||
self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr)
|
||||
|
||||
test_step = t_e*3//2
|
||||
self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/contrib/verbs/verbs_server_lib.h"
|
||||
|
||||
#include "grpc/support/alloc.h"
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ load(
|
|||
"tf_lib_proto_parsing_deps",
|
||||
"tf_additional_verbs_lib_defines",
|
||||
"tf_additional_mpi_lib_defines",
|
||||
"tf_additional_gdr_lib_defines",
|
||||
"tf_additional_gpu_tracer_srcs",
|
||||
"tf_additional_gpu_tracer_deps",
|
||||
"tf_additional_gpu_tracer_cuda_deps",
|
||||
|
|
@ -1245,72 +1246,36 @@ tf_proto_library_cc(
|
|||
],
|
||||
)
|
||||
|
||||
LIB_INTERNAL_WINDOWS_DEPS = glob(
|
||||
[
|
||||
"lib/**/*.h",
|
||||
"lib/**/*.cc",
|
||||
"platform/*.h",
|
||||
"platform/*.cc",
|
||||
"platform/profile_utils/**/*.h",
|
||||
"platform/profile_utils/**/*.cc",
|
||||
] + [
|
||||
"framework/resource_handle.h",
|
||||
"framework/resource_handle.cc",
|
||||
"framework/variant_tensor_data.h",
|
||||
"framework/variant_tensor_data.cc",
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"lib/hash/crc32c_accelerate.cc",
|
||||
"lib/gif/**/*",
|
||||
"lib/jpeg/**/*",
|
||||
"platform/gif.h",
|
||||
"platform/jpeg.h",
|
||||
"platform/**/env_time.cc",
|
||||
"platform/**/cuda.h",
|
||||
"platform/**/cuda_libdevice_path.cc",
|
||||
"platform/**/stream_executor.h",
|
||||
"platform/load_library.cc",
|
||||
"platform/variant_coding.cc",
|
||||
"platform/**/variant_cord_coding.cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lib_internal",
|
||||
srcs = select({
|
||||
"//tensorflow:windows": LIB_INTERNAL_WINDOWS_DEPS,
|
||||
"//tensorflow:windows_msvc": LIB_INTERNAL_WINDOWS_DEPS,
|
||||
"//conditions:default": glob(
|
||||
[
|
||||
"lib/**/*.h",
|
||||
"lib/**/*.cc",
|
||||
"platform/*.h",
|
||||
"platform/*.cc",
|
||||
"platform/profile_utils/**/*.h",
|
||||
"platform/profile_utils/**/*.cc",
|
||||
"framework/resource_handle.h",
|
||||
"framework/resource_handle.cc",
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"framework/variant.cc",
|
||||
"platform/variant_coding.cc",
|
||||
"lib/hash/crc32c_accelerate.cc",
|
||||
"lib/gif/**/*",
|
||||
"lib/jpeg/**/*",
|
||||
"platform/gif.h",
|
||||
"platform/jpeg.h",
|
||||
"platform/**/env_time.cc",
|
||||
"platform/**/cuda.h",
|
||||
"platform/**/cuda_libdevice_path.cc",
|
||||
"platform/**/stream_executor.h",
|
||||
"platform/**/gpu_tracer.cc",
|
||||
"platform/variant_coding.cc",
|
||||
"platform/**/variant_cord_coding.cc",
|
||||
],
|
||||
),
|
||||
}) + tf_additional_lib_srcs(
|
||||
srcs = glob(
|
||||
[
|
||||
"lib/**/*.h",
|
||||
"lib/**/*.cc",
|
||||
"platform/*.h",
|
||||
"platform/*.cc",
|
||||
"platform/profile_utils/**/*.h",
|
||||
"platform/profile_utils/**/*.cc",
|
||||
"framework/resource_handle.h",
|
||||
"framework/resource_handle.cc",
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"framework/variant.cc",
|
||||
"lib/hash/crc32c_accelerate.cc",
|
||||
"lib/gif/**/*",
|
||||
"lib/jpeg/**/*",
|
||||
"platform/gif.h",
|
||||
"platform/jpeg.h",
|
||||
"platform/**/env_time.cc",
|
||||
"platform/**/cuda.h",
|
||||
"platform/**/cuda_libdevice_path.cc",
|
||||
"platform/**/stream_executor.h",
|
||||
"platform/**/gpu_tracer.cc",
|
||||
"platform/variant_coding.cc",
|
||||
"platform/**/variant_cord_coding.cc",
|
||||
],
|
||||
) + tf_additional_lib_srcs(
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"platform/**/cuda.h",
|
||||
|
|
@ -1370,9 +1335,12 @@ cc_library(
|
|||
defines = tf_additional_lib_defines() + [
|
||||
"SNAPPY",
|
||||
] + tf_additional_verbs_lib_defines() +
|
||||
tf_additional_mpi_lib_defines(),
|
||||
tf_additional_mpi_lib_defines() +
|
||||
tf_additional_gdr_lib_defines(),
|
||||
linkopts = select({
|
||||
"//tensorflow:freebsd": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//tensorflow:windows_msvc": [],
|
||||
"//conditions:default": [
|
||||
"-ldl",
|
||||
"-lpthread",
|
||||
|
|
@ -1407,6 +1375,8 @@ cc_library(
|
|||
copts = tf_copts(),
|
||||
linkopts = select({
|
||||
"//tensorflow:freebsd": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//tensorflow:windows_msvc": [],
|
||||
"//conditions:default": ["-ldl"],
|
||||
}),
|
||||
deps = [
|
||||
|
|
@ -1430,6 +1400,8 @@ cc_library(
|
|||
copts = tf_copts(),
|
||||
linkopts = select({
|
||||
"//tensorflow:freebsd": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//tensorflow:windows_msvc": [],
|
||||
"//conditions:default": ["-ldl"],
|
||||
}),
|
||||
deps = [
|
||||
|
|
@ -1605,6 +1577,8 @@ tf_cuda_library(
|
|||
copts = tf_copts(),
|
||||
linkopts = select({
|
||||
"//tensorflow:freebsd": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//tensorflow:windows_msvc": [],
|
||||
"//conditions:default": ["-ldl"],
|
||||
}) + [
|
||||
"-lm",
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ limitations under the License.
|
|||
// Device names
|
||||
// * Every Device should have a unique name with the format:
|
||||
// /job:___/replica:___/task:___/(gpu|cpu):___
|
||||
// An example name would be "/job:train/replica:0/task:3/gpu:2".
|
||||
// An example name would be "/job:train/replica:0/task:3/device:GPU:2".
|
||||
// * Task numbers are within the specified replica, so there are as
|
||||
// many "task zeros" as replicas.
|
||||
|
||||
|
|
|
|||
|
|
@ -471,7 +471,7 @@ Status DirectSession::Run(const RunOptions& run_options,
|
|||
args.step_id = step_id_counter_.fetch_add(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetOrCreateExecutors(pool, input_tensor_names, output_names, target_nodes,
|
||||
GetOrCreateExecutors(input_tensor_names, output_names, target_nodes,
|
||||
&executors_and_keys, &run_state_args));
|
||||
const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
|
||||
|
||||
|
|
@ -711,7 +711,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
|
|||
DebugOptions debug_options;
|
||||
RunStateArgs run_state_args(debug_options);
|
||||
run_state_args.is_partial_run = true;
|
||||
TF_RETURN_IF_ERROR(GetOrCreateExecutors(pool, input_names, output_names,
|
||||
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
|
||||
target_nodes, &executors_and_keys,
|
||||
&run_state_args));
|
||||
|
||||
|
|
@ -1042,9 +1042,9 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds,
|
|||
}
|
||||
|
||||
Status DirectSession::GetOrCreateExecutors(
|
||||
thread::ThreadPool* pool, gtl::ArraySlice<string> inputs,
|
||||
gtl::ArraySlice<string> outputs, gtl::ArraySlice<string> target_nodes,
|
||||
ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args) {
|
||||
gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
|
||||
gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
|
||||
RunStateArgs* run_state_args) {
|
||||
int64 handle_name_counter_value = -1;
|
||||
if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
|
||||
handle_name_counter_value = handle_name_counter_.fetch_add(1);
|
||||
|
|
|
|||
|
|
@ -194,8 +194,8 @@ class DirectSession : public Session {
|
|||
// Retrieves an already existing set of executors to run 'inputs' and
|
||||
// 'outputs', or creates and caches them for future use.
|
||||
::tensorflow::Status GetOrCreateExecutors(
|
||||
thread::ThreadPool* pool, gtl::ArraySlice<string> inputs,
|
||||
gtl::ArraySlice<string> outputs, gtl::ArraySlice<string> target_nodes,
|
||||
gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
|
||||
gtl::ArraySlice<string> target_nodes,
|
||||
ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
|
||||
|
||||
// Creates several graphs given the existing graph_def_ and the
|
||||
|
|
|
|||
|
|
@ -476,7 +476,7 @@ TEST(DirectSessionTest, PlacePrunedGraph) {
|
|||
vx.scalar<float>()() = 1.0;
|
||||
Node* x = test::graph::Constant(&g, vx);
|
||||
Node* y = test::graph::Unary(&g, "Darth", x);
|
||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0");
|
||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
|
|
@ -494,7 +494,7 @@ TEST(DirectSessionTest, PlacePrunedGraph) {
|
|||
vx.scalar<float>()() = 1.0;
|
||||
Node* x = test::graph::Constant(&g, vx);
|
||||
Node* y = test::graph::Unary(&g, "Darth", x);
|
||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0");
|
||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
|
|
|
|||
|
|
@ -154,14 +154,14 @@ static void TestHWAccelerator(bool enableHWTrace) {
|
|||
Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
|
||||
test::FillValues<float>(&x_tensor, {1, 1});
|
||||
Node* x = test::graph::Constant(&graph, x_tensor);
|
||||
x->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0");
|
||||
x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
// y = A * x
|
||||
Node* y = test::graph::Matmul(&graph, a, x, false, false);
|
||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0");
|
||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
|
|
|||
|
|
@ -114,14 +114,14 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
|
|||
<< num_bytes << ". See error logs for more detailed info.";
|
||||
}
|
||||
}
|
||||
if (LogMemory::IsEnabled()) {
|
||||
if (LogMemory::IsEnabled() && ret != nullptr) {
|
||||
LogMemory::RecordRawAllocation(operation_, step_id_, num_bytes, ret,
|
||||
allocator_);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
void deallocate(void* buffer) const override {
|
||||
if (LogMemory::IsEnabled()) {
|
||||
if (LogMemory::IsEnabled() && buffer != nullptr) {
|
||||
LogMemory::RecordRawDeallocation(operation_, step_id_, buffer, allocator_,
|
||||
true);
|
||||
}
|
||||
|
|
@ -588,7 +588,7 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
|
|||
for (int i = 0; i < n; i++) {
|
||||
BaseGPUDevice* gpu_device;
|
||||
TF_RETURN_IF_ERROR(CreateGPUDevice(options,
|
||||
strings::StrCat(name_prefix, "/gpu:", i),
|
||||
strings::StrCat(name_prefix, "/device:GPU:", i),
|
||||
valid_gpu_ids[i], &gpu_device));
|
||||
TF_RETURN_IF_ERROR(gpu_device->Init(options));
|
||||
devices->push_back(gpu_device);
|
||||
|
|
@ -1049,7 +1049,7 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
|
|||
size_t new_id = ids->size();
|
||||
ids->push_back(visible_gpu_id);
|
||||
|
||||
LOG(INFO) << "Creating TensorFlow device (/gpu:" << new_id << ") -> "
|
||||
LOG(INFO) << "Creating TensorFlow device (/device:GPU:" << new_id << ") -> "
|
||||
<< "(" << GetShortDeviceDescription(visible_gpu_id, desc) << ")";
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ class BaseGPUDeviceFactory : public DeviceFactory {
|
|||
Allocator* cpu_allocator) = 0;
|
||||
|
||||
// Returns into 'ids' the list of valid GPU ids, in the order that
|
||||
// they should map to logical gpu ids "/gpu:0", "/gpu:1", etc, based
|
||||
// they should map to logical gpu ids "/device:GPU:0", "/device:GPU:1", etc, based
|
||||
// upon 'visible_device_list', a comma-separated list of 'visible
|
||||
// gpu ids'.
|
||||
Status GetValidDeviceIds(const string& visible_device_list,
|
||||
|
|
|
|||
|
|
@ -106,9 +106,9 @@ TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
|
|||
TEST_F(GpuStreamUtilTest, StreamOverrides) {
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0,
|
||||
"/gpu:0");
|
||||
"/device:GPU:0");
|
||||
Output n = ops::MatMul(root, {}, {});
|
||||
ops::_Send(root.WithOpName("output"), n, "output", "/gpu:0", 0, "/cpu:0");
|
||||
ops::_Send(root.WithOpName("output"), n, "output", "/device:GPU:0", 0, "/cpu:0");
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(&g));
|
||||
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ Allocator* ProcessState::GetCPUAllocator(int numa_node) {
|
|||
if (!status.ok()) {
|
||||
LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
|
||||
}
|
||||
Allocator* allocator;
|
||||
VisitableAllocator* allocator;
|
||||
if (use_bfc_allocator) {
|
||||
// TODO(reedwm): evaluate whether 64GB by default is the best choice.
|
||||
int64 cpu_mem_limit_in_mb = -1;
|
||||
|
|
@ -192,7 +192,7 @@ Allocator* ProcessState::GetCPUAllocator(int numa_node) {
|
|||
if (LogMemory::IsEnabled()) {
|
||||
// Wrap the allocator to track allocation ids for better logging
|
||||
// at the cost of performance.
|
||||
allocator = new TrackingAllocator(allocator, true);
|
||||
allocator = new TrackingVisitableAllocator(allocator, true);
|
||||
}
|
||||
cpu_allocators_.push_back(allocator);
|
||||
}
|
||||
|
|
@ -237,14 +237,14 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
|
|||
LOG(ERROR) << "GetCUDAHostAllocator: " << status.error_message();
|
||||
}
|
||||
int64 cuda_host_mem_limit = cuda_host_mem_limit_in_mb * (1LL << 20);
|
||||
Allocator* allocator =
|
||||
VisitableAllocator* allocator =
|
||||
new BFCAllocator(new CUDAHostAllocator(se), cuda_host_mem_limit,
|
||||
true /*allow_growth*/, "cuda_host_bfc" /*name*/);
|
||||
|
||||
if (LogMemory::IsEnabled()) {
|
||||
// Wrap the allocator to track allocation ids for better logging
|
||||
// at the cost of performance.
|
||||
allocator = new TrackingAllocator(allocator, true);
|
||||
allocator = new TrackingVisitableAllocator(allocator, true);
|
||||
}
|
||||
cuda_host_allocators_.push_back(allocator);
|
||||
if (FLAGS_brain_gpu_record_mem_types) {
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ TEST(MemoryTypeChecker, Int32NotOk) {
|
|||
EXPECT_TRUE(errors::IsInternal(ValidateMemoryTypes(DEVICE_GPU, g)));
|
||||
|
||||
// But we can insert _HostSend/_HostRecv to ensure the invariant.
|
||||
TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_GPU, "/gpu:0", g));
|
||||
TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_GPU, "/device:GPU:0", g));
|
||||
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_GPU, g));
|
||||
#endif // GOOGLE_CUDA
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
|
|
|
|||
|
|
@ -682,7 +682,7 @@ Status SimplePlacer::Run() {
|
|||
int dst_root_id = colocation_graph.FindRoot(dst->id());
|
||||
auto& src_root = colocation_graph.members_[src_root_id];
|
||||
auto& dst_root = colocation_graph.members_[dst_root_id];
|
||||
// If both the source node and this node have paritally
|
||||
// If both the source node and this node have partially
|
||||
// specified a device, then 'node's device should be
|
||||
// cleared: the reference edge forces 'node' to be on the
|
||||
// same device as the source node.
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
SYCLAllocator::SYCLAllocator(Eigen::QueueInterface *queue)
|
||||
SYCLAllocator::SYCLAllocator(Eigen::QueueInterface* queue)
|
||||
: sycl_device_(new Eigen::SyclDevice(queue)) {
|
||||
cl::sycl::queue& sycl_queue = sycl_device_->sycl_queue();
|
||||
const cl::sycl::device& device = sycl_queue.get_device();
|
||||
|
|
@ -28,14 +28,15 @@ SYCLAllocator::SYCLAllocator(Eigen::QueueInterface *queue)
|
|||
}
|
||||
|
||||
SYCLAllocator::~SYCLAllocator() {
|
||||
if(sycl_device_) {
|
||||
if (sycl_device_) {
|
||||
delete sycl_device_;
|
||||
}
|
||||
}
|
||||
|
||||
string SYCLAllocator::Name() { return "device:SYCL"; }
|
||||
|
||||
void *SYCLAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
|
||||
void* SYCLAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
|
||||
mutex_lock lock(mu_);
|
||||
assert(sycl_device_);
|
||||
if (num_bytes == 0) {
|
||||
// Cannot allocate no bytes in SYCL, so instead allocate a single byte
|
||||
|
|
@ -45,7 +46,6 @@ void *SYCLAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
|
|||
const auto& allocated_buffer = sycl_device_->get_sycl_buffer(p);
|
||||
const std::size_t bytes_allocated = allocated_buffer.get_range().size();
|
||||
|
||||
mutex_lock lock(mu_);
|
||||
++stats_.num_allocs;
|
||||
stats_.bytes_in_use += bytes_allocated;
|
||||
stats_.max_bytes_in_use =
|
||||
|
|
@ -56,12 +56,12 @@ void *SYCLAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
|
|||
return p;
|
||||
}
|
||||
|
||||
void SYCLAllocator::DeallocateRaw(void *ptr) {
|
||||
const auto& buffer_to_delete = sycl_device_->get_sycl_buffer(ptr);
|
||||
const std::size_t dealloc_size = buffer_to_delete.get_range().size();
|
||||
void SYCLAllocator::DeallocateRaw(void* ptr) {
|
||||
mutex_lock lock(mu_);
|
||||
stats_.bytes_in_use -= dealloc_size;
|
||||
if (sycl_device_) {
|
||||
const auto& buffer_to_delete = sycl_device_->get_sycl_buffer(ptr);
|
||||
const std::size_t dealloc_size = buffer_to_delete.get_range().size();
|
||||
stats_.bytes_in_use -= dealloc_size;
|
||||
sycl_device_->deallocate(ptr);
|
||||
}
|
||||
}
|
||||
|
|
@ -72,6 +72,10 @@ void SYCLAllocator::GetStats(AllocatorStats* stats) {
|
|||
}
|
||||
|
||||
size_t SYCLAllocator::RequestedSize(void* ptr) {
|
||||
mutex_lock lock(mu_);
|
||||
if(!sycl_device_) {
|
||||
return 0;
|
||||
}
|
||||
const auto& buffer = sycl_device_->get_sycl_buffer(ptr);
|
||||
return buffer.get_size();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,15 +29,20 @@ namespace tensorflow {
|
|||
|
||||
class SYCLAllocator : public Allocator {
|
||||
public:
|
||||
SYCLAllocator(Eigen::QueueInterface *queue);
|
||||
SYCLAllocator(Eigen::QueueInterface* queue);
|
||||
virtual ~SYCLAllocator() override;
|
||||
string Name() override;
|
||||
void *AllocateRaw(size_t alignment, size_t num_bytes) override;
|
||||
void DeallocateRaw(void *ptr) override;
|
||||
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
|
||||
void DeallocateRaw(void* ptr) override;
|
||||
|
||||
virtual bool ShouldAllocateEmptyTensors() override final { return true; }
|
||||
void Synchronize() { sycl_device_->synchronize(); }
|
||||
bool Ok() { return sycl_device_->ok(); }
|
||||
void Synchronize() {
|
||||
mutex_lock lock(mu_);
|
||||
if (sycl_device_) {
|
||||
sycl_device_->synchronize();
|
||||
}
|
||||
}
|
||||
bool Ok() { return sycl_device_ && sycl_device_->ok(); }
|
||||
void GetStats(AllocatorStats* stats) override;
|
||||
// The SYCL buffers keep track of their size, so we already have tracking.
|
||||
bool TracksAllocationSizes() override { return true; }
|
||||
|
|
@ -46,10 +51,19 @@ class SYCLAllocator : public Allocator {
|
|||
// AllocatedSize(void* ptr) by default.
|
||||
size_t RequestedSize(void* ptr) override;
|
||||
Eigen::SyclDevice* getSyclDevice() { return sycl_device_; }
|
||||
// Clear the SYCL device used by the Allocator
|
||||
void ClearSYCLDevice() {
|
||||
mutex_lock lock(mu_);
|
||||
if(sycl_device_) {
|
||||
delete sycl_device_;
|
||||
sycl_device_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Eigen::SyclDevice *sycl_device_; // owned
|
||||
|
||||
mutable mutex mu_;
|
||||
Eigen::SyclDevice* sycl_device_ GUARDED_BY(mu_); // owned
|
||||
AllocatorStats stats_ GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SYCLAllocator);
|
||||
|
|
|
|||
|
|
@ -22,20 +22,10 @@ limitations under the License.
|
|||
#include "tensorflow/core/platform/tracing.h"
|
||||
|
||||
namespace tensorflow {
|
||||
std::mutex GSYCLInterface::mutex_;
|
||||
GSYCLInterface *GSYCLInterface::s_instance = 0;
|
||||
|
||||
void ShutdownSycl() {
|
||||
GSYCLInterface::Reset();
|
||||
}
|
||||
|
||||
void SYCLDevice::RegisterDevice() {
|
||||
atexit(ShutdownSycl);
|
||||
}
|
||||
|
||||
SYCLDevice::~SYCLDevice() {}
|
||||
|
||||
void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) {
|
||||
void SYCLDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
assert(context);
|
||||
if (port::Tracing::IsActive()) {
|
||||
// TODO(pbar) We really need a useful identifier of the graph node.
|
||||
|
|
@ -46,16 +36,16 @@ void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) {
|
|||
op_kernel->Compute(context);
|
||||
}
|
||||
|
||||
Allocator *SYCLDevice::GetAllocator(AllocatorAttributes attr) {
|
||||
Allocator* SYCLDevice::GetAllocator(AllocatorAttributes attr) {
|
||||
if (attr.on_host())
|
||||
return cpu_allocator_;
|
||||
else
|
||||
return sycl_allocator_;
|
||||
}
|
||||
|
||||
Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto,
|
||||
Status SYCLDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor *tensor) {
|
||||
Tensor* tensor) {
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
Allocator* host_alloc = GetAllocator(attr);
|
||||
|
|
@ -79,18 +69,18 @@ Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto,
|
|||
}
|
||||
|
||||
device_context_->CopyCPUTensorToDevice(
|
||||
&parsed, this, ©, [&status](const Status &s) { status = s; });
|
||||
&parsed, this, ©, [&status](const Status& s) { status = s; });
|
||||
*tensor = copy;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
Status SYCLDevice::FillContextMap(const Graph *graph,
|
||||
DeviceContextMap *device_context_map) {
|
||||
Status SYCLDevice::FillContextMap(const Graph* graph,
|
||||
DeviceContextMap* device_context_map) {
|
||||
// Fill in the context map. It is OK for this map to contain
|
||||
// duplicate DeviceContexts so long as we increment the refcount.
|
||||
device_context_map->resize(graph->num_node_ids());
|
||||
for (Node *n : graph->nodes()) {
|
||||
for (Node* n : graph->nodes()) {
|
||||
device_context_->Ref();
|
||||
(*device_context_map)[n->id()] = device_context_;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,201 +27,190 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
|
||||
class GSYCLInterface
|
||||
{
|
||||
std::vector<Eigen::QueueInterface*> m_queue_interface_; // owned
|
||||
std::vector<Allocator*> m_cpu_allocator_; // not owned
|
||||
std::vector<SYCLAllocator*> m_sycl_allocator_; // owned
|
||||
std::vector<SYCLDeviceContext*> m_sycl_context_; // owned
|
||||
|
||||
static std::mutex mutex_;
|
||||
static GSYCLInterface* s_instance;
|
||||
GSYCLInterface() {
|
||||
bool found_device =false;
|
||||
auto device_list = Eigen::get_sycl_supported_devices();
|
||||
// Obtain list of supported devices from Eigen
|
||||
for (const auto& device : device_list) {
|
||||
if(device.is_gpu()) {
|
||||
// returns first found GPU
|
||||
AddDevice(device);
|
||||
found_device = true;
|
||||
}
|
||||
}
|
||||
|
||||
if(!found_device) {
|
||||
// Currently Intel GPU is not supported
|
||||
LOG(WARNING) << "No OpenCL GPU found that is supported by ComputeCpp, trying OpenCL CPU";
|
||||
}
|
||||
|
||||
for (const auto& device : device_list) {
|
||||
if(device.is_cpu()) {
|
||||
// returns first found CPU
|
||||
AddDevice(device);
|
||||
found_device = true;
|
||||
}
|
||||
}
|
||||
|
||||
if(!found_device) {
|
||||
// Currently Intel GPU is not supported
|
||||
LOG(FATAL) << "No OpenCL GPU nor CPU found that is supported by ComputeCpp";
|
||||
} else {
|
||||
LOG(INFO) << "Found following OpenCL devices:";
|
||||
for (int i = 0; i < device_list.size(); i++) {
|
||||
LOG(INFO) << GetShortDeviceDescription(i);
|
||||
}
|
||||
class GSYCLInterface {
|
||||
std::vector<Eigen::QueueInterface*> m_queue_interface_; // owned
|
||||
std::vector<Allocator*> m_cpu_allocator_; // not owned
|
||||
std::vector<SYCLAllocator*> m_sycl_allocator_; // owned
|
||||
std::vector<SYCLDeviceContext*> m_sycl_context_; // ref counted
|
||||
GSYCLInterface() {
|
||||
bool found_device = false;
|
||||
auto device_list = Eigen::get_sycl_supported_devices();
|
||||
// Obtain list of supported devices from Eigen
|
||||
for (const auto& device : device_list) {
|
||||
if (device.is_gpu()) {
|
||||
// returns first found GPU
|
||||
AddDevice(device);
|
||||
found_device = true;
|
||||
}
|
||||
}
|
||||
|
||||
~GSYCLInterface() {
|
||||
m_cpu_allocator_.clear();
|
||||
|
||||
for (auto p : m_sycl_allocator_) {
|
||||
p->Synchronize();
|
||||
delete p;
|
||||
}
|
||||
m_sycl_allocator_.clear();
|
||||
|
||||
for(auto p : m_sycl_context_) {
|
||||
p->Unref();
|
||||
}
|
||||
m_sycl_context_.clear();
|
||||
|
||||
for (auto p : m_queue_interface_) {
|
||||
p->deallocate_all();
|
||||
delete p;
|
||||
p = nullptr;
|
||||
}
|
||||
m_queue_interface_.clear();
|
||||
if (!found_device) {
|
||||
// Currently Intel GPU is not supported
|
||||
LOG(WARNING) << "No OpenCL GPU found that is supported by ComputeCpp, "
|
||||
"trying OpenCL CPU";
|
||||
}
|
||||
|
||||
void AddDevice(const cl::sycl::device & d) {
|
||||
m_queue_interface_.push_back(new Eigen::QueueInterface(d));
|
||||
m_cpu_allocator_.push_back(cpu_allocator());
|
||||
m_sycl_allocator_.push_back(new SYCLAllocator(m_queue_interface_.back()));
|
||||
m_sycl_context_.push_back(new SYCLDeviceContext());
|
||||
}
|
||||
|
||||
public:
|
||||
static GSYCLInterface *instance()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (!s_instance) {
|
||||
s_instance = new GSYCLInterface();
|
||||
}
|
||||
return s_instance;
|
||||
}
|
||||
|
||||
static void Reset()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if(s_instance) {
|
||||
delete s_instance;
|
||||
s_instance = NULL;
|
||||
for (const auto& device : device_list) {
|
||||
if (device.is_cpu()) {
|
||||
// returns first found CPU
|
||||
AddDevice(device);
|
||||
found_device = true;
|
||||
}
|
||||
}
|
||||
|
||||
Eigen::QueueInterface * GetQueueInterface(size_t i = 0) {
|
||||
if(!m_queue_interface_.empty()) {
|
||||
return m_queue_interface_[i];
|
||||
} else {
|
||||
std::cerr << "No cl::sycl::device has been added" << std::endl;
|
||||
return nullptr;
|
||||
if (!found_device) {
|
||||
// Currently Intel GPU is not supported
|
||||
LOG(FATAL)
|
||||
<< "No OpenCL GPU nor CPU found that is supported by ComputeCpp";
|
||||
} else {
|
||||
LOG(INFO) << "Found following OpenCL devices:";
|
||||
for (int i = 0; i < device_list.size(); i++) {
|
||||
LOG(INFO) << GetShortDeviceDescription(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SYCLAllocator * GetSYCLAllocator(size_t i = 0) {
|
||||
if(!m_sycl_allocator_.empty()) {
|
||||
return m_sycl_allocator_[i];
|
||||
} else {
|
||||
std::cerr << "No cl::sycl::device has been added" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
~GSYCLInterface() {
|
||||
m_cpu_allocator_.clear();
|
||||
|
||||
for (auto p : m_sycl_allocator_) {
|
||||
p->Synchronize();
|
||||
p->ClearSYCLDevice();
|
||||
// Cannot delete the Allocator instances, as the Allocator lifetime
|
||||
// needs to exceed any Tensor created by it. There is no way of
|
||||
// knowing when all Tensors have been deallocated, as they are
|
||||
// RefCounted and wait until all instances of a Tensor have been
|
||||
// destroyed before calling Allocator.Deallocate. This could happen at
|
||||
// program exit, which can set up a race condition between destroying
|
||||
// Tensors and Allocators when the program is cleaning up.
|
||||
}
|
||||
m_sycl_allocator_.clear();
|
||||
|
||||
for (auto p : m_sycl_context_) {
|
||||
p->Unref();
|
||||
}
|
||||
m_sycl_context_.clear();
|
||||
|
||||
for (auto p : m_queue_interface_) {
|
||||
p->deallocate_all();
|
||||
delete p;
|
||||
}
|
||||
m_queue_interface_.clear();
|
||||
}
|
||||
|
||||
void AddDevice(const cl::sycl::device& d) {
|
||||
m_queue_interface_.push_back(new Eigen::QueueInterface(d));
|
||||
m_cpu_allocator_.push_back(cpu_allocator());
|
||||
m_sycl_allocator_.push_back(new SYCLAllocator(m_queue_interface_.back()));
|
||||
m_sycl_context_.push_back(new SYCLDeviceContext());
|
||||
}
|
||||
|
||||
public:
|
||||
static const GSYCLInterface* instance() {
|
||||
// c++11 guarantees that this will be constructed in a thread safe way
|
||||
static const GSYCLInterface instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
Eigen::QueueInterface* GetQueueInterface(size_t i = 0) const {
|
||||
if (!m_queue_interface_.empty()) {
|
||||
return m_queue_interface_[i];
|
||||
} else {
|
||||
std::cerr << "No cl::sycl::device has been added" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SYCLAllocator* GetSYCLAllocator(size_t i = 0) const {
|
||||
if (!m_sycl_allocator_.empty()) {
|
||||
return m_sycl_allocator_[i];
|
||||
} else {
|
||||
std::cerr << "No cl::sycl::device has been added" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Allocator* GetCPUAllocator(size_t i = 0) const {
|
||||
if (!m_cpu_allocator_.empty()) {
|
||||
return m_cpu_allocator_[i];
|
||||
} else {
|
||||
std::cerr << "No cl::sycl::device has been added" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SYCLDeviceContext* GetSYCLContext(size_t i = 0) const {
|
||||
if (!m_sycl_context_.empty()) {
|
||||
return m_sycl_context_[i];
|
||||
} else {
|
||||
std::cerr << "No cl::sycl::device has been added" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
string GetShortDeviceDescription(int device_id = 0) const {
|
||||
Eigen::QueueInterface* queue_ptr = GetQueueInterface(device_id);
|
||||
if (!queue_ptr) {
|
||||
LOG(ERROR)
|
||||
<< "Device name cannot be given after Eigen QueueInterface destroyed";
|
||||
return "";
|
||||
}
|
||||
auto device = queue_ptr->sycl_queue().get_device();
|
||||
auto name = device.get_info<cl::sycl::info::device::name>();
|
||||
auto vendor = device.get_info<cl::sycl::info::device::vendor>();
|
||||
auto profile = device.get_info<cl::sycl::info::device::profile>();
|
||||
|
||||
std::string type;
|
||||
if (device.is_host()) {
|
||||
type = "Host";
|
||||
} else if (device.is_cpu()) {
|
||||
type = "CPU";
|
||||
} else if (device.is_gpu()) {
|
||||
type = "GPU";
|
||||
} else if (device.is_accelerator()) {
|
||||
type = "Accelerator";
|
||||
} else {
|
||||
type = "Unknown";
|
||||
}
|
||||
|
||||
Allocator * GetCPUAllocator(size_t i = 0) {
|
||||
if(!m_cpu_allocator_.empty()) {
|
||||
return m_cpu_allocator_[i];
|
||||
} else {
|
||||
std::cerr << "No cl::sycl::device has been added" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SYCLDeviceContext * GetSYCLContext(size_t i = 0) {
|
||||
if(!m_sycl_context_.empty()) {
|
||||
return m_sycl_context_[i];
|
||||
} else {
|
||||
std::cerr << "No cl::sycl::device has been added" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
string GetShortDeviceDescription(int device_id = 0) {
|
||||
auto _device = GetSYCLAllocator(device_id)
|
||||
->getSyclDevice()
|
||||
->sycl_queue()
|
||||
.get_device();
|
||||
auto _name = _device.get_info<cl::sycl::info::device::name>();
|
||||
auto _vendor = _device.get_info<cl::sycl::info::device::vendor>();
|
||||
auto _profile = _device.get_info<cl::sycl::info::device::profile>();
|
||||
|
||||
std::string _type;
|
||||
if (_device.is_host()) {
|
||||
_type = "Host";
|
||||
} else if (_device.is_cpu()) {
|
||||
_type = "CPU";
|
||||
} else if (_device.is_gpu()) {
|
||||
_type = "GPU";
|
||||
} else if (_device.is_accelerator()) {
|
||||
_type = "Accelerator";
|
||||
} else {
|
||||
_type = "Unknown";
|
||||
}
|
||||
|
||||
return strings::StrCat("id: ", device_id, " ,type: ", _type, " ,name: ",
|
||||
_name.c_str(), " ,vendor: ", _vendor.c_str(),
|
||||
" ,profile: ", _profile.c_str());
|
||||
}
|
||||
return strings::StrCat("id: ", device_id, ", type: ", type, ", name: ",
|
||||
name.c_str(), ", vendor: ", vendor.c_str(),
|
||||
", profile: ", profile.c_str());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class SYCLDevice : public LocalDevice {
|
||||
public:
|
||||
SYCLDevice(const SessionOptions &options, const string &name,
|
||||
Bytes memory_limit, const DeviceLocality &locality,
|
||||
const string &physical_device_desc, SYCLAllocator * sycl_allocator,
|
||||
Allocator *cpu_allocator, SYCLDeviceContext* ctx)
|
||||
: LocalDevice(
|
||||
options,
|
||||
Device::BuildDeviceAttributes(name, DEVICE_SYCL, memory_limit,
|
||||
locality, physical_device_desc)),
|
||||
SYCLDevice(const SessionOptions& options, const string& name,
|
||||
Bytes memory_limit, const DeviceLocality& locality,
|
||||
const string& physical_device_desc, SYCLAllocator* sycl_allocator,
|
||||
Allocator* cpu_allocator, SYCLDeviceContext* ctx)
|
||||
: LocalDevice(options, Device::BuildDeviceAttributes(
|
||||
name, DEVICE_SYCL, memory_limit, locality,
|
||||
physical_device_desc)),
|
||||
cpu_allocator_(cpu_allocator),
|
||||
sycl_allocator_(sycl_allocator),
|
||||
device_context_(ctx) {
|
||||
RegisterDevice();
|
||||
set_eigen_sycl_device(sycl_allocator->getSyclDevice());
|
||||
}
|
||||
|
||||
~SYCLDevice() override;
|
||||
|
||||
void Compute(OpKernel *op_kernel, OpKernelContext *context) override;
|
||||
Allocator *GetAllocator(AllocatorAttributes attr) override;
|
||||
Status MakeTensorFromProto(const TensorProto &tensor_proto,
|
||||
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
|
||||
Allocator* GetAllocator(AllocatorAttributes attr) override;
|
||||
Status MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor *tensor) override;
|
||||
Tensor* tensor) override;
|
||||
|
||||
Status FillContextMap(const Graph *graph,
|
||||
DeviceContextMap *device_context_map) override;
|
||||
Status FillContextMap(const Graph* graph,
|
||||
DeviceContextMap* device_context_map) override;
|
||||
|
||||
Status Sync() override;
|
||||
|
||||
private:
|
||||
void RegisterDevice();
|
||||
|
||||
Allocator *cpu_allocator_; // not owned
|
||||
SYCLAllocator *sycl_allocator_; // not owned
|
||||
SYCLDeviceContext *device_context_;
|
||||
Allocator* cpu_allocator_; // not owned
|
||||
SYCLAllocator* sycl_allocator_; // not owned
|
||||
SYCLDeviceContext* device_context_; // not owned
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -21,17 +21,60 @@ limitations under the License.
|
|||
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
// For DMA helper
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
inline void* GetBase(const Tensor* src) {
|
||||
return const_cast<void*>(DMAHelper::base(src));
|
||||
inline void const* GetBase(const Tensor* src) { return DMAHelper::base(src); }
|
||||
inline void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
|
||||
|
||||
inline void SYCLmemcpy(Eigen::SyclDevice const& device,
|
||||
Tensor const& src_tensor, Tensor* dst_tensor) {
|
||||
const size_t size = src_tensor.TotalBytes();
|
||||
void* dst_ptr = GetBase(dst_tensor);
|
||||
void const* src_ptr = GetBase(&src_tensor);
|
||||
|
||||
#define COPY_WITH_TYPE(T) \
|
||||
device.memcpy(dst_ptr, static_cast<T const*>(src_ptr), size);
|
||||
switch (src_tensor.dtype()) {
|
||||
case DT_COMPLEX128:
|
||||
COPY_WITH_TYPE(cl::sycl::cl_ulong2);
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
case DT_COMPLEX64:
|
||||
case DT_INT64:
|
||||
COPY_WITH_TYPE(cl::sycl::cl_ulong);
|
||||
break;
|
||||
case DT_FLOAT:
|
||||
case DT_INT32:
|
||||
case DT_QINT32:
|
||||
COPY_WITH_TYPE(cl::sycl::cl_uint);
|
||||
break;
|
||||
case DT_INT16:
|
||||
case DT_UINT16:
|
||||
case DT_BFLOAT16:
|
||||
case DT_QINT16:
|
||||
case DT_QUINT16:
|
||||
case DT_HALF:
|
||||
COPY_WITH_TYPE(cl::sycl::cl_ushort);
|
||||
break;
|
||||
case DT_BOOL:
|
||||
COPY_WITH_TYPE(bool);
|
||||
break;
|
||||
case DT_UINT8:
|
||||
case DT_INT8:
|
||||
case DT_QINT8:
|
||||
case DT_QUINT8:
|
||||
COPY_WITH_TYPE(cl::sycl::cl_uchar);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type " << src_tensor.dtype();
|
||||
break;
|
||||
}
|
||||
|
||||
inline void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
|
||||
|
||||
#undef COPY_WITH_TYPE
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ void DebugGateway::CopyTensor(const string& node_name, const int output_slot,
|
|||
// Determine if the tensor is on device (GPU) or host (CPU).
|
||||
// The second part of the check is necessary because even an OpKernel on
|
||||
// may have output tensors allocated on CPU.
|
||||
if ((device->name().find("gpu:") != string::npos || device->name().find("SYCL:") != string::npos) &&
|
||||
if ((device->name().find("GPU:") != string::npos || device->name().find("SYCL:") != string::npos) &&
|
||||
!ctx->output_alloc_attr(output_slot).on_host()) {
|
||||
// GPU tensors: Copy it to host (CPU).
|
||||
DeviceContext* device_ctxt = ctx->op_device_context();
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class SessionDebugMinusAXTest : public ::testing::Test {
|
|||
Graph graph(OpRegistry::Global());
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0";
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0";
|
||||
#elif defined(TENSORFLOW_USE_SYCL)
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
|
||||
#else
|
||||
|
|
@ -505,7 +505,7 @@ class SessionDebugOutputSlotWithoutOngoingEdgeTest : public ::testing::Test {
|
|||
Graph graph(OpRegistry::Global());
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0";
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0";
|
||||
#elif defined(TENSORFLOW_USE_SYCL)
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
|
||||
#else
|
||||
|
|
@ -607,7 +607,7 @@ class SessionDebugVariableTest : public ::testing::Test {
|
|||
Graph graph(OpRegistry::Global());
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0";
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0";
|
||||
#elif defined(TENSORFLOW_USE_SYCL)
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
|
||||
#else
|
||||
|
|
@ -879,7 +879,7 @@ class SessionDebugGPUSwitchTest : public ::testing::Test {
|
|||
Graph graph(OpRegistry::Global());
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0";
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0";
|
||||
#elif TENSORFLOW_USE_SYCL
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -53,14 +53,14 @@ class DebugIOUtilsTest : public ::testing::Test {
|
|||
};
|
||||
|
||||
TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) {
|
||||
DebugNodeKey debug_node_key("/job:worker/replica:1/task:0/gpu:2",
|
||||
DebugNodeKey debug_node_key("/job:worker/replica:1/task:0/device:GPU:2",
|
||||
"hidden_1/MatMul", 0, "DebugIdentity");
|
||||
EXPECT_EQ("/job:worker/replica:1/task:0/gpu:2", debug_node_key.device_name);
|
||||
EXPECT_EQ("/job:worker/replica:1/task:0/device:GPU:2", debug_node_key.device_name);
|
||||
EXPECT_EQ("hidden_1/MatMul", debug_node_key.node_name);
|
||||
EXPECT_EQ(0, debug_node_key.output_slot);
|
||||
EXPECT_EQ("DebugIdentity", debug_node_key.debug_op);
|
||||
EXPECT_EQ("hidden_1/MatMul:0:DebugIdentity", debug_node_key.debug_node_name);
|
||||
EXPECT_EQ("_tfdbg_device_,job_worker,replica_1,task_0,gpu_2",
|
||||
EXPECT_EQ("_tfdbg_device_,job_worker,replica_1,task_0,device_GPU_2",
|
||||
debug_node_key.device_path);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
|
|||
}
|
||||
|
||||
#define ALICE "/job:j/replica:0/task:0/cpu:0"
|
||||
#define BOB "/job:j/replica:0/task:0/gpu:0"
|
||||
#define BOB "/job:j/replica:0/task:0/device:GPU:0"
|
||||
|
||||
TEST_F(ExecutorTest, SimpleAdd) {
|
||||
// c = a + b
|
||||
|
|
|
|||
|
|
@ -31,9 +31,9 @@ TEST(GrpcChannelTest, IsSameAddressSpace) {
|
|||
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0",
|
||||
"/job:mnist/replica:10/task:10/cpu:1"));
|
||||
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0",
|
||||
"/job:mnist/replica:10/task:10/gpu:2"));
|
||||
"/job:mnist/replica:10/task:10/device:GPU:2"));
|
||||
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10",
|
||||
"/job:mnist/replica:10/task:10/gpu:2"));
|
||||
"/job:mnist/replica:10/task:10/device:GPU:2"));
|
||||
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:1",
|
||||
"/job:mnist/replica:10/task:10"));
|
||||
|
||||
|
|
|
|||
|
|
@ -129,28 +129,14 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||
TensorResponse* response, StatusCallback done) override {
|
||||
VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
|
||||
int64 start_usec = Env::Default()->NowMicros();
|
||||
// Don't propagate dma_ok over gRPC.
|
||||
RecvTensorRequest* req_copy = nullptr;
|
||||
if (request->dma_ok()) {
|
||||
req_copy = new RecvTensorRequest;
|
||||
*req_copy = *request;
|
||||
req_copy->set_dma_ok(false);
|
||||
}
|
||||
// Type-specialized logging for this method.
|
||||
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
|
||||
StatusCallback wrapper_done;
|
||||
const StatusCallback* cb_to_use;
|
||||
if (!logging_active && req_copy == nullptr) {
|
||||
if (!logging_active) {
|
||||
cb_to_use = &done; // No additional work to do, so just use done directly
|
||||
} else if (!logging_active) {
|
||||
wrapper_done = [req_copy, done](Status s) {
|
||||
delete req_copy;
|
||||
done(s);
|
||||
};
|
||||
cb_to_use = &wrapper_done;
|
||||
} else {
|
||||
wrapper_done = [this, request, req_copy, response, done,
|
||||
start_usec](Status s) {
|
||||
wrapper_done = [this, request, response, done, start_usec](Status s) {
|
||||
if (logger_->LoggingActive()) {
|
||||
int64 end_usec = Env::Default()->NowMicros();
|
||||
int64 step_id = request->step_id();
|
||||
|
|
@ -189,14 +175,12 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||
}
|
||||
VLOG(2) << "done callback, req: " << request->DebugString()
|
||||
<< " response " << response->metadata().DebugString();
|
||||
delete req_copy;
|
||||
done(s);
|
||||
};
|
||||
cb_to_use = &wrapper_done;
|
||||
}
|
||||
|
||||
IssueRequest(req_copy ? req_copy : request, response, recvtensor_,
|
||||
*cb_to_use, call_opts);
|
||||
IssueRequest(request, response, recvtensor_, *cb_to_use, call_opts);
|
||||
}
|
||||
|
||||
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
|
||||
|
|
|
|||
|
|
@ -105,7 +105,8 @@ GrpcServer::~GrpcServer() {
|
|||
|
||||
Status GrpcServer::Init(
|
||||
ServiceInitFunction service_func,
|
||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
|
||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
|
||||
const WorkerCreationFunction& worker_func) {
|
||||
mutex_lock l(mu_);
|
||||
CHECK_EQ(state_, NEW);
|
||||
master_env_.env = env_;
|
||||
|
|
@ -183,7 +184,8 @@ Status GrpcServer::Init(
|
|||
master_impl_ = CreateMaster(&master_env_);
|
||||
master_service_ = NewGrpcMasterService(
|
||||
master_impl_.get(), config.operation_timeout_in_ms(), &builder);
|
||||
worker_impl_ = NewGrpcWorker(&worker_env_);
|
||||
worker_impl_ =
|
||||
worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_);
|
||||
worker_service_ =
|
||||
NewGrpcWorkerService(worker_impl_.get(), &builder).release();
|
||||
// extra service:
|
||||
|
|
@ -239,7 +241,13 @@ Status GrpcServer::Init(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GrpcServer::Init() { return Init(nullptr, nullptr); }
|
||||
Status GrpcServer::Init(
|
||||
ServiceInitFunction service_func,
|
||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
|
||||
return Init(service_func, rendezvous_mgr_func, nullptr);
|
||||
}
|
||||
|
||||
Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr); }
|
||||
|
||||
Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
|
||||
GrpcChannelSpec* channel_spec) {
|
||||
|
|
|
|||
|
|
@ -45,6 +45,10 @@ typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
|
|||
typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
|
||||
ServiceInitFunction;
|
||||
|
||||
// function that creates a grpc based worker implementation.
|
||||
typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*)>
|
||||
WorkerCreationFunction;
|
||||
|
||||
class GrpcServer : public ServerInterface {
|
||||
protected:
|
||||
GrpcServer(const ServerDef& server_def, Env* env);
|
||||
|
|
@ -64,6 +68,10 @@ class GrpcServer : public ServerInterface {
|
|||
const string target() const override;
|
||||
|
||||
protected:
|
||||
Status Init(ServiceInitFunction service_func,
|
||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
|
||||
const WorkerCreationFunction& worker_func);
|
||||
|
||||
Status Init(ServiceInitFunction service_func,
|
||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func);
|
||||
|
||||
|
|
|
|||
|
|
@ -347,32 +347,25 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
|||
if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
|
||||
#if GOOGLE_CUDA
|
||||
const DeviceContext* send_dev_context = send_args.device_context;
|
||||
RecvTensorResponse* tmp = new RecvTensorResponse;
|
||||
tmp->set_is_dead(is_dead);
|
||||
AllocatorAttributes alloc_attrs;
|
||||
alloc_attrs.set_gpu_compatible(true);
|
||||
alloc_attrs.set_on_host(true);
|
||||
Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
|
||||
Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
|
||||
CHECK(send_dev_context)
|
||||
<< "send dev name: " << src_dev->name()
|
||||
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
|
||||
// "val" is on a GPU. Uses GPUUtil to fill the response proto.
|
||||
StatusCallback response_ready = [response, done,
|
||||
tmp](const Status& s) {
|
||||
// "val" is on a GPU. Uses GPUUtil to fill the copy on host.
|
||||
StatusCallback copy_ready = [response, done, copy,
|
||||
is_dead](const Status& s) {
|
||||
// The value is now ready to be returned on the wire.
|
||||
tmp->set_send_start_micros(Env::Default()->NowMicros());
|
||||
|
||||
grpc::EncodeRecvTensorResponseToByteBuffer(*tmp, response);
|
||||
grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
|
||||
done(s);
|
||||
delete tmp;
|
||||
delete copy;
|
||||
};
|
||||
|
||||
// TODO (jeff,sanjay,mrry): Avoid copy on GPU path by
|
||||
// modifying GPUUtil::SetProtoFromGPU to accept a
|
||||
// ::grpc::ByteBuffer to serialize to, rather than
|
||||
// encoding into a protocol buffer and then
|
||||
// serializing that (i.e. figure out how to use
|
||||
// EncodeTensorToByteBuffer on this path rather than
|
||||
// EncodeRecvTensorResponseToByteBuffer)
|
||||
GPUUtil::SetProtoFromGPU(val, src_dev, send_dev_context,
|
||||
tmp->mutable_tensor(), is_dead,
|
||||
response_ready);
|
||||
GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy,
|
||||
copy_ready);
|
||||
#else
|
||||
done(errors::Internal("No GPU device in process"));
|
||||
#endif // GOOGLE_CUDA
|
||||
|
|
|
|||
|
|
@ -34,8 +34,10 @@ class GrpcWorker : public Worker {
|
|||
GrpcWorker(WorkerEnv* env);
|
||||
|
||||
// Specialized version of RecvTensor for gRPC, which avoids a copy.
|
||||
void GrpcRecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request,
|
||||
::grpc::ByteBuffer* response, StatusCallback done);
|
||||
virtual void GrpcRecvTensorAsync(CallOptions* opts,
|
||||
const RecvTensorRequest* request,
|
||||
::grpc::ByteBuffer* response,
|
||||
StatusCallback done);
|
||||
|
||||
WorkerEnv* env();
|
||||
};
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user