Merge changes from github.

END_PUBLIC

---
Commit 9f81374c3 authored by raymondxyang<zihao.yang@microsoft.com>
Committed by Rasmus Munk Larsen<rmlarsen@google.com>:
Add option for build more python tests in Cmake (#11853)

* Ignore Windows built project

* Fix deprecated methods in tf.contrib.python

* Fix regex match for Windows build in contrib.keras

* Fix Regex match for Windows build in session_bundle

* * Fix deprecated methods
* Fix regex match for Windows
* Fix compatibility issue with Python 3.x

* Add missing ops into Windows build for test

* Enabled more testcases for Windows build

* Clean code and fix typo

* Add conditional cmake mode for enabling more unit testcase

* Add Cmake mode for major Contrib packages

* Add supplementary info in RAEDME for new cmake option

* * Update tf_tests after testing with TF 1.3
* Clean code and resolve conflicts

* Fix unsafe regex matches and format code

* Update exclude list after testing with latest master branch

* Fix missing module

---
Commit 98f0e1efe authored by Yong Tang<yong.tang.github@outlook.com>
Committed by Rasmus Munk Larsen<rmlarsen@google.com>:
Dynamic ksize and strides with MaxPool (#11875)

* Dynamic ksize with max_pool

This fix tries to fix the issue raised in 4746 where ksize
is static (attr) with max_pool.
This fix changes ksize to input tensor so that it is dynamic now.

This fix fixes 4746.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add dynamic ksize to MaxPoolGrad and MaxPoolGradGrad

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add test cases for max_pool_v2

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Fix GPU Jenkins issue.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Enable MaxPoolV2 in GPU

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Hide MaxPoolV2 and other fixes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

---
Commit 02d6bc185 authored by Bairen Yi<byronyi@users.noreply.github.com>
Committed by Rasmus Munk Larsen<rmlarsen@google.com>:
remove useless variable (#12212)

---
Commit ed6b0d905 authored by namrata-ibm<bhavenamrata@gmail.com>
Committed by Rasmus Munk Larsen<rmlarsen@google.com>:
Adding support for s390x in calculation of cpu_frequency (#12201)

---
Commit 627dfc9dd authored by Taehoon Lee<taehoonlee@snu.ac.kr>
Committed by Taehoon Lee<taehoonlee@snu.ac.kr>:
Fix typos

---
Commit c0f9b0a91 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
In fast-math mode emit a tanh that has a faster min/max.

PiperOrigin-RevId: 164943597

---
Commit 87605f3d6 authored by Kay Zhu<kayzhu@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[TF:XLA] Use HloEvaluator for ComputeConstant, remove the need of a dedicated
compute constant backend.

PiperOrigin-RevId: 164940970

---
Commit 881de45c2 authored by Taehoon Lee<me@taehoonlee.com>
Committed by Rasmus Munk Larsen<rmlarsen@google.com>:
Add bool type supports for GPU kernels (#11927)

* Add bool type supports for GPU kernels

* Add bool type test codes for GPU kernels

---
Commit eeacdcdb1 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add missing "CPU" suffix in registrations.

PiperOrigin-RevId: 164939527

---
Commit de01be952 authored by namrata-ibm<bhavenamrata@gmail.com>
Committed by Rasmus Munk Larsen<rmlarsen@google.com>:
Adding support for Big Endian in graph_constructor_test and wav_io (#12179)

---
Commit 26719d29f authored by QingYing Chen<pkudysj@126.com>
Committed by Rasmus Munk Larsen<rmlarsen@google.com>:
Implement CRF decode (Viterbi decode) for tensor (#12056)

* Implement CRF decoding for tensors

* add test code for tensor version's CRF decoding

* made modifications according to pylint

* add some comments for crf decode

* remove useless code

* add comments at the top comment of crf module and add more comments in crf_test

* capitalize first char of first word in comments

* replace crf_decode test code with a deterministic example

---
Commit f9a81ca2f authored by Pete Warden<pete@petewarden.com>
Committed by gunan<gunan@google.com>:
Create CI build script for Raspberry Pi (#12190)

* Create CI build script for Raspberry Pi

* Moved location of Pi build script

---
Commit e2a163a90 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Merge code from PR #11940 with internal changes from cl/164796436, and update Python tests to also run on GPU.

PiperOrigin-RevId: 164929133

---
Commit 08bbfa187 authored by Taehoon Lee<me@taehoonlee.com>
Committed by Rasmus Munk Larsen<rmlarsen@google.com>:
Fix typos (#12195)

---
Commit ab96f41fb authored by Luke Iwanski<luke@codeplay.com>
Committed by Rasmus Munk Larsen<rmlarsen@google.com>:
[OpenCL] Extends matmul_benchmark.py to cover SYCL (#11697)

* [OpenCL] Extends matmul_benchmark.py to cover SYCL

* Fixed typo

* /gpu:0 -> /device:GPU:0

* Fixes control_flow_ops_py_test

* /gpu: -> /device:GPU:

* Fixes //tensorflow/python/profiler/internal:run_metadata_test

* gpu: -> GPU:

* Fixes tfprof_node

* [OpenCL] Fixes device path to name with many colons (#123)

The device path is constructed from a device name by replacing all
colons with underscores. Some device names contain more than one colon,
for example 'device:SYCL:0' which gives a path 'device_SYCL_0'. The
previous code would not convert this back to the original device name,
but rather to 'device:SYCL_0'.

An alternative fix would be to convert all underscores to colons in the
device name (i.e. remove the restriction inside `replace("_", ":", 1)`),
however I'm not sure if there are any device names which contain
underscores.

* If no gpu device aviable fake one

* gpu: -> device:GPU

* Fixes profiler test

* /gpu:x -> /device:GPU:x

* Fixes debug_io_utils_test.cc test

* Fixes device_name_utils_test.cc

---
Commit 35e7a3665 authored by Yong Tang<yong.tang.github@outlook.com>
Committed by Rasmus Munk Larsen<rmlarsen@google.com>:
Remove unneeded casting of int64 for reverse_sequence (#12192)

This fix remove unneeded cast of int64 for reverse_sequence:
```
lengths = math_ops.to_int64(lengths)
```
as int32 has already been enabled for reverse_sequence.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
---
Commit 9fba8c185 authored by Anna R<annarev@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Add benchmark dashboard link to benchmarks doc. Also, I added a link and
description for Benchmarks page to Community index page.

PiperOrigin-RevId: 164924906

---
Commit bb6f32fa7 authored by Mark Heffernan<meheff@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Make HloAliasAnalysis updatable after changes to the HLO graph.
As part of this change make HloAliasAnalysis a thinner layer which
basically only holds a map from HloValue to HloBuffer and vice versa.

PiperOrigin-RevId: 164923041

---
Commit 9103096c1 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by Thomas K?ppe<tkoeppe@google.com>:
Merged commit includes the following changes:
164923041  by meheff:

    Make HloAliasAnalysis updatable after changes to the HLO graph.
    As part of this change make HloAliasAnalysis a thinner layer which
    basically only holds a map from HloValue to HloBuffer and vice versa.

--

PiperOrigin-RevId: 164923041

---
Commit 822603aed authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Merging sibling fusion instruction using multi_output_fusion

PiperOrigin-RevId: 164920220

---
Commit c035aa2a8 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Go: Update generated wrapper functions for TensorFlow ops.

PiperOrigin-RevId: 164917891

---
Commit e1e81d9ba authored by Luke Iwanski<luke@codeplay.com>
Committed by Rasmus Munk Larsen<rmlarsen@google.com>:
[OpenCL] Fixes double memcpy bug (#151) (#12173)

* [OpenCL] Fixes double memcpy bug (#151)

As the debg CopyOp is called on a Tensor without type, we need to use
the DataType enum to get type information, and use this to pass the type
on to Eigen. This is a workaround Eigen's need to have a type when
calling memcpy. If the Eigen memcpy can be provided without a type
requirement, then the memcpy in sycl_util is unnecessary.

* Acts on feedback from: #12173/files/32cb12a9001b672425867b5a3110fd98e737a20b#r132496277

---
Commit d9ca2d86d authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Internal change

PiperOrigin-RevId: 164916465

---
Commit b8d13d218 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Remove more parts of DCASGD missed in the first pass. (47949b)

PiperOrigin-RevId: 164914552

---
Commit 73b3d52c7 authored by Alexandre Passos<apassos@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
cmake fix

PiperOrigin-RevId: 164911656

---
Commit 2173b5b0a authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Allow TFE_TensorHandleCopyToDevice to have the same device as src and
destination. It will reuse the same underlying buffer in those cases.

PiperOrigin-RevId: 164909906

---
Commit 13eb3b90e authored by Alexandre Passos<apassos@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Experimental C and Python APIs to invoke TensorFlow kernels on concrete values.

PiperOrigin-RevId: 164902588

---
Commit 7dfabcc01 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Initialize ExecutionOptions in ComputeConstant to default values.

PiperOrigin-RevId: 164894867

---
Commit c8897e9bc authored by Benoit Steiner<bsteiner@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Static required time computation

PiperOrigin-RevId: 164894645

---
Commit 076158f9b authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Enable implicit->explicit conversion by default.

PiperOrigin-RevId: 164890915

---
Commit 58c4a4cb1 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Bugfix: number of input channels is not necessarily in the last dimension, after introduction of data_format param.

PiperOrigin-RevId: 164889729

---
Commit 8f9b1af8a authored by Igor Saprykin<isaprykin@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Recover MonitoredSession when the Coordinator is requested to stop with one of the _PREEMPTION_ERRORS.

When SyncReplicasOptimizer is used, a preemption in the Coordinator may result in two cases:
Case 1) the session gets silently marked as complete
Case 2) the session gets stuck

This CL aims to solve and verify solutions for both of these problems. Fix 1 changes the should_stop logic. Fix 2 changes the CoordinatedSession.run() logic.

SyncReplicasOptimizer runs a separate set of threads using a Coordinator instance. Those threads do FIFOQueue.enqueue; the main thread does a blocking FIFOQueue.dequeue.

`sync_token_q` FIFOQueue is on parameter-servers. When one of the PS instances gets preempted, an AbortedError causes the Coordinator to stop via request_stop(ex). That by itself changes the state of MonitoredSession.should_stop() to True (Fix 1).

Results of the blocking Dequeue operation are sent to the chief worker via Recv. What happens next depends on the amount of tokens in `sync_token_q`. If there are enough for the next call to Dequeue to return, then the low-level "tf session run() call" returns. The next iteration of the `while not MonitoredSession.should_stop()` loop decides that the training is complete (Case 1).

If there are not enough tokens in `sync_token_q`, then the blocking Dequeue is going to keep waiting for them. This results in the graph execution getting stuck and the whole session getting garbage collected after 10 minutes (Case 2).

We decided to fix that by re-creating a session after it gets garbage collected (Fix 2). An alternative was to try to cancel the pending Dequeue operation, but it's not clear that it is the right thing to do and it is also not easy.

PiperOrigin-RevId: 164888390

---
Commit 46e4de6e5 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Undo loop fusion changes for now as they seem to be altering a few results.
END_PUBLIC
RELNOTES: n/a

BEGIN_PUBLIC
BEGIN_PUBLIC
Automated g4 rollback of changelist 164825735

PiperOrigin-RevId: 165340331
This commit is contained in:
A. Unique TensorFlower 2017-08-15 12:08:29 -07:00 committed by TensorFlower Gardener
parent 03a33c08dd
commit 28ce1d163e
257 changed files with 6427 additions and 2057 deletions

2
.gitignore vendored
View File

@ -13,3 +13,5 @@ node_modules
__pycache__ __pycache__
*.swp *.swp
.vscode/ .vscode/
cmake_build/
.idea/**

View File

@ -1,52 +1,53 @@
# NOTE: Disabled temporarily because it's too noisy on pushes.
# Where component owners are known, add them here. # Where component owners are known, add them here.
tensorflow/core/platform/windows/* @mrry #tensorflow/core/platform/windows/* @mrry
tensorflow/java/* @asimshankar #tensorflow/java/* @asimshankar
tensorflow/tensorboard/* @jart @dandelionmane #tensorflow/tensorboard/* @jart @dandelionmane
tensorflow/tools/docs/* @markdaoust #tensorflow/tools/docs/* @markdaoust
# contrib # contrib
# NEED OWNER: tensorflow/contrib/avro/* # NEED OWNER: tensorflow/contrib/avro/*
tensorflow/contrib/batching/* @alextp @chrisolston #tensorflow/contrib/batching/* @alextp @chrisolston
tensorflow/contrib/bayesflow/* @ebrevdo @rsepassi @jvdillon #tensorflow/contrib/bayesflow/* @ebrevdo @rsepassi @jvdillon
tensorflow/contrib/cmake/* @mrry @benoitsteiner #tensorflow/contrib/cmake/* @mrry @benoitsteiner
tensorflow/contrib/copy_graph/* @tucker @poxvoculi #tensorflow/contrib/copy_graph/* @tucker @poxvoculi
tensorflow/contrib/crf/* @kentonl #tensorflow/contrib/crf/* @kentonl
tensorflow/contrib/data/* @mrry #tensorflow/contrib/data/* @mrry
tensorflow/contrib/distributions/* @jvdillon @langmore @rsepassi #tensorflow/contrib/distributions/* @jvdillon @langmore @rsepassi
tensorflow/contrib/factorization/* @agarwal-ashish @xavigonzalvo #tensorflow/contrib/factorization/* @agarwal-ashish @xavigonzalvo
tensorflow/contrib/ffmpeg/* @fredbertsch #tensorflow/contrib/ffmpeg/* @fredbertsch
# NEED OWNER: tensorflow/contrib/framework/* # NEED OWNER: tensorflow/contrib/framework/*
tensorflow/contrib/graph_editor/* @purpledog #tensorflow/contrib/graph_editor/* @purpledog
# NEED OWNER: tensorflow/contrib/grid_rnn/* # NEED OWNER: tensorflow/contrib/grid_rnn/*
tensorflow/contrib/hvx/* @satok16 #tensorflow/contrib/hvx/* @satok16
tensorflow/contrib/imperative/* @keveman #tensorflow/contrib/imperative/* @keveman
tensorflow/contrib/integrate/* @shoyer #tensorflow/contrib/integrate/* @shoyer
tensorflow/contrib/kernel_methods/* @petrosmol #tensorflow/contrib/kernel_methods/* @petrosmol
tensorflow/contrib/ios_examples/* @petewarden #tensorflow/contrib/ios_examples/* @petewarden
tensorflow/contrib/labeled_tensor/* @shoyer #tensorflow/contrib/labeled_tensor/* @shoyer
tensorflow/contrib/layers/* @fchollet @martinwicke #tensorflow/contrib/layers/* @fchollet @martinwicke
tensorflow/contrib/learn/* @martinwicke @ispirmustafa @alextp #tensorflow/contrib/learn/* @martinwicke @ispirmustafa @alextp
tensorflow/contrib/linalg/* @langmore #tensorflow/contrib/linalg/* @langmore
tensorflow/contrib/linear_optimizer/* @petrosmol @andreasst @katsiapis #tensorflow/contrib/linear_optimizer/* @petrosmol @andreasst @katsiapis
tensorflow/contrib/lookup/* @ysuematsu @andreasst #tensorflow/contrib/lookup/* @ysuematsu @andreasst
tensorflow/contrib/losses/* @alextp @ispirmustafa #tensorflow/contrib/losses/* @alextp @ispirmustafa
tensorflow/contrib/makefile/* @petewarden @satok16 @wolffg #tensorflow/contrib/makefile/* @petewarden @satok16 @wolffg
tensorflow/contrib/metrics/* @alextp @honkentuber @ispirmustafa #tensorflow/contrib/metrics/* @alextp @honkentuber @ispirmustafa
tensorflow/contrib/nccl/* @cwhipkey @zheng-xq #tensorflow/contrib/nccl/* @cwhipkey @zheng-xq
tensorflow/contrib/opt/* @strategist333 #tensorflow/contrib/opt/* @strategist333
tensorflow/contrib/pi_examples/* @maciekcc #tensorflow/contrib/pi_examples/* @maciekcc
tensorflow/contrib/quantization/* @petewarden @cwhipkey @keveman #tensorflow/contrib/quantization/* @petewarden @cwhipkey @keveman
tensorflow/contrib/rnn/* @ebrevdo #tensorflow/contrib/rnn/* @ebrevdo
tensorflow/contrib/saved_model/* @nfiedel @sukritiramesh #tensorflow/contrib/saved_model/* @nfiedel @sukritiramesh
tensorflow/contrib/seq2seq/* @lukaszkaiser #tensorflow/contrib/seq2seq/* @lukaszkaiser
tensorflow/contrib/session_bundle/* @nfiedel @sukritiramesh #tensorflow/contrib/session_bundle/* @nfiedel @sukritiramesh
tensorflow/contrib/slim/* @sguada @thenbasilmanran #tensorflow/contrib/slim/* @sguada @thenbasilmanran
tensorflow/contrib/stateless/* @girving #tensorflow/contrib/stateless/* @girving
tensorflow/contrib/tensor_forest/* @gilberthendry @thomascolthurst #tensorflow/contrib/tensor_forest/* @gilberthendry @thomascolthurst
tensorflow/contrib/testing/* @dandelionmane #tensorflow/contrib/testing/* @dandelionmane
tensorflow/contrib/timeseries/* @allenlavoie #tensorflow/contrib/timeseries/* @allenlavoie
tensorflow/contrib/tpu/* @frankchn @saeta @jhseu #tensorflow/contrib/tpu/* @frankchn @saeta @jhseu
tensorflow/contrib/training/* @joel-shor @ebrevdo #tensorflow/contrib/training/* @joel-shor @ebrevdo
tensorflow/contrib/util/* @sherrym #tensorflow/contrib/util/* @sherrym

View File

@ -30,16 +30,16 @@ tracking requests and bugs. So please see
and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).**
## Installation ## Installation
*See [Installing TensorFlow](https://www.tensorflow.org/install) for instructions on how to install our release binaries or how to build from source.* *See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.*
People who are a little more adventurous can also try our nightly binaries: People who are a little more adventurous can also try our nightly binaries:
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/)) * Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) * Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) * Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0rc1-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0rc1-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/)) * Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0rc2-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/))
* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0rc1-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0rc1-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/)) * Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0rc2-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/))
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) * Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))

View File

@ -9,6 +9,7 @@
* `DNNLinearCombinedClassifier` * `DNNLinearCombinedClassifier`
* `DNNLinearCombinedRegressor`. * `DNNLinearCombinedRegressor`.
* All our prebuilt binaries have been built with cuDNN 6. * All our prebuilt binaries have been built with cuDNN 6.
* `import tensorflow` now goes much faster.
* Adds a file cache to the GCS filesystem with configurable max staleness for file contents. This permits caching of file contents across close/open boundaries. * Adds a file cache to the GCS filesystem with configurable max staleness for file contents. This permits caching of file contents across close/open boundaries.
* Added an axis parameter to `tf.gather`. * Added an axis parameter to `tf.gather`.
* Added a `constant_values` keyword argument to `tf.pad`. * Added a `constant_values` keyword argument to `tf.pad`.
@ -31,6 +32,7 @@
* GPU kernels and speed improvements for for unary `tf.where` and `tf.nn.top_k`. * GPU kernels and speed improvements for for unary `tf.where` and `tf.nn.top_k`.
* Monotonic Attention wrappers added to `tf.contrib.seq2seq`. * Monotonic Attention wrappers added to `tf.contrib.seq2seq`.
* Added `tf.contrib.signal`, a library for signal processing primitives. * Added `tf.contrib.signal`, a library for signal processing primitives.
* Added `tf.contrib.resampler`, containing CPU and GPU ops for differentiable resampling of images.
## Breaking Changes to the API ## Breaking Changes to the API
* `tf.RewriterConfig` was removed from the Python API after being available in 1.2 release candidates (it was never in an actual release). Graph rewriting is still available, just not as `tf.RewriterConfig`. Instead add an explicit import. * `tf.RewriterConfig` was removed from the Python API after being available in 1.2 release candidates (it was never in an actual release). Graph rewriting is still available, just not as `tf.RewriterConfig`. Instead add an explicit import.
@ -64,7 +66,7 @@
* Exported model signatures using the 'predict' method will no longer have their input and output keys silently ignored and rewritten to 'inputs' and 'outputs'. If a model was exported with different names before 1.2, and is now served with tensorflow/serving, it will accept requests using 'inputs' and 'outputs'. Starting at 1.2, such a model will accept the keys specified during export. Therefore, inference requests using 'inputs' and 'outputs' may start to fail. To fix this, either update any inference clients to send requests with the actual input and output keys used by the trainer code, or conversely, update the trainer code to name the input and output Tensors 'inputs' and 'outputs', respectively. Signatures using the 'classify' and 'regress' methods are not affected by this change; they will continue to standardize their input and output keys as before. * Exported model signatures using the 'predict' method will no longer have their input and output keys silently ignored and rewritten to 'inputs' and 'outputs'. If a model was exported with different names before 1.2, and is now served with tensorflow/serving, it will accept requests using 'inputs' and 'outputs'. Starting at 1.2, such a model will accept the keys specified during export. Therefore, inference requests using 'inputs' and 'outputs' may start to fail. To fix this, either update any inference clients to send requests with the actual input and output keys used by the trainer code, or conversely, update the trainer code to name the input and output Tensors 'inputs' and 'outputs', respectively. Signatures using the 'classify' and 'regress' methods are not affected by this change; they will continue to standardize their input and output keys as before.
* Add in-memory caching to the Dataset API. * Add in-memory caching to the Dataset API.
* Set default end_of_sequence variable in datasets iterators to false. * Set default end_of_sequence variable in datasets iterators to false.
* [Performance] Increase performance of `tf.layers.con2d` when setting use_bias=True by 2x by using nn.bias_add. * [Performance] Increase performance of `tf.layers.conv2d` when setting use_bias=True by 2x by using nn.bias_add.
* Update iOS examples to use CocoaPods, and moved to tensorflow/examples/ios. * Update iOS examples to use CocoaPods, and moved to tensorflow/examples/ios.
* Adds a family= attribute in `tf.summary` ops to allow controlling the tab name used in Tensorboard for organizing summaries. * Adds a family= attribute in `tf.summary` ops to allow controlling the tab name used in Tensorboard for organizing summaries.
* When GPU is configured, do not require --config=cuda, instead, automatically build for GPU if this is requested in the configure script. * When GPU is configured, do not require --config=cuda, instead, automatically build for GPU if this is requested in the configure script.

View File

@ -384,12 +384,16 @@ def set_action_env_var(environ_cp,
def convert_version_to_int(version): def convert_version_to_int(version):
"""Convert a version number to a integer that can be used to compare. """Convert a version number to a integer that can be used to compare.
Version strings of the form X.YZ and X.Y.Z-xxxxx are supported. The
'xxxxx' part, for instance 'homebrew' on OS/X, is ignored.
Args: Args:
version: a version to be covnerted version: a version to be converted
Returns: Returns:
An integer if converted successfully, otherwise return None. An integer if converted successfully, otherwise return None.
""" """
version = version.split('-')[0]
version_segments = version.split('.') version_segments = version.split('.')
for seg in version_segments: for seg in version_segments:
if not seg.isdigit(): if not seg.isdigit():
@ -428,6 +432,8 @@ def check_bazel_version(min_version):
print('Make sure you are running at least bazel %s' % min_version) print('Make sure you are running at least bazel %s' % min_version)
return curr_version return curr_version
print("You have bazel %s installed." % curr_version)
if curr_version_int < min_version_int: if curr_version_int < min_version_int:
print('Please upgrade your bazel installation to version %s or higher to ' print('Please upgrade your bazel installation to version %s or higher to '
'build TensorFlow!' % min_version) 'build TensorFlow!' % min_version)
@ -938,6 +944,8 @@ def main():
'with_hdfs_support', False) 'with_hdfs_support', False)
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
False) False)
set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
False)
set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support', set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support',
False) False)

View File

@ -182,6 +182,12 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting(
name = "with_gdr_support",
values = {"define": "with_gdr_support=true"},
visibility = ["//visibility:public"],
)
config_setting( config_setting(
name = "with_verbs_support", name = "with_verbs_support",
values = {"define": "with_verbs_support=true"}, values = {"define": "with_verbs_support=true"},

View File

@ -146,7 +146,7 @@ class TF_ManagedBuffer : public TensorBuffer {
void* allocate_tensor(const char* operation, size_t len) { void* allocate_tensor(const char* operation, size_t len) {
void* data = void* data =
tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len); tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
if (tensorflow::LogMemory::IsEnabled()) { if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
tensorflow::LogMemory::RecordRawAllocation( tensorflow::LogMemory::RecordRawAllocation(
operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID,
len, data, tensorflow::cpu_allocator()); len, data, tensorflow::cpu_allocator());
@ -155,7 +155,7 @@ void* allocate_tensor(const char* operation, size_t len) {
} }
void deallocate_buffer(void* data, size_t len, void* arg) { void deallocate_buffer(void* data, size_t len, void* arg) {
if (tensorflow::LogMemory::IsEnabled()) { if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
tensorflow::LogMemory::RecordRawDeallocation( tensorflow::LogMemory::RecordRawDeallocation(
"TensorFlow C Api", "TensorFlow C Api",
tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,

View File

@ -101,7 +101,7 @@ void ConcurrentSteps(const Options* opts, int session_index) {
std::unique_ptr<Session> session(NewSession(options)); std::unique_ptr<Session> session(NewSession(options));
GraphDef def = CreateGraphDef(); GraphDef def = CreateGraphDef();
if (options.target.empty()) { if (options.target.empty()) {
graph::SetDefaultDevice(opts->use_gpu ? "/gpu:0" : "/cpu:0", &def); graph::SetDefaultDevice(opts->use_gpu ? "/device:GPU:0" : "/cpu:0", &def);
} }
TF_CHECK_OK(session->Create(def)); TF_CHECK_OK(session->Create(def));

View File

@ -222,7 +222,7 @@ class MatcherBase {
TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase); TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase);
}; };
// WhileConditionComputationMatcher attempst to match a target computation // WhileConditionComputationMatcher attempts to match a target computation
// pattern in the while condition sub-computation. // pattern in the while condition sub-computation.
// If the target pattern is matched, two pieces of information are extracted // If the target pattern is matched, two pieces of information are extracted
// from 'tagged' instructions returned by the matcher: // from 'tagged' instructions returned by the matcher:

View File

@ -626,7 +626,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(instruction_to_fuse->IsFusable()); CHECK(instruction_to_fuse->IsFusable());
if (GetModule()) { if (GetModule()) {
XLA_VLOG_LINES(1, GetModule()->ToString()); XLA_VLOG_LINES(3, GetModule()->ToString());
} }
HloInstruction* clone = nullptr; HloInstruction* clone = nullptr;
if (called_computations_.empty()) { if (called_computations_.empty()) {
@ -1909,9 +1909,10 @@ bool HloInstruction::IsFusable() const {
case HloOpcode::kRecv: case HloOpcode::kRecv:
return false; return false;
// Only fuse Rng if it is used once, otherwise the random numbers generated // Only fuse Rng if it is used once, otherwise the random numbers generated
// will be different in each fusion. // will be different in each fusion. If it is the root (user count = 0)
// then it is the equivalent of having one user.
case HloOpcode::kRng: case HloOpcode::kRng:
return users_.size() == 1; return users_.size() <= 1;
default: default:
return true; return true;
} }

View File

@ -1077,6 +1077,48 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
root2->operand(1)->operand(0)->shape())); root2->operand(1)->operand(0)->shape()));
} }
TEST_F(HloInstructionTest, IsRandomFusable) {
auto shape = ShapeUtil::MakeShape(F32, {2, 2});
{
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR0<float>(0.0)));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR0<float>(1.0)));
auto rng = builder.AddInstruction(HloInstruction::CreateRng(
shape, RandomDistribution::RNG_NORMAL, {const0, const1}));
auto* computation = hlo_module->AddEntryComputation(builder.Build());
computation->CreateFusionInstruction({rng, const0, const1},
HloInstruction::FusionKind::kLoop);
auto* root = computation->root_instruction();
EXPECT_EQ(HloOpcode::kFusion, root->opcode());
}
{
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR0<float>(0.0)));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR0<float>(1.0)));
auto rng = builder.AddInstruction(HloInstruction::CreateRng(
shape, RandomDistribution::RNG_NORMAL, {const0, const1}));
builder.AddInstruction(HloInstruction::CreateUnary(
shape, HloOpcode::kNegate, rng));
auto* computation = hlo_module->AddEntryComputation(builder.Build());
computation->CreateFusionInstruction({rng, const0, const1},
HloInstruction::FusionKind::kLoop);
auto* root = computation->root_instruction();
EXPECT_EQ(HloOpcode::kFusion, root->operand(0)->opcode());
}
}
TEST_F(HloInstructionTest, CloneSuffixNames) { TEST_F(HloInstructionTest, CloneSuffixNames) {
// Test that the suffix string added to cloned instructions is not // Test that the suffix string added to cloned instructions is not
// duplicated. Rather a numeric incrementing value should be appended. That // duplicated. Rather a numeric incrementing value should be appended. That

View File

@ -57,7 +57,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) {
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, NegConstantF32) { XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
auto result = builder.Neg(a); auto result = builder.Neg(a);
@ -66,7 +66,7 @@ TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, NegConstantS32) { XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<int32>({-1, 0, 1, 324, auto a = builder.ConstantR1<int32>({-1, 0, 1, 324,
std::numeric_limits<int32>::min(), std::numeric_limits<int32>::min(),
@ -126,7 +126,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) {
{}); {});
} }
TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
@ -185,7 +185,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
@ -204,7 +204,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) {
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) { XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<int32>({-1, 0, 2, 1000000000}); auto a = builder.ConstantR1<int32>({-1, 0, 2, 1000000000});
auto b = builder.ConstantR1<int32>({-1, 2, 1, -1}); auto b = builder.ConstantR1<int32>({-1, 2, 1, -1});
@ -222,7 +222,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
ComputeAndCompareR1<int32>(&builder, {}, {}); ComputeAndCompareR1<int32>(&builder, {}, {});
} }
TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
auto b = builder.ConstantR1<float>({10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); auto b = builder.ConstantR1<float>({10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
@ -241,7 +241,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, DivS32s) { XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) {
// clang-format off // clang-format off
// Some interesting values to test. // Some interesting values to test.
std::vector<int32> vals = { std::vector<int32> vals = {
@ -316,7 +316,7 @@ TEST_F(ArrayElementwiseOpTest, DivS32s) {
} }
} }
TEST_F(ArrayElementwiseOpTest, DivU32s) { XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) {
// clang-format off // clang-format off
// Some interesting values to test. // Some interesting values to test.
std::vector<uint32> vals = { std::vector<uint32> vals = {
@ -420,7 +420,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
auto b = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); auto b = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
@ -439,7 +439,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) {
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
std::vector<int32> data = {0, std::vector<int32> data = {0,
1, 1,
-1, -1,
@ -474,7 +474,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) {
ComputeAndCompareR1<int32>(&builder, {}, {}); ComputeAndCompareR1<int32>(&builder, {}, {});
} }
TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234, std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234,
0x1a243514, 0xFFFFFFFF, 0x80808080}; 0x1a243514, 0xFFFFFFFF, 0x80808080};
@ -496,7 +496,7 @@ TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
ComputeAndCompareR1<uint32>(&builder, expected, {}); ComputeAndCompareR1<uint32>(&builder, expected, {});
} }
TEST_F(ArrayElementwiseOpTest, LogicalAnd) { XLA_TEST_F(ArrayElementwiseOpTest, LogicalAnd) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<bool>({false, false, true, true}); auto a = builder.ConstantR1<bool>({false, false, true, true});
auto b = builder.ConstantR1<bool>({false, true, false, true}); auto b = builder.ConstantR1<bool>({false, true, false, true});
@ -514,7 +514,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogicalAndZeroElement) {
ComputeAndCompareR1<bool>(&builder, {}, {}); ComputeAndCompareR1<bool>(&builder, {}, {});
} }
TEST_F(ArrayElementwiseOpTest, LogicalOr) { XLA_TEST_F(ArrayElementwiseOpTest, LogicalOr) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<bool>({false, false, true, true}); auto a = builder.ConstantR1<bool>({false, false, true, true});
auto b = builder.ConstantR1<bool>({false, true, false, true}); auto b = builder.ConstantR1<bool>({false, true, false, true});
@ -532,7 +532,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogicalOrZeroElement) {
ComputeAndCompareR1<bool>(&builder, {}, {}); ComputeAndCompareR1<bool>(&builder, {}, {});
} }
TEST_F(ArrayElementwiseOpTest, LogicalNot) { XLA_TEST_F(ArrayElementwiseOpTest, LogicalNot) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<bool>({false, true, true, false}); auto a = builder.ConstantR1<bool>({false, true, true, false});
auto out = builder.LogicalNot(a); auto out = builder.LogicalNot(a);
@ -548,7 +548,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogicalNotZeroElement) {
ComputeAndCompareR1<bool>(&builder, {}, {}); ComputeAndCompareR1<bool>(&builder, {}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
SetFastMathDisabled(true); SetFastMathDisabled(true);
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
@ -567,7 +567,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
ComputeAndCompareR1<bool>(&builder, {}, {}); ComputeAndCompareR1<bool>(&builder, {}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
SetFastMathDisabled(true); SetFastMathDisabled(true);
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
@ -577,7 +577,7 @@ TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {}); ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
SetFastMathDisabled(true); SetFastMathDisabled(true);
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
@ -587,7 +587,7 @@ TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {}); ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
SetFastMathDisabled(true); SetFastMathDisabled(true);
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR1<float>({-2.5f, 5.0f, 2.25f, NAN, 6.0f}); auto lhs = builder.ConstantR1<float>({-2.5f, 5.0f, 2.25f, NAN, 6.0f});
@ -597,7 +597,7 @@ TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {}); ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
SetFastMathDisabled(true); SetFastMathDisabled(true);
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
@ -607,7 +607,7 @@ TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {}); ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
const int32 min = std::numeric_limits<int32>::min(); const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max(); const int32 max = std::numeric_limits<int32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -629,7 +629,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
ComputeAndCompareR1<bool>(&builder, {}, {}); ComputeAndCompareR1<bool>(&builder, {}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
// Disable fast-math because we're operating on NaNs. // Disable fast-math because we're operating on NaNs.
SetFastMathDisabled(true); SetFastMathDisabled(true);
@ -641,7 +641,7 @@ TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {}); ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
const int32 min = std::numeric_limits<int32>::min(); const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max(); const int32 max = std::numeric_limits<int32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -653,7 +653,7 @@ TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
&builder, {false, true, true, true, false, true, true, true, false}, {}); &builder, {false, true, true, true, false, true, true, true, false}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
const int32 min = std::numeric_limits<int32>::min(); const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max(); const int32 max = std::numeric_limits<int32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -665,7 +665,7 @@ TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
&builder, {true, false, false, true, true, false, true, true, true}, {}); &builder, {true, false, false, true, true, false, true, true, true}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
const int32 min = std::numeric_limits<int32>::min(); const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max(); const int32 max = std::numeric_limits<int32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -678,7 +678,7 @@ TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
{}); {});
} }
TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
const int32 min = std::numeric_limits<int32>::min(); const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max(); const int32 max = std::numeric_limits<int32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -690,7 +690,7 @@ TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
&builder, {true, true, true, false, true, true, false, false, true}, {}); &builder, {true, true, true, false, true, true, false, false, true}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
const int32 min = std::numeric_limits<int32>::min(); const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max(); const int32 max = std::numeric_limits<int32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -703,7 +703,7 @@ TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
{}); {});
} }
TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
const uint32 max = std::numeric_limits<uint32>::max(); const uint32 max = std::numeric_limits<uint32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
@ -715,7 +715,7 @@ TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
{}); {});
} }
TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
const uint32 max = std::numeric_limits<uint32>::max(); const uint32 max = std::numeric_limits<uint32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
@ -726,7 +726,7 @@ TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
&builder, {false, true, true, true, false, true, true, true, false}, {}); &builder, {false, true, true, true, false, true, true, true, false}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
const uint32 max = std::numeric_limits<uint32>::max(); const uint32 max = std::numeric_limits<uint32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
@ -737,7 +737,7 @@ TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
&builder, {true, false, false, true, true, false, true, true, true}, {}); &builder, {true, false, false, true, true, false, true, true, true}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
const uint32 max = std::numeric_limits<uint32>::max(); const uint32 max = std::numeric_limits<uint32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
@ -749,7 +749,7 @@ TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
{}); {});
} }
TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
const uint32 max = std::numeric_limits<uint32>::max(); const uint32 max = std::numeric_limits<uint32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
@ -760,7 +760,7 @@ TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
&builder, {true, true, true, false, true, true, false, false, true}, {}); &builder, {true, true, true, false, true, true, false, false, true}, {});
} }
TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
const uint32 max = std::numeric_limits<uint32>::max(); const uint32 max = std::numeric_limits<uint32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
@ -772,7 +772,7 @@ TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
{}); {});
} }
TEST_F(ArrayElementwiseOpTest, PowF32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
SetFastMathDisabled(true); SetFastMathDisabled(true);
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto lhs = auto lhs =
@ -795,7 +795,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
} }
// Some Pow cases that can be implemented more efficiently. // Some Pow cases that can be implemented more efficiently.
TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
ComputationBuilder b(client_, TestName()); ComputationBuilder b(client_, TestName());
std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
@ -823,7 +823,7 @@ TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_); ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
ComputationBuilder b(client_, TestName()); ComputationBuilder b(client_, TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
@ -848,7 +848,7 @@ TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
ComputationBuilder b(client_, TestName()); ComputationBuilder b(client_, TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
@ -873,7 +873,7 @@ TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
ComputationBuilder b(client_, TestName()); ComputationBuilder b(client_, TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
@ -898,7 +898,7 @@ TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
ComputationBuilder b(client_, TestName()); ComputationBuilder b(client_, TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
@ -923,7 +923,7 @@ TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
ComputationBuilder b(client_, TestName()); ComputationBuilder b(client_, TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
@ -955,7 +955,7 @@ TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
&b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
ComputationBuilder b(client_, TestName()); ComputationBuilder b(client_, TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
@ -988,7 +988,7 @@ TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
&b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
ComputationBuilder b(client_, TestName()); ComputationBuilder b(client_, TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
@ -1021,7 +1021,7 @@ TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
&b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, Div4F32) { XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
ComputationBuilder b(client_, TestName()); ComputationBuilder b(client_, TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
@ -1081,7 +1081,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, SquareIn4D) { XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
Array4D<float> values(2, 2, 2, 2); Array4D<float> values(2, 2, 2, 2);
@ -1120,7 +1120,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
// //
// TODO(b/28180546): Make this compile in a way that is consistent // TODO(b/28180546): Make this compile in a way that is consistent
// among backends. // among backends.
TEST_F(ArrayElementwiseOpTest, MinF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
#if !defined(XLA_TEST_BACKEND_CPU) #if !defined(XLA_TEST_BACKEND_CPU)
auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f}); auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f});
@ -1174,7 +1174,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
// TODO(b/28180546): Make this compile in a way that is consistent // TODO(b/28180546): Make this compile in a way that is consistent
// among backends. See comment on MinF32s test above. // among backends. See comment on MinF32s test above.
TEST_F(ArrayElementwiseOpTest, MaxF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
#if !defined(XLA_TEST_BACKEND_CPU) #if !defined(XLA_TEST_BACKEND_CPU)
auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f}); auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f});
@ -1226,7 +1226,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) {
{}, error_spec_); {}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, MaxS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) {
const int32 min = std::numeric_limits<int32>::min(); const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max(); const int32 max = std::numeric_limits<int32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -1241,7 +1241,7 @@ TEST_F(ArrayElementwiseOpTest, MaxS32s) {
ComputeAndCompareR1<int32>(&builder, expected, {}); ComputeAndCompareR1<int32>(&builder, expected, {});
} }
TEST_F(ArrayElementwiseOpTest, MinS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) {
const int32 min = std::numeric_limits<int32>::min(); const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max(); const int32 max = std::numeric_limits<int32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -1256,7 +1256,7 @@ TEST_F(ArrayElementwiseOpTest, MinS32s) {
ComputeAndCompareR1<int32>(&builder, expected, {}); ComputeAndCompareR1<int32>(&builder, expected, {});
} }
TEST_F(ArrayElementwiseOpTest, MaxU32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) {
const uint32 max = std::numeric_limits<uint32>::max(); const uint32 max = std::numeric_limits<uint32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max}); auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
@ -1267,7 +1267,7 @@ TEST_F(ArrayElementwiseOpTest, MaxU32s) {
ComputeAndCompareR1<uint32>(&builder, expected, {}); ComputeAndCompareR1<uint32>(&builder, expected, {});
} }
TEST_F(ArrayElementwiseOpTest, MinU32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) {
const uint32 max = std::numeric_limits<uint32>::max(); const uint32 max = std::numeric_limits<uint32>::max();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max}); auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
@ -1278,7 +1278,7 @@ TEST_F(ArrayElementwiseOpTest, MinU32s) {
ComputeAndCompareR1<uint32>(&builder, expected, {}); ComputeAndCompareR1<uint32>(&builder, expected, {});
} }
TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR1<float>( auto x = builder.ConstantR1<float>(
{-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
@ -1311,7 +1311,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
} }
} }
TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto v = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f}); auto v = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
auto m = auto m =
@ -1354,7 +1354,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
ComputeAndCompareR3<int32>(&builder, expected, {}); ComputeAndCompareR3<int32>(&builder, expected, {});
} }
TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto m = auto m =
builder.ConstantR2<float>({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); builder.ConstantR2<float>({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
@ -1431,7 +1431,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) {
ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {}); ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {});
} }
TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto minimum = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto minimum = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
@ -1442,7 +1442,7 @@ TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto minimum = builder.ConstantR0<float>(0.0f); auto minimum = builder.ConstantR0<float>(0.0f);
auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
@ -1453,7 +1453,7 @@ TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto min_scalar = builder.ConstantR0<float>(0.0f); auto min_scalar = builder.ConstantR0<float>(0.0f);
auto min_vector = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto min_vector = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
@ -1472,7 +1472,7 @@ TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
std::unique_ptr<Literal> param0_literal = std::unique_ptr<Literal> param0_literal =
@ -1516,7 +1516,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
&builder, expected, {param0_data.get(), param1_data.get()}, error_spec_); &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
std::unique_ptr<Literal> param0_literal = std::unique_ptr<Literal> param0_literal =
@ -1550,7 +1550,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, TanhF32s) { XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f}); auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f});
auto result = builder.Tanh(a); auto result = builder.Tanh(a);
@ -1559,7 +1559,7 @@ TEST_F(ArrayElementwiseOpTest, TanhF32s) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
// This is like the test ArrayElementwiseOpTest.TanhF32s above, except that // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that
// the input tensor is large enough to exercise the vectorized tanh // the input tensor is large enough to exercise the vectorized tanh
// implementation. // implementation.
@ -1603,7 +1603,7 @@ TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
ErrorSpec(0.004, 0.004)); ErrorSpec(0.004, 0.004));
} }
TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
// a ------ (add) --------- (add) // a ------ (add) --------- (add)
// / / // / /
// b -----/ / // b -----/ /
@ -1621,7 +1621,7 @@ TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) { XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
// b ------ (add) --------- (add) // b ------ (add) --------- (add)
// / / // / /
// c -----/ / // c -----/ /
@ -1639,7 +1639,7 @@ TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, AddWithNeg) { XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
// a ----- (neg) ----- (add) // a ----- (neg) ----- (add)
// / // /
// b ----- (neg) ----/ // b ----- (neg) ----/
@ -1656,7 +1656,7 @@ TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
// a ------ (add) ------------\ // a ------ (add) ------------\
// / \ // / \
// b -----/ (add) // b -----/ (add)
@ -1679,7 +1679,7 @@ TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
error_spec_); error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto a = auto a =
builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
@ -1704,7 +1704,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
// Add a matrix + scalar. // Add a matrix + scalar.
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto a = auto a =
@ -1820,7 +1820,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
EXPECT_EQ(expected, ExecuteToString(&builder, {})); EXPECT_EQ(expected, ExecuteToString(&builder, {}));
} }
TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) { XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
// Test simple broadcasting of a R1F32 over R2F32 when the order of binary op // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op
// arguments is reversed. // arguments is reversed.
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -1831,7 +1831,7 @@ TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
// Tests broadcasting for arrays with degenerate (size == 1) dimensions. // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
// m's shape in XLA notation is {3, 2} // m's shape in XLA notation is {3, 2}
@ -1891,7 +1891,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) {
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
// Add together a (2,2) array and a (2) array, using dimension 1 for // Add together a (2,2) array and a (2) array, using dimension 1 for
// broadcasting (though there are two ways to broadcast these shapes). // broadcasting (though there are two ways to broadcast these shapes).
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -1902,7 +1902,7 @@ TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
// Binary add of two R3s together // Binary add of two R3s together
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
@ -2033,7 +2033,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
EXPECT_EQ(expected, ExecuteToString(&builder, {})); EXPECT_EQ(expected, ExecuteToString(&builder, {}));
} }
TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5)); std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
@ -2060,7 +2060,7 @@ TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_); ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5)); std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
@ -2088,7 +2088,7 @@ TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_); ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
} }
TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
constexpr int d0 = 16; constexpr int d0 = 16;
constexpr int d1 = 16; constexpr int d1 = 16;
constexpr int d2 = 2; constexpr int d2 = 2;
@ -2119,7 +2119,7 @@ TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
} }
// Show that we can't add two opaques. // Show that we can't add two opaques.
TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto shape = ShapeUtil::MakeOpaqueShape(); auto shape = ShapeUtil::MakeOpaqueShape();
auto x = builder.Parameter(0, shape, "x"); auto x = builder.Parameter(0, shape, "x");
@ -2133,7 +2133,7 @@ TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
// Regression test for b/31927799. "slice - y" is fused and requires implicit // Regression test for b/31927799. "slice - y" is fused and requires implicit
// broadcast. // broadcast.
TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x_literal = Literal::CreateR1<float>({1, 2, 3}); auto x_literal = Literal::CreateR1<float>({1, 2, 3});
auto y_literal = Literal::CreateR1<float>({4, 5}); auto y_literal = Literal::CreateR1<float>({4, 5});

View File

@ -31,6 +31,7 @@ def xla_test(name,
args=[], args=[],
tags=[], tags=[],
copts=[], copts=[],
data=[],
backend_tags={}, backend_tags={},
backend_args={}, backend_args={},
**kwargs): **kwargs):
@ -114,6 +115,7 @@ def xla_test(name,
this_backend_tags = ["xla_%s" % backend] this_backend_tags = ["xla_%s" % backend]
this_backend_copts = [] this_backend_copts = []
this_backend_args = backend_args.get(backend, []) this_backend_args = backend_args.get(backend, [])
this_backend_data = []
if backend == "cpu": if backend == "cpu":
backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
@ -130,6 +132,7 @@ def xla_test(name,
this_backend_copts += plugins[backend]["copts"] this_backend_copts += plugins[backend]["copts"]
this_backend_tags += plugins[backend]["tags"] this_backend_tags += plugins[backend]["tags"]
this_backend_args += plugins[backend]["args"] this_backend_args += plugins[backend]["args"]
this_backend_data += plugins[backend]["data"]
else: else:
fail("Unknown backend %s" % backend) fail("Unknown backend %s" % backend)
@ -145,6 +148,7 @@ def xla_test(name,
this_backend_copts, this_backend_copts,
args=args + this_backend_args, args=args + this_backend_args,
deps=deps + backend_deps, deps=deps + backend_deps,
data=data + this_backend_data,
**kwargs) **kwargs)
test_names.append(test_name) test_names.append(test_name)
@ -227,14 +231,18 @@ def generate_backend_test_macros(backends=[]):
if not backends: if not backends:
backends = all_backends backends = all_backends
for backend in filter_backends(backends): for backend in filter_backends(backends):
manifest = ""
if backend in plugins:
manifest = plugins[backend]["disabled_manifest"]
native.cc_library( native.cc_library(
name="test_macros_%s" % backend, name="test_macros_%s" % backend,
testonly = True, testonly = True,
srcs = ["test_macros.cc"], srcs = ["test_macros.cc"],
hdrs = ["test_macros.h"], hdrs = ["test_macros.h"],
copts = [ copts = [
"-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(), "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
"-DXLA_DISABLED_MANIFEST=\\\"\\\"" "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
], ],
deps = [ deps = [
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",

View File

@ -22,9 +22,13 @@
# "//tensorflow/compiler/plugin/foo:foo_lib", # "//tensorflow/compiler/plugin/foo:foo_lib",
# "//tensorflow/compiler/plugin/foo:test_macros", # "//tensorflow/compiler/plugin/foo:test_macros",
# ], # ],
# "disabled_manifest": "tensorflow/compiler/plugin/foo/disabled_test_manifest.txt",
# "copts": [], # "copts": [],
# "tags": [], # "tags": [],
# "args": [] # "args": []
# "data": [
# "//tensorflow/compiler/plugin/foo:disabled_test_manifest.txt",
# ],
# }, # },
# } # }

View File

@ -69,35 +69,35 @@ class ScalarComputationsTest : public ClientLibraryTestBase {
} }
}; };
TEST_F(ScalarComputationsTest, NegateScalarF32) { XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Neg(builder.ConstantR0<float>(2.1f)); builder.Neg(builder.ConstantR0<float>(2.1f));
ComputeAndCompareR0<float>(&builder, -2.1f, {}, error_spec_); ComputeAndCompareR0<float>(&builder, -2.1f, {}, error_spec_);
} }
TEST_F(ScalarComputationsTest, NegateScalarS32) { XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Neg(builder.ConstantR0<int32>(2)); builder.Neg(builder.ConstantR0<int32>(2));
ComputeAndCompareR0<int32>(&builder, -2, {}); ComputeAndCompareR0<int32>(&builder, -2, {});
} }
TEST_F(ScalarComputationsTest, AddTwoScalarsF32) { XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Add(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f)); builder.Add(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f));
ComputeAndCompareR0<float>(&builder, 7.6f, {}, error_spec_); ComputeAndCompareR0<float>(&builder, 7.6f, {}, error_spec_);
} }
TEST_F(ScalarComputationsTest, AddTwoScalarsS32) { XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Add(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5)); builder.Add(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5));
ComputeAndCompareR0<int32>(&builder, 7, {}); ComputeAndCompareR0<int32>(&builder, 7, {});
} }
TEST_F(ScalarComputationsTest, AddTwoScalarsU32) { XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Add(builder.ConstantR0<uint32>(35), builder.ConstantR0<uint32>(57)); builder.Add(builder.ConstantR0<uint32>(35), builder.ConstantR0<uint32>(57));
@ -137,21 +137,21 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) {
ComputeAndCompareR0<double>(&builder, 3.75, {}); ComputeAndCompareR0<double>(&builder, 3.75, {});
} }
TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) { XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Sub(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f)); builder.Sub(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f));
ComputeAndCompareR0<float>(&builder, -3.4f, {}, error_spec_); ComputeAndCompareR0<float>(&builder, -3.4f, {}, error_spec_);
} }
TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) { XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Sub(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5)); builder.Sub(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5));
ComputeAndCompareR0<int32>(&builder, -3, {}); ComputeAndCompareR0<int32>(&builder, -3, {});
} }
TEST_F(ScalarComputationsTest, MulThreeScalarsF32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Mul(builder.Mul(builder.ConstantR0<float>(2.1f), builder.Mul(builder.Mul(builder.ConstantR0<float>(2.1f),
builder.ConstantR0<float>(5.5f)), builder.ConstantR0<float>(5.5f)),
@ -160,7 +160,7 @@ TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
ComputeAndCompareR0<float>(&builder, 5.775f, {}, error_spec_); ComputeAndCompareR0<float>(&builder, 5.775f, {}, error_spec_);
} }
TEST_F(ScalarComputationsTest, MulTwoScalarsS32) { XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsS32) {
std::vector<int32> data = {0, std::vector<int32> data = {0,
1, 1,
-1, -1,
@ -184,7 +184,7 @@ TEST_F(ScalarComputationsTest, MulTwoScalarsS32) {
} }
} }
TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234, std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234,
0x1a243514, 0xFFFFFFFF, 0x80808080}; 0x1a243514, 0xFFFFFFFF, 0x80808080};
@ -199,7 +199,7 @@ TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
} }
} }
TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Mul( builder.Mul(
builder.Mul(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5)), builder.Mul(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5)),
@ -208,7 +208,7 @@ TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
ComputeAndCompareR0<int32>(&builder, 10, {}); ComputeAndCompareR0<int32>(&builder, 10, {});
} }
TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
std::unique_ptr<Literal> a_literal = Literal::CreateR0<float>(2.1f); std::unique_ptr<Literal> a_literal = Literal::CreateR0<float>(2.1f);
std::unique_ptr<Literal> b_literal = Literal::CreateR0<float>(5.5f); std::unique_ptr<Literal> b_literal = Literal::CreateR0<float>(5.5f);
@ -231,7 +231,7 @@ TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
error_spec_); error_spec_);
} }
TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) { XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Div(builder.ConstantR0<float>(5.0f), builder.ConstantR0<float>(2.5f)); builder.Div(builder.ConstantR0<float>(5.0f), builder.ConstantR0<float>(2.5f));
@ -337,7 +337,7 @@ INSTANTIATE_TEST_CASE_P(
DivS32Params{INT32_MIN, -0x40000000, 2, 0}, // DivS32Params{INT32_MIN, -0x40000000, 2, 0}, //
DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff})); DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff}));
TEST_F(ScalarComputationsTest, DivU32s) { XLA_TEST_F(ScalarComputationsTest, DivU32s) {
// clang-format off // clang-format off
// Some interesting values to test. // Some interesting values to test.
std::vector<uint32> vals = { std::vector<uint32> vals = {
@ -378,7 +378,7 @@ TEST_F(ScalarComputationsTest, DivU32s) {
} }
} }
TEST_F(ScalarComputationsTest, RemU32s) { XLA_TEST_F(ScalarComputationsTest, RemU32s) {
// clang-format off // clang-format off
// Some interesting values to test. // Some interesting values to test.
std::vector<uint32> vals = { std::vector<uint32> vals = {
@ -419,7 +419,7 @@ TEST_F(ScalarComputationsTest, RemU32s) {
} }
} }
TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
builder.Rem(x, builder.ConstantR0<int32>(80000)); builder.Rem(x, builder.ConstantR0<int32>(80000));
@ -446,7 +446,7 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) {
ComputeAndCompareR0<uint32>(&builder, 2, {}); ComputeAndCompareR0<uint32>(&builder, 2, {});
} }
TEST_F(ScalarComputationsTest, LogicalAnd) { XLA_TEST_F(ScalarComputationsTest, LogicalAnd) {
for (bool x : {false, true}) { for (bool x : {false, true}) {
for (bool y : {false, true}) { for (bool y : {false, true}) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -458,7 +458,7 @@ TEST_F(ScalarComputationsTest, LogicalAnd) {
} }
} }
TEST_F(ScalarComputationsTest, LogicalOr) { XLA_TEST_F(ScalarComputationsTest, LogicalOr) {
for (bool x : {false, true}) { for (bool x : {false, true}) {
for (bool y : {false, true}) { for (bool y : {false, true}) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -470,7 +470,7 @@ TEST_F(ScalarComputationsTest, LogicalOr) {
} }
} }
TEST_F(ScalarComputationsTest, LogicalNot) { XLA_TEST_F(ScalarComputationsTest, LogicalNot) {
for (bool x : {false, true}) { for (bool x : {false, true}) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.LogicalNot(builder.ConstantR0<bool>(x)); builder.LogicalNot(builder.ConstantR0<bool>(x));
@ -479,7 +479,7 @@ TEST_F(ScalarComputationsTest, LogicalNot) {
} }
} }
TEST_F(ScalarComputationsTest, SelectScalarTrue) { XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Select(builder.ConstantR0<bool>(true), // The predicate. builder.Select(builder.ConstantR0<bool>(true), // The predicate.
builder.ConstantR0<float>(123.0f), // The value on true. builder.ConstantR0<float>(123.0f), // The value on true.
@ -488,7 +488,7 @@ TEST_F(ScalarComputationsTest, SelectScalarTrue) {
ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_); ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
} }
TEST_F(ScalarComputationsTest, SelectScalarFalse) { XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Select(builder.ConstantR0<bool>(false), // The predicate. builder.Select(builder.ConstantR0<bool>(false), // The predicate.
builder.ConstantR0<float>(123.0f), // The value on true. builder.ConstantR0<float>(123.0f), // The value on true.
@ -499,7 +499,7 @@ TEST_F(ScalarComputationsTest, SelectScalarFalse) {
// This test is an explicit version of what is happening in the following // This test is an explicit version of what is happening in the following
// templatized comparison tests. // templatized comparison tests.
TEST_F(ScalarComputationsTest, CompareGtScalar) { XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Gt(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(1.0f)); builder.Gt(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(1.0f));
@ -507,30 +507,30 @@ TEST_F(ScalarComputationsTest, CompareGtScalar) {
} }
// S32 comparisons. // S32 comparisons.
TEST_F(ScalarComputationsTest, CompareEqS32Greater) { XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) {
TestCompare<int32>(2, 1, false, &ComputationBuilder::Eq); TestCompare<int32>(2, 1, false, &ComputationBuilder::Eq);
} }
TEST_F(ScalarComputationsTest, CompareEqS32Equal) { XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) {
TestCompare<int32>(3, 3, true, &ComputationBuilder::Eq); TestCompare<int32>(3, 3, true, &ComputationBuilder::Eq);
} }
TEST_F(ScalarComputationsTest, CompareNeS32) { XLA_TEST_F(ScalarComputationsTest, CompareNeS32) {
TestCompare<int32>(2, 1, true, &ComputationBuilder::Ne); TestCompare<int32>(2, 1, true, &ComputationBuilder::Ne);
} }
TEST_F(ScalarComputationsTest, CompareGeS32) { XLA_TEST_F(ScalarComputationsTest, CompareGeS32) {
TestCompare<int32>(2, 1, true, &ComputationBuilder::Ge); TestCompare<int32>(2, 1, true, &ComputationBuilder::Ge);
} }
TEST_F(ScalarComputationsTest, CompareGtS32) { XLA_TEST_F(ScalarComputationsTest, CompareGtS32) {
TestCompare<int32>(1, 5, false, &ComputationBuilder::Gt); TestCompare<int32>(1, 5, false, &ComputationBuilder::Gt);
} }
TEST_F(ScalarComputationsTest, CompareLeS32) { XLA_TEST_F(ScalarComputationsTest, CompareLeS32) {
TestCompare<int32>(2, 1, false, &ComputationBuilder::Le); TestCompare<int32>(2, 1, false, &ComputationBuilder::Le);
} }
TEST_F(ScalarComputationsTest, CompareLtS32) { XLA_TEST_F(ScalarComputationsTest, CompareLtS32) {
TestCompare<int32>(9, 7, false, &ComputationBuilder::Lt); TestCompare<int32>(9, 7, false, &ComputationBuilder::Lt);
TestCompare<int32>(std::numeric_limits<int32>::min(), TestCompare<int32>(std::numeric_limits<int32>::min(),
std::numeric_limits<int32>::max(), true, std::numeric_limits<int32>::max(), true,
@ -538,105 +538,105 @@ TEST_F(ScalarComputationsTest, CompareLtS32) {
} }
// U32 comparisons. // U32 comparisons.
TEST_F(ScalarComputationsTest, CompareEqU32False) { XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) {
TestCompare<uint32>(2, 1, false, &ComputationBuilder::Eq); TestCompare<uint32>(2, 1, false, &ComputationBuilder::Eq);
} }
TEST_F(ScalarComputationsTest, CompareNeU32) { XLA_TEST_F(ScalarComputationsTest, CompareNeU32) {
TestCompare<uint32>(2, 1, true, &ComputationBuilder::Ne); TestCompare<uint32>(2, 1, true, &ComputationBuilder::Ne);
} }
TEST_F(ScalarComputationsTest, CompareGeU32Greater) { XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) {
TestCompare<uint32>(2, 1, true, &ComputationBuilder::Ge); TestCompare<uint32>(2, 1, true, &ComputationBuilder::Ge);
} }
TEST_F(ScalarComputationsTest, CompareGeU32Equal) { XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) {
TestCompare<uint32>(3, 3, true, &ComputationBuilder::Ge); TestCompare<uint32>(3, 3, true, &ComputationBuilder::Ge);
} }
TEST_F(ScalarComputationsTest, CompareGtU32) { XLA_TEST_F(ScalarComputationsTest, CompareGtU32) {
TestCompare<uint32>(1, 5, false, &ComputationBuilder::Gt); TestCompare<uint32>(1, 5, false, &ComputationBuilder::Gt);
TestCompare<uint32>(5, 5, false, &ComputationBuilder::Gt); TestCompare<uint32>(5, 5, false, &ComputationBuilder::Gt);
TestCompare<uint32>(5, 1, true, &ComputationBuilder::Gt); TestCompare<uint32>(5, 1, true, &ComputationBuilder::Gt);
} }
TEST_F(ScalarComputationsTest, CompareLeU32) { XLA_TEST_F(ScalarComputationsTest, CompareLeU32) {
TestCompare<uint32>(2, 1, false, &ComputationBuilder::Le); TestCompare<uint32>(2, 1, false, &ComputationBuilder::Le);
} }
TEST_F(ScalarComputationsTest, CompareLtU32) { XLA_TEST_F(ScalarComputationsTest, CompareLtU32) {
TestCompare<uint32>(9, 7, false, &ComputationBuilder::Lt); TestCompare<uint32>(9, 7, false, &ComputationBuilder::Lt);
TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true, TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true,
&ComputationBuilder::Lt); &ComputationBuilder::Lt);
} }
// F32 comparisons. // F32 comparisons.
TEST_F(ScalarComputationsTest, CompareEqF32False) { XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) {
TestCompare<float>(2.0, 1.3, false, &ComputationBuilder::Eq); TestCompare<float>(2.0, 1.3, false, &ComputationBuilder::Eq);
} }
TEST_F(ScalarComputationsTest, CompareNeF32) { XLA_TEST_F(ScalarComputationsTest, CompareNeF32) {
TestCompare<float>(2.0, 1.3, true, &ComputationBuilder::Ne); TestCompare<float>(2.0, 1.3, true, &ComputationBuilder::Ne);
} }
TEST_F(ScalarComputationsTest, CompareGeF32Greater) { XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) {
TestCompare<float>(2.0, 1.9, true, &ComputationBuilder::Ge); TestCompare<float>(2.0, 1.9, true, &ComputationBuilder::Ge);
} }
TEST_F(ScalarComputationsTest, CompareGeF32Equal) { XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) {
TestCompare<float>(3.5, 3.5, true, &ComputationBuilder::Ge); TestCompare<float>(3.5, 3.5, true, &ComputationBuilder::Ge);
} }
TEST_F(ScalarComputationsTest, CompareGtF32) { XLA_TEST_F(ScalarComputationsTest, CompareGtF32) {
TestCompare<float>(1.0, 5.2, false, &ComputationBuilder::Gt); TestCompare<float>(1.0, 5.2, false, &ComputationBuilder::Gt);
} }
TEST_F(ScalarComputationsTest, CompareLeF32) { XLA_TEST_F(ScalarComputationsTest, CompareLeF32) {
TestCompare<float>(2.0, 1.2, false, &ComputationBuilder::Le); TestCompare<float>(2.0, 1.2, false, &ComputationBuilder::Le);
} }
TEST_F(ScalarComputationsTest, CompareLtF32) { XLA_TEST_F(ScalarComputationsTest, CompareLtF32) {
TestCompare<float>(9.0, 7.2, false, &ComputationBuilder::Lt); TestCompare<float>(9.0, 7.2, false, &ComputationBuilder::Lt);
} }
// F32 comparisons with exceptional values. The test names encode the // F32 comparisons with exceptional values. The test names encode the
// left/right operands at the end, and use Minf and Mzero for -inf and -0.0. // left/right operands at the end, and use Minf and Mzero for -inf and -0.0.
TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) { XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) {
TestCompare<float>(-INFINITY, -0.0, true, &ComputationBuilder::Lt); TestCompare<float>(-INFINITY, -0.0, true, &ComputationBuilder::Lt);
} }
TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) { XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) {
// Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
TestCompare<float>(-0.0, 0.0, false, &ComputationBuilder::Lt); TestCompare<float>(-0.0, 0.0, false, &ComputationBuilder::Lt);
} }
TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) { XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) {
TestCompare<float>(0.0, INFINITY, true, &ComputationBuilder::Lt); TestCompare<float>(0.0, INFINITY, true, &ComputationBuilder::Lt);
} }
TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) { XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) {
TestCompare<float>(-INFINITY, -0.0, false, &ComputationBuilder::Ge); TestCompare<float>(-INFINITY, -0.0, false, &ComputationBuilder::Ge);
} }
TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) { XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) {
// Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
TestCompare<float>(-0.0, 0.0, true, &ComputationBuilder::Ge); TestCompare<float>(-0.0, 0.0, true, &ComputationBuilder::Ge);
} }
TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) { XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) {
TestCompare<float>(0.0, INFINITY, false, &ComputationBuilder::Ge); TestCompare<float>(0.0, INFINITY, false, &ComputationBuilder::Ge);
} }
TEST_F(ScalarComputationsTest, ExpScalar) { XLA_TEST_F(ScalarComputationsTest, ExpScalar) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Exp(builder.ConstantR0<float>(2.0f)); builder.Exp(builder.ConstantR0<float>(2.0f));
ComputeAndCompareR0<float>(&builder, 7.3890562, {}, error_spec_); ComputeAndCompareR0<float>(&builder, 7.3890562, {}, error_spec_);
} }
TEST_F(ScalarComputationsTest, LogScalar) { XLA_TEST_F(ScalarComputationsTest, LogScalar) {
ComputationBuilder builder(client_, "log"); ComputationBuilder builder(client_, "log");
builder.Log(builder.ConstantR0<float>(2.0f)); builder.Log(builder.ConstantR0<float>(2.0f));
ComputeAndCompareR0<float>(&builder, 0.6931471, {}, error_spec_); ComputeAndCompareR0<float>(&builder, 0.6931471, {}, error_spec_);
} }
TEST_F(ScalarComputationsTest, TanhScalar) { XLA_TEST_F(ScalarComputationsTest, TanhScalar) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Tanh(builder.ConstantR0<float>(2.0f)); builder.Tanh(builder.ConstantR0<float>(2.0f));
@ -650,14 +650,14 @@ XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) {
ComputeAndCompareR0<double>(&builder, 0.96402758, {}, error_spec_); ComputeAndCompareR0<double>(&builder, 0.96402758, {}, error_spec_);
} }
TEST_F(ScalarComputationsTest, PowScalar) { XLA_TEST_F(ScalarComputationsTest, PowScalar) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Pow(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(3.0f)); builder.Pow(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(3.0f));
ComputeAndCompareR0<float>(&builder, 8.0, {}, error_spec_); ComputeAndCompareR0<float>(&builder, 8.0, {}, error_spec_);
} }
TEST_F(ScalarComputationsTest, ClampScalarHigh) { XLA_TEST_F(ScalarComputationsTest, ClampScalarHigh) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound. builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
builder.ConstantR0<float>(5.0f), // The operand to be clamped. builder.ConstantR0<float>(5.0f), // The operand to be clamped.
@ -666,7 +666,7 @@ TEST_F(ScalarComputationsTest, ClampScalarHigh) {
ComputeAndCompareR0<float>(&builder, 3.0, {}, error_spec_); ComputeAndCompareR0<float>(&builder, 3.0, {}, error_spec_);
} }
TEST_F(ScalarComputationsTest, ClampScalarMiddle) { XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddle) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound. builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
builder.ConstantR0<float>(2.5f), // The operand to be clamped. builder.ConstantR0<float>(2.5f), // The operand to be clamped.
@ -675,7 +675,7 @@ TEST_F(ScalarComputationsTest, ClampScalarMiddle) {
ComputeAndCompareR0<float>(&builder, 2.5, {}, error_spec_); ComputeAndCompareR0<float>(&builder, 2.5, {}, error_spec_);
} }
TEST_F(ScalarComputationsTest, ClampScalarLow) { XLA_TEST_F(ScalarComputationsTest, ClampScalarLow) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound. builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
builder.ConstantR0<float>(-5.0f), // The operand to be clamped. builder.ConstantR0<float>(-5.0f), // The operand to be clamped.
@ -684,57 +684,57 @@ TEST_F(ScalarComputationsTest, ClampScalarLow) {
ComputeAndCompareR0<float>(&builder, 2.0, {}, error_spec_); ComputeAndCompareR0<float>(&builder, 2.0, {}, error_spec_);
} }
TEST_F(ScalarComputationsTest, MinS32Above) { XLA_TEST_F(ScalarComputationsTest, MinS32Above) {
TestMinMax<int32>(10, 3, 3, &ComputationBuilder::Min); TestMinMax<int32>(10, 3, 3, &ComputationBuilder::Min);
} }
TEST_F(ScalarComputationsTest, MinS32Below) { XLA_TEST_F(ScalarComputationsTest, MinS32Below) {
TestMinMax<int32>(-100, 3, -100, &ComputationBuilder::Min); TestMinMax<int32>(-100, 3, -100, &ComputationBuilder::Min);
} }
TEST_F(ScalarComputationsTest, MaxS32Above) { XLA_TEST_F(ScalarComputationsTest, MaxS32Above) {
TestMinMax<int32>(10, 3, 10, &ComputationBuilder::Max); TestMinMax<int32>(10, 3, 10, &ComputationBuilder::Max);
} }
TEST_F(ScalarComputationsTest, MaxS32Below) { XLA_TEST_F(ScalarComputationsTest, MaxS32Below) {
TestMinMax<int32>(-100, 3, 3, &ComputationBuilder::Max); TestMinMax<int32>(-100, 3, 3, &ComputationBuilder::Max);
} }
TEST_F(ScalarComputationsTest, MinU32Above) { XLA_TEST_F(ScalarComputationsTest, MinU32Above) {
const uint32 large = std::numeric_limits<int32>::max(); const uint32 large = std::numeric_limits<int32>::max();
TestMinMax<uint32>(large, 3, 3, &ComputationBuilder::Min); TestMinMax<uint32>(large, 3, 3, &ComputationBuilder::Min);
} }
TEST_F(ScalarComputationsTest, MinU32Below) { XLA_TEST_F(ScalarComputationsTest, MinU32Below) {
TestMinMax<uint32>(0, 5, 0, &ComputationBuilder::Min); TestMinMax<uint32>(0, 5, 0, &ComputationBuilder::Min);
} }
TEST_F(ScalarComputationsTest, MaxU32Above) { XLA_TEST_F(ScalarComputationsTest, MaxU32Above) {
const uint32 large = std::numeric_limits<int32>::max(); const uint32 large = std::numeric_limits<int32>::max();
TestMinMax<uint32>(large, 3, large, &ComputationBuilder::Max); TestMinMax<uint32>(large, 3, large, &ComputationBuilder::Max);
} }
TEST_F(ScalarComputationsTest, MaxU32Below) { XLA_TEST_F(ScalarComputationsTest, MaxU32Below) {
TestMinMax<uint32>(0, 5, 5, &ComputationBuilder::Max); TestMinMax<uint32>(0, 5, 5, &ComputationBuilder::Max);
} }
TEST_F(ScalarComputationsTest, MinF32Above) { XLA_TEST_F(ScalarComputationsTest, MinF32Above) {
TestMinMax<float>(10.1f, 3.1f, 3.1f, &ComputationBuilder::Min); TestMinMax<float>(10.1f, 3.1f, 3.1f, &ComputationBuilder::Min);
} }
TEST_F(ScalarComputationsTest, MinF32Below) { XLA_TEST_F(ScalarComputationsTest, MinF32Below) {
TestMinMax<float>(-100.1f, 3.1f, -100.1f, &ComputationBuilder::Min); TestMinMax<float>(-100.1f, 3.1f, -100.1f, &ComputationBuilder::Min);
} }
TEST_F(ScalarComputationsTest, MaxF32Above) { XLA_TEST_F(ScalarComputationsTest, MaxF32Above) {
TestMinMax<float>(10.1f, 3.1f, 10.1f, &ComputationBuilder::Max); TestMinMax<float>(10.1f, 3.1f, 10.1f, &ComputationBuilder::Max);
} }
TEST_F(ScalarComputationsTest, MaxF32Below) { XLA_TEST_F(ScalarComputationsTest, MaxF32Below) {
TestMinMax<float>(-100.1f, 3.1f, 3.1f, &ComputationBuilder::Max); TestMinMax<float>(-100.1f, 3.1f, 3.1f, &ComputationBuilder::Max);
} }
TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
// Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20. // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20.
ComputationBuilder b(client_, TestName()); ComputationBuilder b(client_, TestName());
b.Div( b.Div(
@ -747,7 +747,7 @@ TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
ComputeAndCompareR0<float>(&b, 0.5, {}, error_spec_); ComputeAndCompareR0<float>(&b, 0.5, {}, error_spec_);
} }
TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
// Compute the expression 1 * (3 - 1) * (7 + 0) - 4. // Compute the expression 1 * (3 - 1) * (7 + 0) - 4.
ComputationBuilder b(client_, TestName()); ComputationBuilder b(client_, TestName());
b.Sub(b.Mul(b.ConstantR0<int32>(1), b.Sub(b.Mul(b.ConstantR0<int32>(1),
@ -758,7 +758,7 @@ TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
ComputeAndCompareR0<int32>(&b, 10, {}); ComputeAndCompareR0<int32>(&b, 10, {});
} }
TEST_F(ScalarComputationsTest, SqrtF320) { XLA_TEST_F(ScalarComputationsTest, SqrtF320) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
Literal zero_literal = Literal::Zero(PrimitiveType::F32); Literal zero_literal = Literal::Zero(PrimitiveType::F32);

View File

@ -85,12 +85,12 @@ XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
AbsSize0TestHelper<float>(); AbsSize0TestHelper<float>();
} }
TEST_F(UnaryOpTest, AbsTestR1) { XLA_TEST_F(UnaryOpTest, AbsTestR1) {
AbsTestHelper<int>(); AbsTestHelper<int>();
AbsTestHelper<float>(); AbsTestHelper<float>();
} }
TEST_F(UnaryOpTest, AbsTestR0) { XLA_TEST_F(UnaryOpTest, AbsTestR0) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto argi = builder.ConstantR0<int>(-5); auto argi = builder.ConstantR0<int>(-5);
auto absi = builder.Abs(argi); auto absi = builder.Abs(argi);
@ -104,7 +104,7 @@ TEST_F(UnaryOpTest, AbsTestR0) {
ComputeAndCompareR0<float>(&builder, 8.0f, {}); ComputeAndCompareR0<float>(&builder, 8.0f, {});
} }
TEST_F(UnaryOpTest, SignTestR0) { XLA_TEST_F(UnaryOpTest, SignTestR0) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto argi = builder.ConstantR0<int>(-5); auto argi = builder.ConstantR0<int>(-5);
auto absi = builder.Sign(argi); auto absi = builder.Sign(argi);
@ -118,17 +118,17 @@ TEST_F(UnaryOpTest, SignTestR0) {
ComputeAndCompareR0<float>(&builder, -2.0f, {}); ComputeAndCompareR0<float>(&builder, -2.0f, {});
} }
TEST_F(UnaryOpTest, SignTestR1) { XLA_TEST_F(UnaryOpTest, SignTestR1) {
SignTestHelper<int>(); SignTestHelper<int>();
SignTestHelper<float>(); SignTestHelper<float>();
} }
TEST_F(UnaryOpTest, SignAbsTestR1) { XLA_TEST_F(UnaryOpTest, SignAbsTestR1) {
SignAbsTestHelper<int>(); SignAbsTestHelper<int>();
SignAbsTestHelper<float>(); SignAbsTestHelper<float>();
} }
TEST_F(UnaryOpTest, UnsignedAbsTestR1) { XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto arg = builder.ConstantR1<unsigned int>( auto arg = builder.ConstantR1<unsigned int>(
{2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}); {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
@ -138,7 +138,7 @@ TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
&builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}, {}); &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}, {});
} }
TEST_F(UnaryOpTest, UnsignedSignTestR1) { XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto arg = builder.ConstantR1<unsigned int>( auto arg = builder.ConstantR1<unsigned int>(
{2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}); {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
@ -147,7 +147,7 @@ TEST_F(UnaryOpTest, UnsignedSignTestR1) {
ComputeAndCompareR1<unsigned int>(&builder, {1, 1, 0, 1, 1}, {}); ComputeAndCompareR1<unsigned int>(&builder, {1, 1, 0, 1, 1}, {});
} }
TEST_F(UnaryOpTest, SignAbsTestR2) { XLA_TEST_F(UnaryOpTest, SignAbsTestR2) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto arg = builder.ConstantR2<float>({{1.0, -2.0}, {-3.0, 4.0}}); auto arg = builder.ConstantR2<float>({{1.0, -2.0}, {-3.0, 4.0}});
auto sign = builder.Sign(arg); auto sign = builder.Sign(arg);

View File

@ -48,7 +48,7 @@ class VecOpsSimpleTest : public ClientLibraryTestBase {
ErrorSpec error_spec_{0.0001}; ErrorSpec error_spec_{0.0001};
}; };
TEST_F(VecOpsSimpleTest, ExpTenValues) { XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR1<float>( auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
@ -61,7 +61,7 @@ TEST_F(VecOpsSimpleTest, ExpTenValues) {
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
} }
TEST_F(VecOpsSimpleTest, ExpManyValues) { XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) {
for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) { for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
std::vector<float> exponents; std::vector<float> exponents;
@ -83,7 +83,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) {
} }
} }
TEST_F(VecOpsSimpleTest, ExpIn4D) { XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
Array4D<float> exponents(2, 2, 2, 2); Array4D<float> exponents(2, 2, 2, 2);
@ -105,7 +105,7 @@ TEST_F(VecOpsSimpleTest, ExpIn4D) {
ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3)); ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3));
} }
TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR1<float>( auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
@ -116,7 +116,7 @@ TEST_F(VecOpsSimpleTest, NegateTenFloatValues) {
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
} }
TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR1<int32>({2, -2, 12, -4, 5, 20, -15, 0, -2, 1}); auto x = builder.ConstantR1<int32>({2, -2, 12, -4, 5, 20, -15, 0, -2, 1});
builder.Neg(x); builder.Neg(x);
@ -125,7 +125,7 @@ TEST_F(VecOpsSimpleTest, NegateTenInt32Values) {
ComputeAndCompareR1<int32>(&builder, expected, {}); ComputeAndCompareR1<int32>(&builder, expected, {});
} }
TEST_F(VecOpsSimpleTest, NegateUint32Values) { XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR1<uint32>( auto x = builder.ConstantR1<uint32>(
{0, 1, 42, static_cast<uint32>(-1), static_cast<uint32>(-12)}); {0, 1, 42, static_cast<uint32>(-1), static_cast<uint32>(-12)});
@ -135,7 +135,7 @@ TEST_F(VecOpsSimpleTest, NegateUint32Values) {
ComputeAndCompareR1<uint32>(&builder, expected, {}); ComputeAndCompareR1<uint32>(&builder, expected, {});
} }
TEST_F(VecOpsSimpleTest, SquareTenValues) { XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR1<float>( auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
@ -146,7 +146,7 @@ TEST_F(VecOpsSimpleTest, SquareTenValues) {
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
} }
TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR1<float>( auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
@ -187,7 +187,7 @@ XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) {
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
} }
TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto add = CreateScalarAddComputation(F32, &builder); auto add = CreateScalarAddComputation(F32, &builder);
@ -202,7 +202,7 @@ TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) {
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
} }
TEST_F(VecOpsSimpleTest, MaxTenValues) { XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR1<float>( auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
@ -215,7 +215,7 @@ TEST_F(VecOpsSimpleTest, MaxTenValues) {
ComputeAndCompareR1<float>(&builder, expected, {}); ComputeAndCompareR1<float>(&builder, expected, {});
} }
TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) { XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) {
// Similar to MaxTenValues, except that the inputs come from params rather // Similar to MaxTenValues, except that the inputs come from params rather
// than constants. // than constants.
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -233,7 +233,7 @@ TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) {
error_spec_); error_spec_);
} }
TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) {
// Similar to MaxTenValuesFromParams, except that the data size passed in and // Similar to MaxTenValuesFromParams, except that the data size passed in and
// out is large. // out is large.
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
@ -273,7 +273,7 @@ TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) {
error_spec_); error_spec_);
} }
TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR1<float>( auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
@ -285,7 +285,7 @@ TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) {
ComputeAndCompareR1<float>(&builder, expected, {}); ComputeAndCompareR1<float>(&builder, expected, {});
} }
TEST_F(VecOpsSimpleTest, MinTenValues) { XLA_TEST_F(VecOpsSimpleTest, MinTenValues) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto x = builder.ConstantR1<float>( auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
@ -298,7 +298,7 @@ TEST_F(VecOpsSimpleTest, MinTenValues) {
ComputeAndCompareR1<float>(&builder, expected, {}); ComputeAndCompareR1<float>(&builder, expected, {});
} }
TEST_F(VecOpsSimpleTest, MinMaxTenValues) { XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto zero = builder.ConstantR0<float>(0); auto zero = builder.ConstantR0<float>(0);
auto one = builder.ConstantR0<float>(1); auto one = builder.ConstantR0<float>(1);
@ -311,7 +311,7 @@ TEST_F(VecOpsSimpleTest, MinMaxTenValues) {
ComputeAndCompareR1<float>(&builder, expected, {}); ComputeAndCompareR1<float>(&builder, expected, {});
} }
TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto zero = builder.ConstantR0<float>(0); auto zero = builder.ConstantR0<float>(0);
auto one = builder.ConstantR0<float>(1); auto one = builder.ConstantR0<float>(1);
@ -324,7 +324,7 @@ TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) {
ComputeAndCompareR1<float>(&builder, expected, {}); ComputeAndCompareR1<float>(&builder, expected, {});
} }
TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto zero = builder.ConstantR1<float>({0.0f, 0.0f}); auto zero = builder.ConstantR1<float>({0.0f, 0.0f});
auto one = builder.ConstantR1<float>({1.0f, 1.0f}); auto one = builder.ConstantR1<float>({1.0f, 1.0f});
@ -335,7 +335,7 @@ TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) {
ComputeAndCompareR1<float>(&builder, expected, {}); ComputeAndCompareR1<float>(&builder, expected, {});
} }
TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
auto one = builder.ConstantR0<float>(1); auto one = builder.ConstantR0<float>(1);
auto two = builder.ConstantR0<float>(2); auto two = builder.ConstantR0<float>(2);
@ -348,7 +348,7 @@ TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
ComputeAndCompareR1<float>(&builder, expected, {}); ComputeAndCompareR1<float>(&builder, expected, {});
} }
TEST_F(VecOpsSimpleTest, MapTenValues) { XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
Computation add_half; Computation add_half;
{ {
// add_half(x) = x + 0.5 // add_half(x) = x + 0.5

View File

@ -67,7 +67,7 @@ def batch_function(num_batch_threads, max_batch_size, batch_timeout_micros,
So, for example, in the following code So, for example, in the following code
``` ```python
@batch_function(1, 2, 3) @batch_function(1, 2, 3)
def layer(a): def layer(a):
return tf.matmul(a, a) return tf.matmul(a, a)

View File

@ -29,6 +29,7 @@ option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON)
option(tensorflow_BUILD_CONTRIB_KERNELS "Build OpKernels from tensorflow/contrib/..." ON) option(tensorflow_BUILD_CONTRIB_KERNELS "Build OpKernels from tensorflow/contrib/..." ON)
option(tensorflow_BUILD_CC_TESTS "Build cc unit tests " OFF) option(tensorflow_BUILD_CC_TESTS "Build cc unit tests " OFF)
option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF) option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF)
option(tensorflow_BUILD_MORE_PYTHON_TESTS "Build more python unit tests for contrib packages" OFF)
option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF) option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF)
option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON) option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON)
option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions") option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions")

View File

@ -241,6 +241,13 @@ Step-by-step Windows build
``` ```
ctest -C RelWithDebInfo ctest -C RelWithDebInfo
``` ```
* `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python tests on
serveral major packages. This option is only valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`.
After building the python wheel, you need to install the new wheel before running the tests.
To execute the tests, use
```
ctest -C RelWithDebInfo
```
4. Invoke MSBuild to build TensorFlow. 4. Invoke MSBuild to build TensorFlow.

View File

@ -76,7 +76,9 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
#"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc" #"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/generate_vocab_remapping_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/generate_vocab_remapping_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/zero_initializer_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/ops/checkpoint_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/checkpoint_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc"

View File

@ -156,6 +156,21 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/*_test.py"
) )
if (tensorflow_BUILD_MORE_PYTHON_TESTS)
# Adding other major packages
file(GLOB_RECURSE tf_test_src_py
${tf_test_src_py}
"${tensorflow_source_dir}/tensorflow/contrib/legacy_seq2seq/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/linalg/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/keras/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/learn/*_test.py"
)
endif()
# exclude the ones we don't want # exclude the ones we don't want
set(tf_test_src_py_exclude set(tf_test_src_py_exclude
# Python source line inspection tests are flaky on Windows (b/36375074). # Python source line inspection tests are flaky on Windows (b/36375074).
@ -183,6 +198,9 @@ if (tensorflow_BUILD_PYTHON_TESTS)
# Loading resources in contrib doesn't seem to work on Windows # Loading resources in contrib doesn't seem to work on Windows
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/client/random_forest_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/client/random_forest_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py"
# dask need fix
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py"
# Test is flaky on Windows GPU builds (b/38283730). # Test is flaky on Windows GPU builds (b/38283730).
"${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py"
) )
@ -215,11 +233,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py"
# training tests # training tests
"${tensorflow_source_dir}/tensorflow/python/training/basic_session_run_hooks_test.py" # Needs tf.contrib fix. "${tensorflow_source_dir}/tensorflow/python/training/basic_session_run_hooks_test.py" # Needs tf.contrib fix.
"${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" # Needs tf.contrib fix.
"${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" # Needs portpicker. "${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" # Needs portpicker.
"${tensorflow_source_dir}/tensorflow/python/training/monitored_session_test.py" # Needs tf.contrib fix.
"${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows. "${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows.
"${tensorflow_source_dir}/tensorflow/python/training/saver_large_variable_test.py" # Overflow error.
"${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename. "${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename.
"${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker. "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker.
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops
@ -233,6 +248,45 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py" # No libcurl support "${tensorflow_source_dir}/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py" # No libcurl support
# Newly running on Windows since TensorBoard backend move. Fail on Windows and need debug. # Newly running on Windows since TensorBoard backend move. Fail on Windows and need debug.
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows.
# Dask.Dataframe bugs on Window Build
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/io_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/graph_actions_test.py"
# Need extra build
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_distribution_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py"
# Windows Path
"${tensorflow_source_dir}/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py" #TODO: Fix path
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/models_test.py"
# Related to Windows Multiprocessing https://github.com/fchollet/keras/issues/5071
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/engine/training_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/utils/data_utils_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/callbacks_test.py"
# Scipy needed
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/image_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/logistic_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py"
# Failing with TF 1.3 (TODO)
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py"
) )
endif() endif()
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})

View File

@ -23,6 +23,7 @@ import itertools
import numpy as np import numpy as np
from tensorflow.contrib.crf.python.ops import crf from tensorflow.contrib.crf.python.ops import crf
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -199,6 +200,52 @@ class CrfTest(test.TestCase):
self.assertEqual(actual_max_sequence, self.assertEqual(actual_max_sequence,
expected_max_sequence[:sequence_lengths]) expected_max_sequence[:sequence_lengths])
def testCrfDecode(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
with self.test_session() as sess:
all_sequence_scores = []
all_sequences = []
# Compare the dynamic program with brute force computation.
for tag_indices in itertools.product(
range(num_tags), repeat=sequence_lengths):
tag_indices = list(tag_indices)
tag_indices.extend([0] * (num_words - sequence_lengths))
all_sequences.append(tag_indices)
sequence_score = crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
transition_params=constant_op.constant(transition_params))
sequence_score = array_ops.squeeze(sequence_score, [0])
all_sequence_scores.append(sequence_score)
tf_all_sequence_scores = sess.run(all_sequence_scores)
expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
expected_max_sequence = all_sequences[expected_max_sequence_index]
expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
actual_max_sequence, actual_max_score = crf.crf_decode(
array_ops.expand_dims(inputs, 0),
constant_op.constant(transition_params),
array_ops.expand_dims(sequence_lengths, 0))
actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0])
actual_max_score = array_ops.squeeze(actual_max_score, [0])
tf_actual_max_sequence, tf_actual_max_score = sess.run(
[actual_max_sequence, actual_max_score])
self.assertAllClose(tf_actual_max_score, expected_max_score)
self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]),
expected_max_sequence[:sequence_lengths])
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -16,13 +16,24 @@
The following snippet is an example of a CRF layer on top of a batched sequence The following snippet is an example of a CRF layer on top of a batched sequence
of unary scores (logits for every word). This example also decodes the most of unary scores (logits for every word). This example also decodes the most
likely sequence at test time: likely sequence at test time. There are two ways to do decoding. One
is using crf_decode to do decoding in Tensorflow , and the other one is using
viterbi_decode in Numpy.
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood( log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
unary_scores, gold_tags, sequence_lengths) unary_scores, gold_tags, sequence_lengths)
loss = tf.reduce_mean(-log_likelihood) loss = tf.reduce_mean(-log_likelihood)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss) train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
# Decoding in Tensorflow.
viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode(
unary_scores, transition_params, sequence_lengths)
tf_viterbi_sequence, tf_viterbi_score, _ = session.run(
[viterbi_sequence, viterbi_score, train_op])
# Decoding in Numpy.
tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run( tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run(
[unary_scores, sequence_lengths, transition_params, train_op]) [unary_scores, sequence_lengths, transition_params, train_op])
for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores, for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
@ -31,7 +42,7 @@ for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_]
# Compute the highest score and its tag sequence. # Compute the highest score and its tag sequence.
viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode( tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode(
tf_unary_scores_, tf_transition_params) tf_unary_scores_, tf_transition_params)
""" """
@ -43,6 +54,7 @@ import numpy as np
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import rnn_cell
@ -50,7 +62,9 @@ from tensorflow.python.ops import variable_scope as vs
__all__ = [ __all__ = [
"crf_sequence_score", "crf_log_norm", "crf_log_likelihood", "crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
"crf_unary_score", "crf_binary_score", "CrfForwardRnnCell", "viterbi_decode" "crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
"viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell",
"CrfDecodeBackwardRnnCell"
] ]
@ -310,3 +324,154 @@ def viterbi_decode(score, transition_params):
viterbi_score = np.max(trellis[-1]) viterbi_score = np.max(trellis[-1])
return viterbi, viterbi_score return viterbi, viterbi_score
class CrfDecodeForwardRnnCell(rnn_cell.RNNCell):
"""Computes the forward decoding in a linear-chain CRF.
"""
def __init__(self, transition_params):
"""Initialize the CrfDecodeForwardRnnCell.
Args:
transition_params: A [num_tags, num_tags] matrix of binary
potentials. This matrix is expanded into a
[1, num_tags, num_tags] in preparation for the broadcast
summation occurring within the cell.
"""
self._transition_params = array_ops.expand_dims(transition_params, 0)
self._num_tags = transition_params.get_shape()[0].value
@property
def state_size(self):
return self._num_tags
@property
def output_size(self):
return self._num_tags
def __call__(self, inputs, state, scope=None):
"""Build the CrfDecodeForwardRnnCell.
Args:
inputs: A [batch_size, num_tags] matrix of unary potentials.
state: A [batch_size, num_tags] matrix containing the previous step's
score values.
scope: Unused variable scope of this cell.
Returns:
backpointers: [batch_size, num_tags], containing backpointers.
new_state: [batch_size, num_tags], containing new score values.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
state = array_ops.expand_dims(state, 2) # [B, O, 1]
# This addition op broadcasts self._transitions_params along the zeroth
# dimension and state along the second dimension.
# [B, O, 1] + [1, O, O] -> [B, O, O]
transition_scores = state + self._transition_params # [B, O, O]
new_state = inputs + math_ops.reduce_max(transition_scores, [1]) # [B, O]
backpointers = math_ops.argmax(transition_scores, 1)
backpointers = math_ops.cast(backpointers, dtype=dtypes.int32) # [B, O]
return backpointers, new_state
class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
"""Computes backward decoding in a linear-chain CRF.
"""
def __init__(self, num_tags):
"""Initialize the CrfDecodeBackwardRnnCell.
Args:
num_tags
"""
self._num_tags = num_tags
@property
def state_size(self):
return 1
@property
def output_size(self):
return 1
def __call__(self, inputs, state, scope=None):
"""Build the CrfDecodeBackwardRnnCell.
Args:
inputs: [batch_size, num_tags], backpointer of next step (in time order).
state: [batch_size, 1], next position's tag index.
scope: Unused variable scope of this cell.
Returns:
new_tags, new_tags: A pair of [batch_size, num_tags]
tensors containing the new tag indices.
"""
state = array_ops.squeeze(state, axis=[1]) # [B]
batch_size = array_ops.shape(inputs)[0]
b_indices = math_ops.range(batch_size) # [B]
indices = array_ops.stack([b_indices, state], axis=1) # [B, 2]
new_tags = array_ops.expand_dims(
gen_array_ops.gather_nd(inputs, indices), # [B]
axis=-1) # [B, 1]
return new_tags, new_tags
def crf_decode(potentials, transition_params, sequence_length):
"""Decode the highest scoring sequence of tags in TensorFlow.
This is a function for tensor.
Args:
potentials: A [batch_size, max_seq_len, num_tags] tensor, matrix of
unary potentials.
transition_params: A [num_tags, num_tags] tensor, matrix of
binary potentials.
sequence_length: A [batch_size] tensor, containing sequence lengths.
Returns:
decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32.
Contains the highest scoring tag indicies.
best_score: A [batch_size] tensor, containing the score of decode_tags.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
num_tags = potentials.get_shape()[2].value
# Computes forward decoding. Get last score and backpointers.
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
backpointers, last_score = rnn.dynamic_rnn(
crf_fwd_cell,
inputs=inputs,
sequence_length=sequence_length - 1,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32) # [B, T - 1, O], [B, O]
backpointers = gen_array_ops.reverse_sequence(
backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O]
# Computes backward decoding. Extract tag indices from backpointers.
crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),
dtype=dtypes.int32) # [B]
initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
decode_tags, _ = rnn.dynamic_rnn(
crf_bwd_cell,
inputs=backpointers,
sequence_length=sequence_length - 1,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32) # [B, T - 1, 1]
decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T]
decode_tags = gen_array_ops.reverse_sequence(
decode_tags, sequence_length, seq_dim=1) # [B, T]
best_score = math_ops.reduce_max(last_score, axis=1) # [B]
return decode_tags, best_score

View File

@ -93,7 +93,7 @@ class CudnnRNNBenchmark(test.Benchmark):
batch_size = config["batch_size"] batch_size = config["batch_size"]
seq_length = config["seq_length"] seq_length = config["seq_length"]
with ops.Graph().as_default(), ops.device("/gpu:0"): with ops.Graph().as_default(), ops.device("/device:GPU:0"):
model = cudnn_rnn_ops.CudnnLSTM(num_layers, num_units, num_units) model = cudnn_rnn_ops.CudnnLSTM(num_layers, num_units, num_units)
params_size_t = model.params_size() params_size_t = model.params_size()
input_data = variables.Variable( input_data = variables.Variable(
@ -125,7 +125,7 @@ class CudnnRNNBenchmark(test.Benchmark):
batch_size = config["batch_size"] batch_size = config["batch_size"]
seq_length = config["seq_length"] seq_length = config["seq_length"]
with ops.Graph().as_default(), ops.device("/gpu:0"): with ops.Graph().as_default(), ops.device("/device:GPU:0"):
inputs = seq_length * [ inputs = seq_length * [
array_ops.zeros([batch_size, num_units], dtypes.float32) array_ops.zeros([batch_size, num_units], dtypes.float32)
] ]
@ -153,7 +153,7 @@ class CudnnRNNBenchmark(test.Benchmark):
batch_size = config["batch_size"] batch_size = config["batch_size"]
seq_length = config["seq_length"] seq_length = config["seq_length"]
with ops.Graph().as_default(), ops.device("/gpu:0"): with ops.Graph().as_default(), ops.device("/device:GPU:0"):
inputs = seq_length * [ inputs = seq_length * [
array_ops.zeros([batch_size, num_units], dtypes.float32) array_ops.zeros([batch_size, num_units], dtypes.float32)
] ]

View File

@ -286,14 +286,14 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase):
save_path = os.path.join(self.get_temp_dir(), save_path = os.path.join(self.get_temp_dir(),
"save-restore-variable-test") "save-restore-variable-test")
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
# Passing graph explictly, otherwise an old sess would be reused. # Passing graph explicitly, otherwise an old sess would be reused.
with self.test_session( with self.test_session(
use_gpu=True, graph=ops.get_default_graph()) as sess: use_gpu=True, graph=ops.get_default_graph()) as sess:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
params_v = sess.run(params) params_v = sess.run(params)
val = saver.save(sess, save_path) val = saver.save(sess, save_path)
self.assertEqual(save_path, val) self.assertEqual(save_path, val)
# Passing graph explictly, otherwise an old sess would be reused. # Passing graph explicitly, otherwise an old sess would be reused.
with self.test_session( with self.test_session(
use_gpu=True, graph=ops.get_default_graph()) as sess: use_gpu=True, graph=ops.get_default_graph()) as sess:
reset_params = state_ops.assign( reset_params = state_ops.assign(
@ -328,14 +328,14 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase):
save_path = os.path.join(self.get_temp_dir(), save_path = os.path.join(self.get_temp_dir(),
"save-restore-variable-test") "save-restore-variable-test")
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
# Passing graph explictly, otherwise an old sess would be reused. # Passing graph explicitly, otherwise an old sess would be reused.
with self.test_session( with self.test_session(
use_gpu=True, graph=ops.get_default_graph()) as sess: use_gpu=True, graph=ops.get_default_graph()) as sess:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
params_v = sess.run(param_vars) params_v = sess.run(param_vars)
val = saver.save(sess, save_path) val = saver.save(sess, save_path)
self.assertEqual(save_path, val) self.assertEqual(save_path, val)
# Passing graph explictly, otherwise an old sess would be reused. # Passing graph explicitly, otherwise an old sess would be reused.
with self.test_session( with self.test_session(
use_gpu=True, graph=ops.get_default_graph()) as sess: use_gpu=True, graph=ops.get_default_graph()) as sess:
reset_params = [ reset_params = [
@ -398,14 +398,14 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase):
params=params, params=params,
is_training=False) is_training=False)
total_sum = sum(map(math_ops.reduce_sum, outputs)) total_sum = sum(map(math_ops.reduce_sum, outputs))
# Passing graph explictly, otherwise an old sess would be reused. # Passing graph explicitly, otherwise an old sess would be reused.
with self.test_session( with self.test_session(
use_gpu=True, graph=ops.get_default_graph()) as sess: use_gpu=True, graph=ops.get_default_graph()) as sess:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
total_sum_v = sess.run(total_sum) total_sum_v = sess.run(total_sum)
val = saver.save(sess, save_path) val = saver.save(sess, save_path)
self.assertEqual(save_path, val) self.assertEqual(save_path, val)
# Passing graph explictly, otherwise an old sess would be reused. # Passing graph explicitly, otherwise an old sess would be reused.
with self.test_session( with self.test_session(
use_gpu=True, graph=ops.get_default_graph()) as sess: use_gpu=True, graph=ops.get_default_graph()) as sess:
reset_params = state_ops.assign( reset_params = state_ops.assign(

View File

@ -258,11 +258,12 @@ class Iterator(object):
# initializers that simply reset their state to the beginning. # initializers that simply reset their state to the beginning.
raise ValueError("Iterator does not have an initializer.") raise ValueError("Iterator does not have an initializer.")
def make_initializer(self, dataset): def make_initializer(self, dataset, name=None):
"""Returns a `tf.Operation` that initializes this iterator on `dataset`. """Returns a `tf.Operation` that initializes this iterator on `dataset`.
Args: Args:
dataset: A `Dataset` with compatible structure to this iterator. dataset: A `Dataset` with compatible structure to this iterator.
name: (Optional.) A name for the created operation.
Returns: Returns:
A `tf.Operation` that can be run to initialize this iterator on the given A `tf.Operation` that can be run to initialize this iterator on the given
@ -272,22 +273,25 @@ class Iterator(object):
TypeError: If `dataset` and this iterator do not have a compatible TypeError: If `dataset` and this iterator do not have a compatible
element structure. element structure.
""" """
nest.assert_same_structure(self._output_types, dataset.output_types) with ops.name_scope(name, "make_initializer") as name:
nest.assert_same_structure(self._output_shapes, dataset.output_shapes) nest.assert_same_structure(self._output_types, dataset.output_types)
for iterator_dtype, dataset_dtype in zip( nest.assert_same_structure(self._output_shapes, dataset.output_shapes)
nest.flatten(self._output_types), nest.flatten(dataset.output_types)): for iterator_dtype, dataset_dtype in zip(
if iterator_dtype != dataset_dtype: nest.flatten(self._output_types), nest.flatten(dataset.output_types)):
raise TypeError( if iterator_dtype != dataset_dtype:
"Expected output types %r but got dataset with output types %r." % raise TypeError(
(self._output_types, dataset.output_types)) "Expected output types %r but got dataset with output types %r." %
for iterator_shape, dataset_shape in zip( (self._output_types, dataset.output_types))
nest.flatten(self._output_shapes), nest.flatten(dataset.output_shapes)): for iterator_shape, dataset_shape in zip(
if not iterator_shape.is_compatible_with(dataset_shape): nest.flatten(self._output_shapes),
raise TypeError("Expected output shapes compatible with %r but got " nest.flatten(dataset.output_shapes)):
"dataset with output shapes %r." % if not iterator_shape.is_compatible_with(dataset_shape):
(self._output_shapes, dataset.output_shapes)) raise TypeError("Expected output shapes compatible with %r but got "
return gen_dataset_ops.make_iterator(dataset.make_dataset_resource(), "dataset with output shapes %r." %
self._iterator_resource) (self._output_shapes, dataset.output_shapes))
return gen_dataset_ops.make_iterator(dataset.make_dataset_resource(),
self._iterator_resource,
name=name)
def get_next(self, name=None): def get_next(self, name=None):
"""Returns a nested structure of `tf.Tensor`s containing the next element. """Returns a nested structure of `tf.Tensor`s containing the next element.

View File

@ -49,6 +49,7 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import *
from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
from tensorflow.contrib.distributions.python.ops.sample_stats import * from tensorflow.contrib.distributions.python.ops.sample_stats import *
from tensorflow.contrib.distributions.python.ops.test_util import *
from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import * from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import *
from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import * from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import *
from tensorflow.contrib.distributions.python.ops.wishart import * from tensorflow.contrib.distributions.python.ops.wishart import *

View File

@ -634,7 +634,7 @@ class MixtureBenchmark(test.Benchmark):
np.random.seed(127) np.random.seed(127)
with session.Session(config=config, graph=ops.Graph()) as sess: with session.Session(config=config, graph=ops.Graph()) as sess:
random_seed.set_random_seed(0) random_seed.set_random_seed(0)
with ops.device("/gpu:0" if use_gpu else "/cpu:0"): with ops.device("/device:GPU:0" if use_gpu else "/cpu:0"):
mixture = create_distribution( mixture = create_distribution(
num_components=num_components, num_components=num_components,
batch_size=batch_size, batch_size=batch_size,

View File

@ -17,7 +17,9 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor

View File

@ -20,7 +20,9 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import re import re
import numpy as np import numpy as np
from tensorflow.contrib.framework.python.framework import tensor_util from tensorflow.contrib.framework.python.framework import tensor_util
from tensorflow.contrib.framework.python.ops import variables as variables_lib2 from tensorflow.contrib.framework.python.ops import variables as variables_lib2
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op

View File

@ -37,6 +37,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.platform import resource_loader from tensorflow.python.platform import resource_loader
from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training import training_util from tensorflow.python.training import training_util
from tensorflow.python.util.deprecation import deprecated
__all__ = ['add_model_variable', __all__ = ['add_model_variable',
@ -82,7 +83,7 @@ def zero_initializer(ref, use_locking=True, name="zero_initializer"):
resource_loader.get_path_to_datafile("_variable_ops.so")) resource_loader.get_path_to_datafile("_variable_ops.so"))
return gen_variable_ops.zero_initializer(ref, name=name) return gen_variable_ops.zero_initializer(ref, name=name)
@deprecated(None, "Please switch to tf.train.assert_global_step")
def assert_global_step(global_step_tensor): def assert_global_step(global_step_tensor):
training_util.assert_global_step(global_step_tensor) training_util.assert_global_step(global_step_tensor)
@ -110,11 +111,11 @@ def assert_or_get_global_step(graph=None, global_step_tensor=None):
assert_global_step(global_step_tensor) assert_global_step(global_step_tensor)
return global_step_tensor return global_step_tensor
@deprecated(None, "Please switch to tf.train.get_global_step")
def get_global_step(graph=None): def get_global_step(graph=None):
return training_util.get_global_step(graph) return training_util.get_global_step(graph)
@deprecated(None, "Please switch to tf.train.create_global_step")
def create_global_step(graph=None): def create_global_step(graph=None):
"""Create global step tensor in graph. """Create global step tensor in graph.
@ -132,7 +133,7 @@ def create_global_step(graph=None):
""" """
return training_util.create_global_step(graph) return training_util.create_global_step(graph)
@deprecated(None, "Please switch to tf.train.get_or_create_global_step")
def get_or_create_global_step(graph=None): def get_or_create_global_step(graph=None):
"""Returns and create (if necessary) the global step tensor. """Returns and create (if necessary) the global step tensor.
@ -561,7 +562,7 @@ def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False):
grouped_vars[ckpt_name].append(var) grouped_vars[ckpt_name].append(var)
else: else:
for ckpt_name, value in var_list.iteritems(): for ckpt_name, value in var_list.items():
if isinstance(value, (tuple, list)): if isinstance(value, (tuple, list)):
grouped_vars[ckpt_name] = value grouped_vars[ckpt_name] = value
else: else:

View File

@ -443,19 +443,19 @@ class VariablesTest(test.TestCase):
e = variables_lib2.variable('e', initializer=e_init) e = variables_lib2.variable('e', initializer=e_init)
# The values below highlight how the VariableDeviceChooser puts initial # The values below highlight how the VariableDeviceChooser puts initial
# values on the same device as the variable job. # values on the same device as the variable job.
self.assertDeviceEqual(a.device, '/gpu:0') self.assertDeviceEqual(a.device, '/device:GPU:0')
self.assertEqual(a.initial_value.op.colocation_groups(), self.assertEqual(a.initial_value.op.colocation_groups(),
a.op.colocation_groups()) a.op.colocation_groups())
self.assertDeviceEqual(b.device, '/gpu:0') self.assertDeviceEqual(b.device, '/device:GPU:0')
self.assertEqual(b.initial_value.op.colocation_groups(), self.assertEqual(b.initial_value.op.colocation_groups(),
b.op.colocation_groups()) b.op.colocation_groups())
self.assertDeviceEqual(c.device, '/cpu:12') self.assertDeviceEqual(c.device, '/cpu:12')
self.assertEqual(c.initial_value.op.colocation_groups(), self.assertEqual(c.initial_value.op.colocation_groups(),
c.op.colocation_groups()) c.op.colocation_groups())
self.assertDeviceEqual(d.device, '/gpu:0') self.assertDeviceEqual(d.device, '/device:GPU:0')
self.assertEqual(d.initial_value.op.colocation_groups(), self.assertEqual(d.initial_value.op.colocation_groups(),
d.op.colocation_groups()) d.op.colocation_groups())
self.assertDeviceEqual(e.device, '/gpu:0') self.assertDeviceEqual(e.device, '/device:GPU:0')
self.assertDeviceEqual(e.initial_value.device, '/cpu:99') self.assertDeviceEqual(e.initial_value.device, '/cpu:99')

View File

@ -0,0 +1,125 @@
# Description:
# GPU Direct RDMA Out-of-Band Tensor transport for TensorFlow.
package(default_visibility = [
"//tensorflow:__subpackages__",
])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
filegroup(
name = "c_srcs",
data = glob([
"**/*.cc",
"**/*.h",
]),
)
load(
"//tensorflow:tensorflow.bzl",
"tf_cuda_library",
)
# For platform specific build config
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_proto_library_cc",
)
tf_proto_library_cc(
name = "gdr_proto",
srcs = ["gdr.proto"],
cc_api_version = 2,
visibility = [
"//tensorflow:__subpackages__",
],
)
tf_cuda_library(
name = "gdr_memory_manager",
srcs = ["gdr_memory_manager.cc"],
hdrs = ["gdr_memory_manager.h"],
linkopts = select({
"//tensorflow:with_gdr_support": [
"-libverbs",
"-lrdmacm",
],
"//conditions:default": [],
}),
deps = [
":gdr_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_cuda_library(
name = "gdr_worker",
srcs = ["gdr_worker.cc"],
hdrs = ["gdr_worker.h"],
deps = [
":gdr_memory_manager",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:graph_mgr",
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"//tensorflow/core/distributed_runtime:worker",
"//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_session",
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
"//tensorflow/core/distributed_runtime/rpc:grpc_tensor_coding",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
],
)
cc_library(
name = "gdr_rendezvous_mgr",
srcs = ["gdr_rendezvous_mgr.cc"],
hdrs = ["gdr_rendezvous_mgr.h"],
deps = [
":gdr_memory_manager",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
"//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_interface",
],
)
cc_library(
name = "gdr_server_lib",
srcs = ["gdr_server_lib.cc"],
hdrs = ["gdr_server_lib.h"],
linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
deps = [
":gdr_memory_manager",
":gdr_rendezvous_mgr",
":gdr_worker",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
alwayslink = 1,
)

View File

@ -0,0 +1,122 @@
Introduction
===
This is an implementation of GDR out-of-band transport for TensorFlow distributed runtime, complementary to current gRPC transport. It uses gRPC as control plane to setup rendezvous for each tensor transmission, and utilizes [GPU Direct RDMA](https://developer.nvidia.com/gpudirect) whenever possible to transmit tensors in remote GPU memory through network interface card (NIC), bypassing host memory and CPU entirely. It gracefully falls back to ordinary RDMA or even gRPC when GDR is not available.
Design
===
The GDR out-of-band transport is designed to avoid any unnecessary memory copies, especially for large tensors (>100MB). That typically requires registration of tensor buffers to NIC in an ad-hoc manner, which is rather slow as described in the design trade-off of the verbs runtime. The verbs runtime thus chooses to manage its own NIC-registered buffers and copy the tensors from/to those buffers for every single tensor transfer.
We show that, however, such design trade-off is not always relevant. In this patch, we manage both computation and communication buffers in a unified manner. By pre-registration of large buffers to NIC and allocating small tensors from the buffer pool using a BFC allocator, it is possible to avoid both ad-hoc buffer registration and memory copies all together.
For the actual tensor transport, we rely on gRPC to transmit the [remote buffer information](gdr.proto). This greatly simplifies our design, and there are only 2 types of RDMA messages: a single READ to retrieve the tensor data (bypassing remote CPU), and another invalidate using WRITE with IMM to release the tensor buffer on the remote side. The remote side will only be polling the invalidate message and `Unref` the tensor buffers that read by its peer.
Environment
===
To fully utilize GDR, the target environment has to meet 3 conditions:
1. There is an RDMA capable device with corresponding [OFED package](https://www.openfabrics.org/index.php/overview.html) installed (detailed information is available from your [Infiniband/RoCE](http://www.mellanox.com/page/products_dyn?product_family=116)/[iWarp](http://www.chelsio.com/gpudirect-rdma/) vendor), which could be verified through `ibv_devinfo`, e.g.
```
$ ibv_devinfo
hca_id: mlx4_0
transport: InfiniBand (0)
fw_ver: 2.40.7000
node_guid: 248a:0703:00f6:3370
sys_image_guid: 248a:0703:00f6:3370
vendor_id: 0x02c9
vendor_part_id: 4099
hw_ver: 0x1
board_id: MT_1090110023
phys_port_cnt: 2
Device ports:
port: 1
state: PORT_ACTIVE (4)
max_mtu: 4096 (5)
active_mtu: 1024 (3)
sm_lid: 0
port_lid: 0
port_lmc: 0x00
link_layer: Ethernet
port: 2
state: PORT_ACTIVE (4)
max_mtu: 4096 (5)
active_mtu: 1024 (3)
sm_lid: 0
port_lid: 0
port_lmc: 0x00
link_layer: Ethernet
```
2. There is a GDR capable GPU, i.e. of Fermi, Kepler or later architecture with [corresponding driver](http://docs.nvidia.com/cuda/gpudirect-rdma/index.html) installed. The PCI-e topology could be confirmed by `nvidia-smi topo -m`. For example, in the following topology, `GPU2` and `GPU3` are adjacent to `mlx4_0`, and tensors on these devices could benefit from GDR in current implementation.
```
$ nvidia-smi topo -m
GPU0 GPU1 GPU2 GPU3 mlx4_0 CPU Affinity
GPU0 X PHB SOC SOC SOC 0-5
GPU1 PHB X SOC SOC SOC 0-5
GPU2 SOC SOC X PHB PHB 6-11
GPU3 SOC SOC PHB X PHB 6-11
mlx4_0 SOC SOC PHB PHB X
Legend:
X = Self
SOC = Connection traversing PCIe as well as the SMP link between CPU sockets(e.g. QPI)
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe switches (without traversing the PCIe Host Bridge)
PIX = Connection traversing a single PCIe switch
NV# = Connection traversing a bonded set of # NVLinks
```
3. The [`nv_peer_mem`](https://github.com/Mellanox/nv_peer_memory) kernel module is installed.
How to build and run in GDR mode
===
To test it out on a GDR capable environment, choose to enable GDR in your configure script.
```
Do you wish to build TensorFlow with GDR support? [y/N]: y
GDR support will be enabled for TensorFlow.
```
Change your `protocol` to `grpc+gdr` to enable GDR in your deployment.
```
server = tf.train.Server(cluster, job_name="local", task_index=0, protocol='grpc+gdr') # default protocol is 'grpc'
```
Currently the out-of-band transport service listens to the same IP and port address as specified in gRPC.
A successful initialization looks like this:
```
2017-08-05 19:10:38.601718: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:0) -> (device: 0, name: Tesla K40m, pci bus id: 0000:02:00.0)
2017-08-05 19:10:38.601728: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:1) -> (device: 1, name: Tesla K40m, pci bus id: 0000:03:00.0)
2017-08-05 19:10:38.601736: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:2) -> (device: 2, name: Tesla K40m, pci bus id: 0000:82:00.0)
2017-08-05 19:10:38.601742: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:3) -> (device: 3, name: Tesla K40m, pci bus id: 0000:83:00.0)
2017-08-05 19:10:39.591026: I tensorflow/contrib/gdr/gdr_memory_manager.cc:235] RDMA server is listening on 10.40.2.200:5001
2017-08-05 19:10:39.591071: I tensorflow/contrib/gdr/gdr_memory_manager.cc:285] Instrumenting CPU allocator cuda_host_bfc
2017-08-05 19:10:39.591083: I tensorflow/contrib/gdr/gdr_memory_manager.cc:285] Instrumenting CPU allocator cpu_pool
2017-08-05 19:10:39.591095: I tensorflow/contrib/gdr/gdr_memory_manager.cc:285] Instrumenting CPU allocator cpu_rdma_bfc
2017-08-05 19:10:39.591278: I tensorflow/contrib/gdr/gdr_memory_manager.cc:78] NUMA node for device: mlx4_0 is 1
2017-08-05 19:10:39.740253: I tensorflow/contrib/gdr/gdr_memory_manager.cc:296] Instrumenting GPU allocator with bus_id 2
```
The last line suggests that the GPUs with bus id 2 (mapped to pci bus id prefixed 0000:8) will benefit from GDR and host memory bypass, which is `/gpu:2` and `/gpu:3` in this case.
Caveats
===
In current implementation, only tensors that reside in host memory or in GPU memory such that the GPU is adjacent to an RDMA capable NIC will use direct RDMA as its transport. When RDMA is available but not GDR, a temporary tensor copy on host memory will be used as RDMA source/destination (and copied from/to the target device). When there is no RDMA device present, it can even fallback to the original gRPC runtime. While it is theoretically possible to mix GDR enabled TF with non-GDR deployments in the same job, make sure the environment is properly setup so the GDR mode is enabled whenever possible (i.e. do not fall back to gRPC when it is not absolutely necessary).
In the original design (as in the reference), tensor buffers are only registered to NIC when we could determine that the tensor will be either a source of Send or a sink of Recv across physical machine boundary. However, to implement the precise allocations, we need to change all the devices to possibly return a NIC compatible allocator. As GDR is currently in contrib, we would like to avoid the unnecessary code disruption to the TF core, so we allocate all tensors from NIC-registered buffers using a BFC allocator. This behaviour is similar to the effect of enabling the extra GPU option `force_gpu_compatible`, which allocate all host tensors in GPU-registered buffers no matter they will be transferred from/to GPUs or not.
Reference
===
Bairen Yi, Jiacheng Xia, Li Chen, and Kai Chen. 2017. Towards Zero Copy Dataflows using RDMA. In Proceedings of SIGCOMM Posters and Demos'17, Los Angeles, CA, USA, August 22-24, 2017, 3 pages. https://doi.org/10.1145/3123878.3123907

View File

@ -0,0 +1,13 @@
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
message RemoteMemoryRegion {
string host = 1;
string port = 2;
uint64 addr = 3;
uint32 rkey = 4;
uint32 tensor_key = 5;
uint64 checksum = 6;
}

View File

@ -0,0 +1,682 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef TENSORFLOW_USE_GDR
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
#include <atomic>
#include <cerrno>
#include <fstream>
#include <list>
#include <map>
#include <set>
#include <fcntl.h>
#include <rdma/rdma_cma.h>
#include <rdma/rdma_verbs.h>
#include <sys/epoll.h>
#include "tensorflow/contrib/gdr/gdr.pb.h"
#include "tensorflow/core/common_runtime/bfc_allocator.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
#endif // GOOGLE_CUDA
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace {
bool IsGDRAvailable() {
#if defined(__APPLE__)
return false;
#elif defined(PLATFORM_WINDOWS)
return false;
#else
std::ifstream ifs("/proc/modules");
string line;
while (std::getline(ifs, line)) {
auto sep = line.find(' ');
CHECK_NE(sep, std::string::npos);
if (line.substr(0, sep) == "nv_peer_mem") {
return true;
}
}
return false;
#endif
}
int TryToReadNumaNode(ibv_device* device) {
#if defined(__APPLE__)
LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
return 0;
#elif defined(PLATFORM_WINDOWS)
// Windows support for NUMA is not currently implemented. Return node 0.
return 0;
#else
VLOG(2) << "Trying to read NUMA node for device: " << device->name;
static const int kUnknownNumaNode = -1;
auto filename = string(device->ibdev_path) + "/device/numa_node";
std::ifstream ifs(filename.c_str());
string content;
CHECK(std::getline(ifs, content));
int32 value;
if (strings::safe_strto32(content, &value)) {
if (value < 0) {
LOG(INFO) << "Successful NUMA node read from SysFS had negative value ("
<< value << "), but there must be at least one NUMA node"
", so returning NUMA node zero";
return 0;
}
LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
return value;
}
return kUnknownNumaNode;
#endif
}
void EndpointDeleter(rdma_cm_id* id) {
if (id) {
rdma_destroy_ep(id);
}
}
void MRDeleter(ibv_mr* mr) {
if (mr) {
rdma_dereg_mr(mr);
}
}
using RdmaEndpointPtr = std::unique_ptr<rdma_cm_id, decltype(&EndpointDeleter)>;
using MemoryRegionPtr = std::unique_ptr<ibv_mr, decltype(&MRDeleter)>;
class GdrMemoryManager : public RemoteMemoryManager {
public:
GdrMemoryManager(const string& host, const string& port);
virtual ~GdrMemoryManager();
virtual Status Init() override;
virtual void Run() override;
virtual void Stop() override;
virtual Status TransportOptionsFromTensor(
::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
Device* device, DeviceContext* device_context, bool on_host) override;
virtual Status TensorFromTransportOptions(
Tensor* tensor, const ::google::protobuf::Any& transport_options,
Device* device, DeviceContext* device_context, bool on_host) override;
protected:
Status CreateEndpoint(const string& host, const string& port,
RdmaEndpointPtr& endpoint);
static bool Comparator(const void* ptr, const MemoryRegionPtr& other) {
return ptr < reinterpret_cast<char*>(other->addr) + other->length;
}
ibv_mr* FindMemoryRegion(void* addr, size_t length);
void InsertMemoryRegion(void* addr, size_t length);
#if GOOGLE_CUDA
void InsertCUDAMemoryRegion(void* addr, size_t length);
#endif
void EvictMemoryRegion(void* addr, size_t length);
private:
const string host_;
const string port_;
RdmaEndpointPtr listening_;
std::atomic<bool> stopped_;
int epfd_;
// Server side endpoints
// Accessed sequentially in Run() so not protected by lock
std::list<RdmaEndpointPtr> server_clients_;
using TensorKey = uint32_t;
std::atomic<TensorKey> next_key_;
// Server side on-the-fly tensor buffers
mutex server_mu_;
std::map<TensorKey, const TensorBuffer*> tensor_buffers_
GUARDED_BY(server_mu_);
// Client side endpoints
mutex client_mu_;
std::map<std::pair<string, string>, RdmaEndpointPtr> clients_
GUARDED_BY(cient_mu_);
// Managed memory regions
mutex alloc_mu_;
std::vector<MemoryRegionPtr> mrs_ GUARDED_BY(alloc_mu_);
TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager);
};
// TODO(byronyi): remove this class duplicated from the one in
// common/runtime/gpu/pool_allocator.h when it is available in common_runtime
class BasicCPUAllocator : public SubAllocator {
public:
~BasicCPUAllocator() override {}
void* Alloc(size_t alignment, size_t num_bytes) override {
return port::AlignedMalloc(num_bytes, alignment);
}
void Free(void* ptr, size_t) override { port::AlignedFree(ptr); }
};
// TODO(byronyi): remove this class and its registration when the default
// cpu_allocator() returns visitable allocator
class BFCRdmaAllocator : public BFCAllocator {
public:
BFCRdmaAllocator()
: BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") {
}
};
REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator);
GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
: host_(host),
port_(port),
listening_(nullptr, EndpointDeleter),
stopped_(true),
next_key_(0) {}
GdrMemoryManager::~GdrMemoryManager() { close(epfd_); }
Status GdrMemoryManager::Init() {
epfd_ = epoll_create1(0);
if (epfd_ == -1) {
return errors::Unavailable(strerror(errno), ": ", "epoll_create");
}
rdma_addrinfo* addrinfo;
rdma_addrinfo hints = {};
hints.ai_port_space = RDMA_PS_TCP;
hints.ai_flags = RAI_PASSIVE;
if (rdma_getaddrinfo(const_cast<char*>(host_.c_str()),
const_cast<char*>(port_.c_str()), &hints, &addrinfo)) {
return errors::Unavailable(strerror(errno), ": ", "cannot resolve rdma://",
host_, ":", port_);
}
ibv_qp_init_attr init_attr = {};
init_attr.qp_type = IBV_QPT_RC;
init_attr.cap.max_recv_wr = 32;
init_attr.cap.max_send_wr = 1;
init_attr.cap.max_recv_sge = 1;
init_attr.cap.max_send_sge = 1;
// Create listening endpoint
rdma_cm_id* id;
if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
return errors::Unavailable(strerror(errno), ": ", "cannot bind to rdma://",
host_, ":", port_);
}
listening_.reset(id);
rdma_freeaddrinfo(addrinfo);
// Listen without backlog
if (rdma_listen(listening_.get(), 0)) {
return errors::Unavailable(strerror(errno), ": ",
"cannot listen on rdma://", host_, ":", port_);
}
LOG(INFO) << "RDMA server is listening on " << host_ << ":" << port_;
if (listening_->verbs == nullptr) {
return errors::Unimplemented(
"Unsupported address ", host_, ":", port_,
" as it does not bind to a particular RDMA device");
}
int flags = fcntl(listening_->channel->fd, F_GETFL, 0);
if (fcntl(listening_->channel->fd, F_SETFL, flags | O_NONBLOCK)) {
return errors::Unavailable(strerror(errno), ": ",
"cannot set server to non-blocking mode");
}
epoll_event event = {};
event.events = EPOLLIN | EPOLLPRI;
event.data.ptr = listening_.get();
if (epoll_ctl(epfd_, EPOLL_CTL_ADD, listening_->channel->fd, &event)) {
return errors::Unavailable(strerror(errno), ": ",
"cannot add server to epoll");
}
Allocator* allocators[] = {
#if GOOGLE_CUDA
ProcessState::singleton()->GetCUDAHostAllocator(0),
ProcessState::singleton()->GetCPUAllocator(0),
#endif // GOOGLE_CUDA
cpu_allocator(),
};
using namespace std::placeholders;
VisitableAllocator::Visitor alloc_visitor =
std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
VisitableAllocator::Visitor free_visitor =
std::bind(&GdrMemoryManager::EvictMemoryRegion, this, _1, _2);
std::set<Allocator*> instrumented_;
// Host memory allocators
for (Allocator* allocator : allocators) {
auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
CHECK(visitable_allocator) << "is not visitable for instrumentation"
<< allocator->Name();
// Make sure we don't instrument the same allocator twice
if (instrumented_.find(allocator) == std::end(instrumented_)) {
visitable_allocator->AddAllocVisitor(alloc_visitor);
visitable_allocator->AddFreeVisitor(free_visitor);
instrumented_.insert(allocator);
LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
}
}
#if GOOGLE_CUDA
VisitableAllocator::Visitor cuda_alloc_visitor =
std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
if (IsGDRAvailable()) {
// Note we don't free allocated GPU memory so there is no free visitor
int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1;
ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor);
LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
}
#endif // GOOGLE_CUDA
return Status::OK();
}
void GdrMemoryManager::Run() {
stopped_ = false;
while (!stopped_) {
epoll_event events[32];
int ret = epoll_wait(epfd_, events, 32, 1);
if (ret == -1) {
LOG(ERROR) << "epoll_wait: " << strerror(errno);
return;
}
for (int i = 0; i < ret; i++) {
rdma_cm_id* id = static_cast<rdma_cm_id*>(events[i].data.ptr);
if (id == listening_.get()) {
// Accept incoming connections
if (!rdma_get_request(listening_.get(), &id)) {
if (!rdma_accept(id, nullptr)) {
LOG(INFO) << "Accepted new RDMA connection";
if (ibv_req_notify_cq(id->recv_cq, 0)) {
LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed";
EndpointDeleter(id);
continue;
}
for (int i = 0; i < 32; i++) {
if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed";
EndpointDeleter(id);
continue;
}
}
int flags = fcntl(id->recv_cq_channel->fd, F_GETFL, 0);
if (fcntl(id->recv_cq_channel->fd, F_SETFL, flags | O_NONBLOCK)) {
LOG(ERROR) << strerror(errno)
<< ": cannot set server_client to non-blocking mode";
EndpointDeleter(id);
continue;
}
epoll_event event = {};
event.events = EPOLLIN | EPOLLPRI;
event.data.ptr = id;
if (epoll_ctl(epfd_, EPOLL_CTL_ADD, id->recv_cq_channel->fd,
&event)) {
LOG(ERROR) << strerror(errno)
<< ": cannot add server client to epoll";
EndpointDeleter(id);
continue;
}
server_clients_.push_back({id, EndpointDeleter});
}
}
} else {
// Polling work completions
ibv_cq* cq;
void* context;
if (!ibv_get_cq_event(id->recv_cq_channel, &cq, &context)) {
ibv_ack_cq_events(id->recv_cq, 1);
if (ibv_req_notify_cq(id->recv_cq, 0)) {
LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed";
continue;
}
ibv_wc wc[32];
int ret = ibv_poll_cq(id->recv_cq, 32, wc);
if (ret < 0) {
LOG(ERROR) << "ibv_poll_cq failed";
continue;
}
for (int i = 0; i < ret; i++) {
if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) {
LOG(ERROR) << "Received unknown operation " << wc[i].opcode;
}
if (wc[i].status != 0) {
LOG(ERROR) << ibv_wc_status_str(wc[i].status);
}
TensorKey tensor_key = ntohl(wc[i].imm_data);
{
mutex_lock l(server_mu_);
auto iter = tensor_buffers_.find(tensor_key);
if (iter == std::end(tensor_buffers_)) {
LOG(ERROR) << "Cannot find tensor buffer for tensor key "
<< tensor_key;
} else {
const TensorBuffer* buffer = iter->second;
buffer->Unref();
tensor_buffers_.erase(iter);
}
}
if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
perror("rdma_post_recvv");
LOG(ERROR) << "rdma_post_recvv failed";
continue;
}
}
}
}
}
}
}
void GdrMemoryManager::Stop() { stopped_ = true; }
Status GdrMemoryManager::TransportOptionsFromTensor(
::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
Device* device, DeviceContext* device_context, bool on_host) {
auto buffer = DMAHelper::buffer(&tensor);
void* addr = buffer->data();
size_t length = buffer->size();
if (length == 0) {
return errors::Unavailable("Cannot register tensor buffer of size 0");
}
ibv_mr* mr = FindMemoryRegion(addr, length);
Tensor host_copy;
#if GOOGLE_CUDA
if (!on_host && mr != nullptr) {
TF_RETURN_IF_ERROR(GPUUtil::Sync(device));
} else if (!on_host) {
Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
host_copy = Tensor(alloc, tensor.dtype(), tensor.shape());
Status s;
Notification n;
GPUUtil::CopyGPUTensorToCPU(device, device_context, &tensor, &host_copy,
[&s, &n](const Status& status) {
s.Update(status);
n.Notify();
});
n.WaitForNotification();
if (!s.ok()) {
return s;
}
buffer = DMAHelper::buffer(&host_copy);
addr = buffer->data();
length = buffer->size();
mr = FindMemoryRegion(addr, length);
}
#endif
if (mr == nullptr) {
return errors::Unavailable("Cannot find pinned memory region");
}
buffer->Ref();
TensorKey tensor_key = next_key_++;
{
mutex_lock l(server_mu_);
tensor_buffers_.insert(std::make_pair(tensor_key, buffer));
}
uint64_t checksum = 0;
if (VLOG_IS_ON(2)) {
#ifdef GOOGLE_CUDA
if (device->tensorflow_gpu_device_info() && (!on_host)) {
if (host_copy.NumElements() > 0) {
checksum = GPUUtil::Checksum(device, device_context, host_copy);
} else {
checksum = GPUUtil::Checksum(device, device_context, tensor);
}
} else {
checksum = GPUUtil::Checksum(tensor);
}
#endif
}
RemoteMemoryRegion remote_mr;
remote_mr.set_host(host_);
remote_mr.set_port(port_);
remote_mr.set_addr(reinterpret_cast<uint64_t>(addr));
remote_mr.set_rkey(mr->rkey);
remote_mr.set_tensor_key(tensor_key);
remote_mr.set_checksum(checksum);
mutable_transport_options->PackFrom(remote_mr);
return Status::OK();
}
Status GdrMemoryManager::TensorFromTransportOptions(
Tensor* tensor, const ::google::protobuf::Any& transport_options,
Device* device, DeviceContext* device_context, bool on_host) {
RemoteMemoryRegion remote_mr;
if (!transport_options.UnpackTo(&remote_mr)) {
return errors::NotFound("No RDMA transport options found");
}
auto buffer = DMAHelper::buffer(tensor);
void* addr = buffer->data();
size_t length = buffer->size();
ibv_mr* mr = FindMemoryRegion(addr, length);
Tensor host_copy;
#if GOOGLE_CUDA
if (!on_host && mr != nullptr) {
TF_RETURN_IF_ERROR(GPUUtil::Sync(device));
} else if (!on_host) {
Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
host_copy = Tensor(alloc, tensor->dtype(), tensor->shape());
buffer = DMAHelper::buffer(&host_copy);
addr = buffer->data();
length = buffer->size();
mr = FindMemoryRegion(addr, length);
}
#endif // GOOGLE_CUDA
if (mr == nullptr) {
return errors::Unavailable("Cannot find pinned memory region");
}
decltype(clients_)::iterator iter;
bool success;
{
mutex_lock l(client_mu_);
std::tie(iter, success) = clients_.insert(
std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()),
RdmaEndpointPtr(nullptr, EndpointDeleter)));
if (success || iter->second.get() == nullptr) {
TF_RETURN_IF_ERROR(
CreateEndpoint(remote_mr.host(), remote_mr.port(), iter->second));
}
}
rdma_cm_id* id = iter->second.get();
uint64_t start = Env::Default()->NowMicros();
if (rdma_post_read(id, nullptr, buffer->data(), buffer->size(), mr, 0,
remote_mr.addr(), remote_mr.rkey())) {
return errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed");
}
ibv_send_wr wr = {};
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
wr.imm_data = htonl(remote_mr.tensor_key());
wr.send_flags = IBV_SEND_FENCE | IBV_SEND_SIGNALED;
ibv_send_wr* bad_wr;
if (ibv_post_send(id->qp, &wr, &bad_wr)) {
return errors::Unavailable(strerror(errno), ": ", "ibv_post_send failed");
}
ibv_wc wc = {};
int ret = rdma_get_send_comp(id, &wc);
if (ret < 0 || wc.status) {
return errors::Unavailable(ibv_wc_status_str(wc.status));
}
#if GOOGLE_CUDA
if (host_copy.NumElements() > 0) {
Status s;
Notification n;
GPUUtil::CopyCPUTensorToGPU(&host_copy, device_context, device, tensor,
[&s, &n](const Status& status) {
s.Update(status);
n.Notify();
});
n.WaitForNotification();
if (!s.ok()) {
return s;
}
}
#endif // GOOGLE_CUDA
uint64_t end = Env::Default()->NowMicros();
VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey()
<< " of size " << buffer->size() << " with tensor key "
<< remote_mr.tensor_key() << " took " << (end - start) << " micros";
uint64_t checksum = 0;
if (VLOG_IS_ON(2)) {
#ifdef GOOGLE_CUDA
if (device->tensorflow_gpu_device_info() && (!on_host)) {
if (host_copy.NumElements() > 0) {
checksum = GPUUtil::Checksum(device, device_context, host_copy);
} else {
checksum = GPUUtil::Checksum(device, device_context, *tensor);
}
} else {
checksum = GPUUtil::Checksum(*tensor);
}
CHECK(checksum == remote_mr.checksum()) << "Checksum mismatch: " << checksum
<< "!=" << remote_mr.checksum();
#endif
}
return Status::OK();
}
Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port,
RdmaEndpointPtr& endpoint) {
rdma_addrinfo* addrinfo;
rdma_addrinfo hints = {};
hints.ai_port_space = RDMA_PS_TCP;
if (rdma_getaddrinfo(const_cast<char*>(host.c_str()),
const_cast<char*>(port.c_str()), &hints, &addrinfo)) {
return errors::InvalidArgument(
strerror(errno), ": ", "cannot connect to rdma://", host, ":", port);
}
ibv_qp_init_attr init_attr = {};
init_attr.qp_type = IBV_QPT_RC;
init_attr.cap.max_recv_wr = 1;
init_attr.cap.max_send_wr = 32;
init_attr.cap.max_recv_sge = 1;
init_attr.cap.max_send_sge = 1;
rdma_cm_id* id;
if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
rdma_freeaddrinfo(addrinfo);
return errors::Unavailable(strerror(errno), ": ",
"cannot create endpoint to rdma://", host, ":",
port);
}
rdma_freeaddrinfo(addrinfo);
if (rdma_connect(id, nullptr)) {
rdma_destroy_ep(id);
return errors::Unavailable(strerror(errno), ": ",
"cannot connect to rdma://", host, ":", port);
}
LOG(INFO) << "RDMA endpoint connected to rdma://" << host << ":" << port;
endpoint = RdmaEndpointPtr(id, EndpointDeleter);
return Status::OK();
}
ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) {
if (length == 0) return nullptr;
mutex_lock l(alloc_mu_);
auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
if (iter == std::end(mrs_) || iter->get()->addr > addr) {
return nullptr;
} else {
return iter->get();
}
}
void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) {
if (length == 0) return;
ibv_mr* mr = rdma_reg_read(listening_.get(), addr, length);
if (mr != nullptr) {
mutex_lock l(alloc_mu_);
auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
mrs_.insert(iter, {mr, &MRDeleter});
} else {
LOG(WARNING) << "Cannot register memory region";
}
}
void GdrMemoryManager::EvictMemoryRegion(void* addr, size_t length) {
if (length == 0) return;
mutex_lock l(alloc_mu_);
auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
if (iter != std::end(mrs_) && iter->get()->addr == addr) {
mrs_.erase(iter);
} else {
LOG(WARNING) << "Failed to de-register memory region";
}
}
} // namespace
RemoteMemoryManager* CreateRemoteMemoryManager(const string& host,
const string& port) {
return new GdrMemoryManager(host, port);
}
} // namespace tensorflow
#endif // TENSORFLOW_USE_GDR

View File

@ -0,0 +1,63 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef GDR_MEMORY_MANAGER_H_
#define GDR_MEMORY_MANAGER_H_
#include "tensorflow/core/lib/core/status.h"
namespace google {
namespace protobuf {
class Any;
}
}
namespace tensorflow {
class Device;
class DeviceContext;
class Tensor;
// Abstract interface that handles out-of-band tensor transport.
//
// The transport options are encoded into a protocol buffer and transmitted via
// some other communication channels like RPC.
// See RecvTensorRequest in tensorflow/core/protobuf/worker.proto
class RemoteMemoryManager {
public:
virtual ~RemoteMemoryManager() {}
virtual Status Init() = 0;
virtual void Run() = 0;
virtual void Stop() = 0;
// Encodes the tensor information to an arbitrary protocol buffer
// The protocol buffer needs to be transmitted via some other channel
virtual Status TransportOptionsFromTensor(
::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
Device* device, DeviceContext* device_context, bool on_host) = 0;
// Retrieve the tensor from the encoded protocol buffer
// Note that the tensor has to be allocated, but not initialized
virtual Status TensorFromTransportOptions(
Tensor* tensor, const ::google::protobuf::Any& transport_options,
Device* device, DeviceContext* device_context, bool on_host) = 0;
};
RemoteMemoryManager* CreateRemoteMemoryManager(const string& host,
const string& port);
} // namespace tensorflow
#endif // GDR_MEMORY_MANAGER_H_

View File

@ -0,0 +1,201 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h"
#include "google/protobuf/any.pb.h"
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
class GdrRecvTensorCall : public BaseRecvTensorCall {
public:
GdrRecvTensorCall(WorkerInterface* wi, Device* dst_device,
RemoteMemoryManager* remote_memory_manager,
const Rendezvous::Args& recv_args, int64 step_id,
StringPiece key)
: wi_(wi),
dst_device_(dst_device),
remote_memory_manager_(remote_memory_manager),
recv_args_(recv_args) {
req_.set_step_id(step_id);
req_.set_rendezvous_key(key.data(), key.size());
}
~GdrRecvTensorCall() override {}
void Start(std::function<void()> recv_done) override {
req_.set_dma_ok(true);
resp_.InitAlloc(dst_device_, recv_args_.alloc_attrs);
StatusCallback cb = [this, recv_done](const Status& s) {
bool dma_ok = resp_.metadata().has_transport_options();
if (s.ok() && tensor().TotalBytes() > 0 && (!is_dead()) && dma_ok) {
auto transport_options = resp_.metadata().transport_options();
const bool on_host =
(dst_device_->tensorflow_gpu_device_info() == nullptr) ||
recv_args_.alloc_attrs.on_host();
Status s = remote_memory_manager_->TensorFromTransportOptions(
const_cast<Tensor*>(&tensor()), transport_options, dst_device_,
recv_args_.device_context, on_host);
if (!s.ok()) {
mutex_lock l(mu_);
status_.Update(s);
LOG(ERROR)
<< "Cannot find pinned memory region from allocator "
<< dst_device_->GetAllocator(recv_args_.alloc_attrs)->Name();
}
}
if (!s.ok()) {
mutex_lock l(mu_);
status_.Update(s);
}
recv_done();
};
wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));
}
void StartAbort(const Status& s) override {
{
mutex_lock l(mu_);
status_.Update(s);
}
opts_.StartCancel();
}
Status status() const override {
mutex_lock l(mu_);
return status_;
}
const Tensor& tensor() const { return resp_.tensor(); }
bool is_dead() const { return resp_.metadata().is_dead(); }
Device* dst_device() const { return dst_device_; }
const Rendezvous::Args& recv_args() const { return recv_args_; }
private:
WorkerInterface* wi_;
Device* dst_device_;
RemoteMemoryManager* remote_memory_manager_;
CallOptions opts_;
RecvTensorRequest req_;
TensorResponse resp_;
Rendezvous::Args recv_args_;
mutable mutex mu_;
Status status_ GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(GdrRecvTensorCall);
};
class GdrRemoteRendezvous : public BaseRemoteRendezvous {
public:
GdrRemoteRendezvous(const WorkerEnv* env, int64 step_id,
RemoteMemoryManager* remote_memory_manager)
: BaseRemoteRendezvous(env, step_id),
remote_memory_manager_(remote_memory_manager) {}
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& recv_args,
DoneCallback done) override {
CHECK(is_initialized());
string src_worker;
string src_rel_device;
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker,
&src_rel_device)) {
Status s = errors::Internal(parsed.src_device,
" is invalid remote source device.");
done(s, Args(), recv_args, Tensor{}, false);
return;
}
WorkerSession* sess = session();
WorkerInterface* rwi = sess->worker_cache->CreateWorker(src_worker);
if (rwi == nullptr) {
Status s = errors::Internal("No worker known as ", src_worker);
done(s, Args(), recv_args, Tensor{}, false);
return;
}
Device* dst_device;
Status s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
if (!s.ok()) {
sess->worker_cache->ReleaseWorker(src_worker, rwi);
done(s, Args(), recv_args, Tensor{}, false);
return;
}
// Prepare a RecvTensor call that can handle being aborted.
GdrRecvTensorCall* call =
new GdrRecvTensorCall(rwi, dst_device, remote_memory_manager_,
recv_args, step_id_, parsed.FullKey());
// Record "call" in active_ so that it can be aborted cleanly.
RegisterCall(call);
// Start "call".
Ref();
call->Start([this, call, src_worker, rwi, done]() {
// Removes "call" from active_. Prevent StartAbort().
DeregisterCall(call);
// If StartAbort was called prior to DeregisterCall, then the
// current status should be bad.
Status s = call->status();
done(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
session()->worker_cache->ReleaseWorker(src_worker, rwi);
delete call;
Unref();
});
}
private:
~GdrRemoteRendezvous() override {}
RemoteMemoryManager* remote_memory_manager_;
TF_DISALLOW_COPY_AND_ASSIGN(GdrRemoteRendezvous);
};
} // namespace
GdrRendezvousMgr::GdrRendezvousMgr(const WorkerEnv* env,
RemoteMemoryManager* remote_memory_manager)
: BaseRendezvousMgr(env), remote_memory_manager_(remote_memory_manager) {}
BaseRemoteRendezvous* GdrRendezvousMgr::Create(int64 step_id,
const WorkerEnv* worker_env) {
return new GdrRemoteRendezvous(worker_env, step_id, remote_memory_manager_);
}
} // end namespace tensorflow

View File

@ -0,0 +1,42 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef GDR_RENDEZVOUS_MGR_H_
#define GDR_RENDEZVOUS_MGR_H_
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
class GdrRendezvousMgr : public BaseRendezvousMgr {
public:
explicit GdrRendezvousMgr(const WorkerEnv* env,
RemoteMemoryManager* remote_memory_manager);
protected:
BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env);
private:
RemoteMemoryManager* remote_memory_manager_; // Not owned
TF_DISALLOW_COPY_AND_ASSIGN(GdrRendezvousMgr);
};
} // end namespace tensorflow
#endif // GDR_RENDEZVOUS_MGR_H_

View File

@ -0,0 +1,127 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/gdr/gdr_server_lib.h"
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
#include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h"
#include "tensorflow/contrib/gdr/gdr_worker.h"
#include "net/grpc/public/include/grpc/support/alloc.h"
namespace tensorflow {
GdrServer::GdrServer(const ServerDef& server_def, Env* env)
: GrpcServer(server_def, env) {
string host;
string port;
for (const auto& job : server_def.cluster().job()) {
if (job.name() == server_def.job_name()) {
auto iter = job.tasks().find(server_def.task_index());
if (iter != job.tasks().end()) {
const std::vector<string> hostname_port =
str_util::Split(iter->second, ':');
if (hostname_port.size() == 2) {
host = hostname_port[0];
port = hostname_port[1];
}
}
}
}
remote_memory_manager_ = std::unique_ptr<RemoteMemoryManager>(
CreateRemoteMemoryManager(host, port));
}
GdrServer::~GdrServer() {}
Status GdrServer::Init() {
RendezvousMgrCreationFunction rendezvous_mgr_func =
[this](const WorkerEnv* env) {
return new GdrRendezvousMgr(env, remote_memory_manager_.get());
};
WorkerCreationFunction worker_func = [this](WorkerEnv* env) {
return std::unique_ptr<GdrWorker>(
new GdrWorker(env, remote_memory_manager_.get()));
};
TF_RETURN_IF_ERROR(
GrpcServer::Init(nullptr, rendezvous_mgr_func, worker_func));
return remote_memory_manager_->Init();
}
Status GdrServer::Start() {
{
mutex_lock l(mu_);
gdr_thread_.reset(worker_env()->env->StartThread(
ThreadOptions(), "TF_gdr_service",
[this] { remote_memory_manager_->Run(); }));
}
return GrpcServer::Start();
}
Status GdrServer::Stop() {
TF_RETURN_IF_ERROR(GrpcServer::Stop());
remote_memory_manager_->Stop();
return Status::OK();
}
Status GdrServer::Join() {
{
mutex_lock l(mu_);
gdr_thread_.reset();
}
return GrpcServer::Join();
}
/* static */
Status GdrServer::Create(const ServerDef& server_def, Env* env,
std::unique_ptr<ServerInterface>* out_server) {
std::unique_ptr<GdrServer> ret(
new GdrServer(server_def, env == nullptr ? Env::Default() : env));
TF_RETURN_IF_ERROR(ret->Init());
*out_server = std::move(ret);
return Status::OK();
}
namespace {
class GdrServerFactory : public ServerFactory {
public:
bool AcceptsOptions(const ServerDef& server_def) override {
return server_def.protocol() == "grpc+gdr";
}
Status NewServer(const ServerDef& server_def,
std::unique_ptr<ServerInterface>* out_server) override {
return GdrServer::Create(server_def, Env::Default(), out_server);
}
};
// Registers a `ServerFactory` for `GdrServer` instances.
class GdrServerRegistrar {
public:
GdrServerRegistrar() {
gpr_allocation_functions alloc_fns;
memset(&alloc_fns, 0, sizeof(alloc_fns));
alloc_fns.malloc_fn = port::Malloc;
alloc_fns.realloc_fn = port::Realloc;
alloc_fns.free_fn = port::Free;
gpr_set_allocation_functions(alloc_fns);
ServerFactory::Register("GDR_SERVER", new GdrServerFactory());
}
};
static GdrServerRegistrar registrar;
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,52 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef GDR_SERVER_LIB_H_
#define GDR_SERVER_LIB_H_
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
namespace tensorflow {
class GdrServer : public GrpcServer {
protected:
GdrServer(const ServerDef& server_def, Env* env);
public:
static Status Create(const ServerDef& server_def, Env* env,
std::unique_ptr<ServerInterface>* out_server);
virtual ~GdrServer() override;
virtual Status Start() override;
virtual Status Stop() override;
virtual Status Join() override;
protected:
Status Init();
private:
mutex mu_;
std::unique_ptr<RemoteMemoryManager> remote_memory_manager_;
std::unique_ptr<Thread> gdr_thread_ GUARDED_BY(mu_);
};
} // namespace tensorflow
#endif // GDR_SERVER_LIB_H_

View File

@ -0,0 +1,146 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/gdr/gdr_worker.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#endif // GOOGLE_CUDA
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/distributed_runtime/worker.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
namespace tensorflow {
GdrWorker::GdrWorker(WorkerEnv* worker_env,
RemoteMemoryManager* remote_memory_manager)
: GrpcWorker(worker_env), remote_memory_manager_(remote_memory_manager) {}
void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
::grpc::ByteBuffer* response,
StatusCallback done) {
const int64 step_id = request->step_id();
const string& key = request->rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Rendezvous::ParsedKey parsed;
Status s = Rendezvous::ParseKey(key, &parsed);
Device* src_dev = nullptr;
if (s.ok()) {
s = PrepareRecvTensor(parsed, &src_dev);
}
if (!s.ok()) {
done(s);
return;
}
// Request the tensor associated with the rendezvous key. Any time
// while waiting for the tensor to be produced, up until the start
// of execution of the callback lambda body below, an RPC
// cancellation should abort the rendezvous.
opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
const bool dma_ok = request->dma_ok();
env_->rendezvous_mgr->RecvLocalAsync(
step_id, parsed,
[this, opts, response, done, src_dev, dma_ok](
const Status& status, const Rendezvous::Args& send_args,
const Rendezvous::Args&, const Tensor& val, const bool is_dead) {
opts->ClearCancelCallback();
if (status.ok()) {
// DMA can only be used for Tensors that do not fall into
// the following three odd edge cases: 1) a zero-size
// buffer, 2) a dead tensor which has an uninit value, and
// 3) the tensor has the on_host allocation attribute,
// i.e. it's in CPU RAM *independent of its assigned
// device type*.
const bool on_host =
(src_dev->tensorflow_gpu_device_info() == nullptr) ||
send_args.alloc_attrs.on_host();
if (val.TotalBytes() > 0 && (!is_dead) &&
DMAHelper::CanUseDMA(&val) && dma_ok) {
// DMA cases.
RecvTensorResponse proto;
auto transport_options = proto.mutable_transport_options();
Status s = remote_memory_manager_->TransportOptionsFromTensor(
transport_options, val, src_dev, send_args.device_context,
on_host);
if (s.ok()) {
proto.set_is_dead(is_dead);
proto.set_send_start_micros(Env::Default()->NowMicros());
TensorProto* tensor_proto = proto.mutable_tensor();
tensor_proto->set_dtype(val.dtype());
val.shape().AsProto(tensor_proto->mutable_tensor_shape());
grpc::EncodeRecvTensorResponseToByteBuffer(proto, response);
done(Status::OK());
return;
} else {
done(s);
return;
}
} else {
// Non-DMA cases.
if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
#if GOOGLE_CUDA
const DeviceContext* send_dev_context = send_args.device_context;
AllocatorAttributes alloc_attrs;
alloc_attrs.set_gpu_compatible(true);
alloc_attrs.set_on_host(true);
Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
CHECK(send_dev_context)
<< "send dev name: " << src_dev->name()
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
// "val" is on a GPU. Uses GPUUtil to fill the response proto.
StatusCallback copy_ready = [response, done, copy,
is_dead](const Status& s) {
// The value is now ready to be returned on the wire.
grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
done(s);
delete copy;
};
GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy,
copy_ready);
#else
done(errors::Internal("No GPU device in process"));
#endif // GOOGLE_CUDA
} else {
grpc::EncodeTensorToByteBuffer(is_dead, val, response);
done(Status::OK());
}
}
} else {
// !s.ok()
done(status);
}
});
}
} // namespace tensorflow

View File

@ -0,0 +1,45 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef GDR_WORKER_H_
#define GDR_WORKER_H_
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
namespace tensorflow {
class GdrWorker : public GrpcWorker {
public:
GdrWorker(WorkerEnv* env, RemoteMemoryManager* remote_memory_manager);
// Serve the RecvTensorRequest but omit the tensor content and transmit it
// out-of-band using GPU Direct RDMA whenever possible.
// If it's not possible, it falls back to gRPC in-band tensor transport by
// encoding the tensor content into the grpc::ByteBuffer.
// The RecvTensorResponse will carry the necessary information for RDMA.
virtual void GrpcRecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
::grpc::ByteBuffer* response,
StatusCallback done) override;
private:
RemoteMemoryManager* remote_memory_manager_; // Not owned
};
} // namespace tensorflow
#endif // GDR_WORKER_H_

View File

@ -3570,7 +3570,7 @@ def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
Returns: Returns:
the tensor after 1d conv with un-shared weights, with shape (batch_size, the tensor after 1d conv with un-shared weights, with shape (batch_size,
output_lenght, filters) output_length, filters)
Raises: Raises:
ValueError: if `data_format` is neither `channels_last` or ValueError: if `data_format` is neither `channels_last` or

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import marshal import marshal
import os
import sys import sys
import time import time
import types as python_types import types as python_types
@ -195,7 +196,10 @@ def func_dump(func):
Returns: Returns:
A tuple `(code, defaults, closure)`. A tuple `(code, defaults, closure)`.
""" """
code = marshal.dumps(func.__code__).decode('raw_unicode_escape') if os.name == 'nt':
code = marshal.dumps(func.__code__).replace(b'\\',b'/').decode('raw_unicode_escape')
else:
code = marshal.dumps(func.__code__).decode('raw_unicode_escape')
defaults = func.__defaults__ defaults = func.__defaults__
if func.__closure__: if func.__closure__:
closure = tuple(c.cell_contents for c in func.__closure__) closure = tuple(c.cell_contents for c in func.__closure__)

View File

@ -1944,7 +1944,7 @@ def gdn(inputs,
spatial dimensions. It is similar to local response normalization, but much spatial dimensions. It is similar to local response normalization, but much
more flexible, as `beta` and `gamma` are trainable parameters. more flexible, as `beta` and `gamma` are trainable parameters.
Arguments: Args:
inputs: Tensor input. inputs: Tensor input.
inverse: If `False` (default), compute GDN response. If `True`, compute IGDN inverse: If `False` (default), compute GDN response. If `True`, compute IGDN
response (one step of fixed point iteration to invert GDN; the division response (one step of fixed point iteration to invert GDN; the division

View File

@ -66,11 +66,11 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary as core_summary
from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import device_setter from tensorflow.python.training import device_setter
from tensorflow.python.training import monitored_session from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver from tensorflow.python.training import saver
from tensorflow.python.training import summary_io
from tensorflow.python.training import training_util from tensorflow.python.training import training_util
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator
@ -337,7 +337,7 @@ def _write_dict_to_summary(output_dir, dictionary, current_global_step):
""" """
logging.info('Saving dict for global step %d: %s', current_global_step, logging.info('Saving dict for global step %d: %s', current_global_step,
_dict_to_str(dictionary)) _dict_to_str(dictionary))
summary_writer = summary_io.SummaryWriterCache.get(output_dir) summary_writer = core_summary.FileWriterCache.get(output_dir)
summary_proto = summary_pb2.Summary() summary_proto = summary_pb2.Summary()
for key in dictionary: for key in dictionary:
if dictionary[key] is None: if dictionary[key] is None:
@ -1034,7 +1034,7 @@ class BaseEstimator(
loss = None loss = None
while not mon_sess.should_stop(): while not mon_sess.should_stop():
_, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss]) _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
summary_io.SummaryWriterCache.clear() core_summary.FileWriterCache.clear()
return loss return loss

View File

@ -506,7 +506,7 @@ class EstimatorModelFnTest(test.TestCase):
return input_fn_utils.InputFnOps( return input_fn_utils.InputFnOps(
features, labels, {'examples': serialized_tf_example}) features, labels, {'examples': serialized_tf_example})
est.export_savedmodel(est.model_dir + '/export', serving_input_fn) est.export_savedmodel(os.path.join(est.model_dir, 'export'), serving_input_fn)
self.assertTrue(self.mock_saver.restore.called) self.assertTrue(self.mock_saver.restore.called)
@ -988,10 +988,11 @@ class EstimatorTest(test.TestCase):
self.assertTrue('input_example_tensor' in graph_ops) self.assertTrue('input_example_tensor' in graph_ops)
self.assertTrue('ParseExample/ParseExample' in graph_ops) self.assertTrue('ParseExample/ParseExample' in graph_ops)
self.assertTrue('linear/linear/feature/matmul' in graph_ops) self.assertTrue('linear/linear/feature/matmul' in graph_ops)
self.assertSameElements( self.assertItemsEqual(
['bogus_lookup', 'feature'], ['bogus_lookup', 'feature'],
graph.get_collection( [compat.as_str_any(x) for x in graph.get_collection(
constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS)) constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS)])
# cleanup # cleanup
gfile.DeleteRecursively(tmpdir) gfile.DeleteRecursively(tmpdir)

View File

@ -44,15 +44,16 @@ import six
from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.learn.python.learn import session_run_hook
from tensorflow.contrib.learn.python.learn.summary_writer_cache import SummaryWriterCache from tensorflow.contrib.learn.python.learn.summary_writer_cache import SummaryWriterCache
from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.util.event_pb2 import SessionLog from tensorflow.core.util.event_pb2 import SessionLog
from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as core_summary
from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import summary_io from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect
@ -521,7 +522,7 @@ class SummarySaver(EveryN):
self._summary_op = summary_op self._summary_op = summary_op
self._summary_writer = summary_writer self._summary_writer = summary_writer
if summary_writer is None and output_dir: if summary_writer is None and output_dir:
self._summary_writer = summary_io.SummaryWriter(output_dir) self._summary_writer = core_summary.FileWriter(output_dir)
self._scaffold = scaffold self._scaffold = scaffold
# TODO(mdan): Throw an error if output_dir and summary_writer are None. # TODO(mdan): Throw an error if output_dir and summary_writer are None.
@ -529,7 +530,7 @@ class SummarySaver(EveryN):
super(SummarySaver, self).set_estimator(estimator) super(SummarySaver, self).set_estimator(estimator)
# TODO(mdan): This line looks redundant. # TODO(mdan): This line looks redundant.
if self._summary_writer is None: if self._summary_writer is None:
self._summary_writer = summary_io.SummaryWriter(estimator.model_dir) self._summary_writer = core_summary.FileWriter(estimator.model_dir)
def every_n_step_begin(self, step): def every_n_step_begin(self, step):
super(SummarySaver, self).every_n_step_begin(step) super(SummarySaver, self).every_n_step_begin(step)
@ -1029,7 +1030,7 @@ class CheckpointSaver(BaseMonitor):
logging.info("Create CheckpointSaver.") logging.info("Create CheckpointSaver.")
super(CheckpointSaver, self).__init__() super(CheckpointSaver, self).__init__()
self._saver = saver self._saver = saver
self._summary_writer = SummaryWriterCache.get(checkpoint_dir) self._summary_writer = core_summary.FileWriterCache.get(checkpoint_dir)
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
self._scaffold = scaffold self._scaffold = scaffold
self._save_secs = save_secs self._save_secs = save_secs
@ -1098,12 +1099,12 @@ class StepCounter(EveryN):
self._last_reported_time = None self._last_reported_time = None
self._summary_writer = summary_writer self._summary_writer = summary_writer
if summary_writer is None and output_dir: if summary_writer is None and output_dir:
self._summary_writer = SummaryWriterCache.get(output_dir) self._summary_writer = core_summary.FileWriterCache.get(output_dir)
def set_estimator(self, estimator): def set_estimator(self, estimator):
super(StepCounter, self).set_estimator(estimator) super(StepCounter, self).set_estimator(estimator)
if self._summary_writer is None: if self._summary_writer is None:
self._summary_writer = SummaryWriterCache.get(estimator.model_dir) self._summary_writer = core_summary.FileWriterCache.get(estimator.model_dir)
def every_n_step_end(self, current_step, outputs): def every_n_step_end(self, current_step, outputs):
current_time = time.time() current_time = time.time()
@ -1169,7 +1170,7 @@ class RunHookAdapterForMonitors(session_run_hook.SessionRunHook):
def begin(self): def begin(self):
self._last_step = None self._last_step = None
self._global_step_tensor = contrib_variables.get_global_step() self._global_step_tensor = training_util.get_global_step()
for m in self._monitors: for m in self._monitors:
m.begin(max_steps=None) m.begin(max_steps=None)

View File

@ -27,7 +27,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib import testing from tensorflow.contrib import testing
from tensorflow.contrib.framework.python.framework import checkpoint_utils from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.ops import variables as variables_lib
from tensorflow.contrib.learn.python import learn from tensorflow.contrib.learn.python import learn
from tensorflow.contrib.learn.python.learn import estimators from tensorflow.contrib.learn.python.learn import estimators
from tensorflow.python.client import session as session_lib from tensorflow.python.client import session as session_lib
@ -43,6 +42,7 @@ from tensorflow.python.summary import summary
from tensorflow.python.training import gradient_descent from tensorflow.python.training import gradient_descent
from tensorflow.python.training import monitored_session from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver from tensorflow.python.training import saver
from tensorflow.python.training import training_util
class _MyEveryN(learn.monitors.EveryN): class _MyEveryN(learn.monitors.EveryN):
@ -616,7 +616,7 @@ class CheckpointSaverTest(test.TestCase):
self.graph = ops.Graph() self.graph = ops.Graph()
with self.graph.as_default(): with self.graph.as_default():
self.scaffold = monitored_session.Scaffold() self.scaffold = monitored_session.Scaffold()
self.global_step = variables_lib.get_or_create_global_step() self.global_step = training_util.get_or_create_global_step()
self.train_op = state_ops.assign_add(self.global_step, 1) self.train_op = state_ops.assign_add(self.global_step, 1)
def tearDown(self): def tearDown(self):
@ -780,7 +780,7 @@ class RunHookAdapterForMonitorsTest(test.TestCase):
def test_calls_and_steps(self): def test_calls_and_steps(self):
with ops.Graph().as_default(), session_lib.Session() as sess: with ops.Graph().as_default(), session_lib.Session() as sess:
global_step_tensor = variables_lib.create_global_step() global_step_tensor = training_util.create_global_step()
inc_5 = state_ops.assign_add(global_step_tensor, 5) inc_5 = state_ops.assign_add(global_step_tensor, 5)
mock_mon = FakeMonitor() mock_mon = FakeMonitor()
mock_mon2 = FakeMonitor() mock_mon2 = FakeMonitor()
@ -821,7 +821,7 @@ class RunHookAdapterForMonitorsTest(test.TestCase):
def test_requests(self): def test_requests(self):
with ops.Graph().as_default(), session_lib.Session() as sess: with ops.Graph().as_default(), session_lib.Session() as sess:
variables_lib.create_global_step() training_util.create_global_step()
mock_mon = FakeMonitor() mock_mon = FakeMonitor()
mock_mon2 = FakeMonitor() mock_mon2 = FakeMonitor()

View File

@ -31,6 +31,7 @@ from tensorflow.contrib.session_bundle import exporter
from tensorflow.contrib.session_bundle import manifest_pb2 from tensorflow.contrib.session_bundle import manifest_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
@ -49,9 +50,8 @@ def _training_input_fn():
class ExportTest(test.TestCase): class ExportTest(test.TestCase):
def _get_default_signature(self, export_meta_filename): def _get_default_signature(self, export_meta_filename):
"""Gets the default signature from the export.meta file.""" """ Gets the default signature from the export.meta file. """
with session.Session(): with session.Session():
save = saver.import_meta_graph(export_meta_filename) save = saver.import_meta_graph(export_meta_filename)
meta_graph_def = save.export_meta_graph() meta_graph_def = save.export_meta_graph()
@ -68,18 +68,19 @@ class ExportTest(test.TestCase):
self.assertTrue(gfile.Exists(export_dir)) self.assertTrue(gfile.Exists(export_dir))
# Only the written checkpoints are exported. # Only the written checkpoints are exported.
self.assertTrue( self.assertTrue(
saver.checkpoint_exists(export_dir + '00000001/export'), saver.checkpoint_exists(os.path.join(export_dir, '00000001', 'export')),
'Exported checkpoint expected but not found: %s' % 'Exported checkpoint expected but not found: %s' %
(export_dir + '00000001/export')) os.path.join(export_dir, '00000001', 'export'))
self.assertTrue( self.assertTrue(
saver.checkpoint_exists(export_dir + '00000010/export'), saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')),
'Exported checkpoint expected but not found: %s' % 'Exported checkpoint expected but not found: %s' %
(export_dir + '00000010/export')) os.path.join(export_dir, '00000010', 'export'))
self.assertEquals( self.assertEquals(
six.b(os.path.join(export_dir, '00000010')), six.b(os.path.join(export_dir, '00000010')),
export_monitor.last_export_dir) export_monitor.last_export_dir)
# Validate the signature # Validate the signature
signature = self._get_default_signature(export_dir + '00000010/export.meta') signature = self._get_default_signature(
os.path.join(export_dir, '00000010', 'export.meta'))
self.assertTrue(signature.HasField(expected_signature)) self.assertTrue(signature.HasField(expected_signature))
def testExportMonitor_EstimatorProvidesSignature(self): def testExportMonitor_EstimatorProvidesSignature(self):
@ -88,7 +89,7 @@ class ExportTest(test.TestCase):
y = 2 * x + 3 y = 2 * x + 3
cont_features = [feature_column.real_valued_column('', dimension=1)] cont_features = [feature_column.real_valued_column('', dimension=1)]
regressor = learn.LinearRegressor(feature_columns=cont_features) regressor = learn.LinearRegressor(feature_columns=cont_features)
export_dir = tempfile.mkdtemp() + 'export/' export_dir = os.path.join(tempfile.mkdtemp(), 'export')
export_monitor = learn.monitors.ExportMonitor( export_monitor = learn.monitors.ExportMonitor(
every_n_steps=1, export_dir=export_dir, exports_to_keep=2) every_n_steps=1, export_dir=export_dir, exports_to_keep=2)
regressor.fit(x, y, steps=10, monitors=[export_monitor]) regressor.fit(x, y, steps=10, monitors=[export_monitor])
@ -99,7 +100,7 @@ class ExportTest(test.TestCase):
x = np.random.rand(1000) x = np.random.rand(1000)
y = 2 * x + 3 y = 2 * x + 3
cont_features = [feature_column.real_valued_column('', dimension=1)] cont_features = [feature_column.real_valued_column('', dimension=1)]
export_dir = tempfile.mkdtemp() + 'export/' export_dir = os.path.join(tempfile.mkdtemp(), 'export')
export_monitor = learn.monitors.ExportMonitor( export_monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=export_dir, export_dir=export_dir,
@ -122,7 +123,7 @@ class ExportTest(test.TestCase):
input_feature_key = 'my_example_key' input_feature_key = 'my_example_key'
monitor = learn.monitors.ExportMonitor( monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=tempfile.mkdtemp() + 'export/', export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
input_fn=_serving_input_fn, input_fn=_serving_input_fn,
input_feature_key=input_feature_key, input_feature_key=input_feature_key,
exports_to_keep=2, exports_to_keep=2,
@ -140,7 +141,7 @@ class ExportTest(test.TestCase):
monitor = learn.monitors.ExportMonitor( monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=tempfile.mkdtemp() + 'export/', export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
input_fn=_serving_input_fn, input_fn=_serving_input_fn,
input_feature_key=input_feature_key, input_feature_key=input_feature_key,
exports_to_keep=2, exports_to_keep=2,
@ -165,7 +166,7 @@ class ExportTest(test.TestCase):
monitor = learn.monitors.ExportMonitor( monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=tempfile.mkdtemp() + 'export/', export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
input_fn=_serving_input_fn, input_fn=_serving_input_fn,
input_feature_key=input_feature_key, input_feature_key=input_feature_key,
exports_to_keep=2, exports_to_keep=2,
@ -187,7 +188,7 @@ class ExportTest(test.TestCase):
monitor = learn.monitors.ExportMonitor( monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=tempfile.mkdtemp() + 'export/', export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
input_fn=_serving_input_fn, input_fn=_serving_input_fn,
input_feature_key=input_feature_key, input_feature_key=input_feature_key,
exports_to_keep=2, exports_to_keep=2,
@ -210,7 +211,7 @@ class ExportTest(test.TestCase):
shape=(1,), minval=0.0, maxval=1000.0) shape=(1,), minval=0.0, maxval=1000.0)
}, None }, None
export_dir = tempfile.mkdtemp() + 'export/' export_dir = os.path.join(tempfile.mkdtemp(), 'export')
monitor = learn.monitors.ExportMonitor( monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=export_dir, export_dir=export_dir,
@ -235,7 +236,7 @@ class ExportTest(test.TestCase):
y = 2 * x + 3 y = 2 * x + 3
cont_features = [feature_column.real_valued_column('', dimension=1)] cont_features = [feature_column.real_valued_column('', dimension=1)]
regressor = learn.LinearRegressor(feature_columns=cont_features) regressor = learn.LinearRegressor(feature_columns=cont_features)
export_dir = tempfile.mkdtemp() + 'export/' export_dir = os.path.join(tempfile.mkdtemp(), 'export')
export_monitor = learn.monitors.ExportMonitor( export_monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=export_dir, export_dir=export_dir,
@ -244,10 +245,13 @@ class ExportTest(test.TestCase):
regressor.fit(x, y, steps=10, monitors=[export_monitor]) regressor.fit(x, y, steps=10, monitors=[export_monitor])
self.assertTrue(gfile.Exists(export_dir)) self.assertTrue(gfile.Exists(export_dir))
self.assertFalse(saver.checkpoint_exists(export_dir + '00000000/export')) with self.assertRaises(errors.NotFoundError):
self.assertTrue(saver.checkpoint_exists(export_dir + '00000010/export')) saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export'))
self.assertTrue(
saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')))
# Validate the signature # Validate the signature
signature = self._get_default_signature(export_dir + '00000010/export.meta') signature = self._get_default_signature(
os.path.join(export_dir, '00000010', 'export.meta'))
self.assertTrue(signature.HasField('regression_signature')) self.assertTrue(signature.HasField('regression_signature'))

View File

@ -33,8 +33,13 @@ from tensorflow.python.util import compat
def _create_parser(base_dir): def _create_parser(base_dir):
# create a simple parser that pulls the export_version from the directory. # create a simple parser that pulls the export_version from the directory.
def parser(path): def parser(path):
match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$", # Modify the path object for RegEx match for Windows Paths
compat.as_str_any(path.path)) if os.name == 'nt':
match = re.match("^" + compat.as_str_any(base_dir).replace('\\','/') + "/(\\d+)$",
compat.as_str_any(path.path).replace('\\','/'))
else:
match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$",
compat.as_str_any(path.path))
if not match: if not match:
return None return None
return path._replace(export_version=int(match.group(1))) return path._replace(export_version=int(match.group(1)))
@ -48,13 +53,13 @@ class GcTest(test_util.TensorFlowTestCase):
paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)] paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
newest = gc.largest_export_versions(2) newest = gc.largest_export_versions(2)
n = newest(paths) n = newest(paths)
self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)]) self.assertEqual(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
def testLargestExportVersionsDoesNotDeleteZeroFolder(self): def testLargestExportVersionsDoesNotDeleteZeroFolder(self):
paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)] paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)]
newest = gc.largest_export_versions(2) newest = gc.largest_export_versions(2)
n = newest(paths) n = newest(paths)
self.assertEquals(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)]) self.assertEqual(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
def testModExportVersion(self): def testModExportVersion(self):
paths = [ paths = [
@ -62,9 +67,9 @@ class GcTest(test_util.TensorFlowTestCase):
gc.Path("/foo", 9) gc.Path("/foo", 9)
] ]
mod = gc.mod_export_version(2) mod = gc.mod_export_version(2)
self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)]) self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
mod = gc.mod_export_version(3) mod = gc.mod_export_version(3)
self.assertEquals(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)]) self.assertEqual(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
def testOneOfEveryNExportVersions(self): def testOneOfEveryNExportVersions(self):
paths = [ paths = [
@ -73,7 +78,7 @@ class GcTest(test_util.TensorFlowTestCase):
gc.Path("/foo", 8), gc.Path("/foo", 33) gc.Path("/foo", 8), gc.Path("/foo", 33)
] ]
one_of = gc.one_of_every_n_export_versions(3) one_of = gc.one_of_every_n_export_versions(3)
self.assertEquals( self.assertEqual(
one_of(paths), [ one_of(paths), [
gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 8), gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 8),
gc.Path("/foo", 33) gc.Path("/foo", 33)
@ -84,14 +89,14 @@ class GcTest(test_util.TensorFlowTestCase):
# Test that here. # Test that here.
paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)] paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)]
one_of = gc.one_of_every_n_export_versions(3) one_of = gc.one_of_every_n_export_versions(3)
self.assertEquals(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)]) self.assertEqual(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)])
def testUnion(self): def testUnion(self):
paths = [] paths = []
for i in xrange(10): for i in xrange(10):
paths.append(gc.Path("/foo", i)) paths.append(gc.Path("/foo", i))
f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3)) f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
self.assertEquals( self.assertEqual(
f(paths), [ f(paths), [
gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6),
gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9) gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9)
@ -103,9 +108,9 @@ class GcTest(test_util.TensorFlowTestCase):
gc.Path("/foo", 9) gc.Path("/foo", 9)
] ]
mod = gc.negation(gc.mod_export_version(2)) mod = gc.negation(gc.mod_export_version(2))
self.assertEquals(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)]) self.assertEqual(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
mod = gc.negation(gc.mod_export_version(3)) mod = gc.negation(gc.mod_export_version(3))
self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)]) self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
def testPathsWithParse(self): def testPathsWithParse(self):
base_dir = os.path.join(test.get_temp_dir(), "paths_parse") base_dir = os.path.join(test.get_temp_dir(), "paths_parse")
@ -115,7 +120,7 @@ class GcTest(test_util.TensorFlowTestCase):
# add a base_directory to ignore # add a base_directory to ignore
gfile.MakeDirs(os.path.join(base_dir, "ignore")) gfile.MakeDirs(os.path.join(base_dir, "ignore"))
self.assertEquals( self.assertEqual(
gc.get_paths(base_dir, _create_parser(base_dir)), gc.get_paths(base_dir, _create_parser(base_dir)),
[ [
gc.Path(os.path.join(base_dir, "0"), 0), gc.Path(os.path.join(base_dir, "0"), 0),

View File

@ -57,6 +57,11 @@ REGISTER_KERNEL_BUILDER(Name("BytesLimit").Device(DEVICE_CPU), BytesLimitOp);
REGISTER_KERNEL_BUILDER(Name("BytesLimit").Device(DEVICE_GPU).HostMemory("out"), REGISTER_KERNEL_BUILDER(Name("BytesLimit").Device(DEVICE_GPU).HostMemory("out"),
BytesLimitOp); BytesLimitOp);
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("BytesLimit").Device(DEVICE_SYCL).HostMemory("out"),
BytesLimitOp);
#endif // TENSORFLOW_USE_SYCL
// Op that measures the peak memory in bytes. // Op that measures the peak memory in bytes.
class MaxBytesInUseOp : public MemoryStatsOp { class MaxBytesInUseOp : public MemoryStatsOp {
public: public:
@ -76,4 +81,10 @@ REGISTER_KERNEL_BUILDER(
Name("MaxBytesInUse").Device(DEVICE_GPU).HostMemory("out"), Name("MaxBytesInUse").Device(DEVICE_GPU).HostMemory("out"),
MaxBytesInUseOp); MaxBytesInUseOp);
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(
Name("MaxBytesInUse").Device(DEVICE_SYCL).HostMemory("out"),
MaxBytesInUseOp);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow } // namespace tensorflow

View File

@ -20,6 +20,8 @@ limitations under the License.
#include <string> #include <string>
#include <utility> #include <utility>
#include "grpc/support/alloc.h"
#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"

View File

@ -61,7 +61,7 @@ void MPIUtils::InitMPI() {
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &number_of_procs)); MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &number_of_procs));
MPI_CHECK(MPI_Get_processor_name(my_host_name, &len)); MPI_CHECK(MPI_Get_processor_name(my_host_name, &len));
fprintf(stderr, fprintf(stderr,
"MPI Environment initialised. Process id: %d Total processes: %d " "MPI Environment initialized. Process id: %d Total processes: %d "
"|| Hostname: %s \n", "|| Hostname: %s \n",
proc_id, number_of_procs, my_host_name); proc_id, number_of_procs, my_host_name);
} }

View File

@ -43,7 +43,7 @@ class AllReduceTest(test.TestCase):
self._testSingleAllReduce(sess, dtype, nccl.all_max, np.maximum) self._testSingleAllReduce(sess, dtype, nccl.all_max, np.maximum)
def _testSingleAllReduce(self, sess, np_type, nccl_fn, numpy_accumulation_fn): def _testSingleAllReduce(self, sess, np_type, nccl_fn, numpy_accumulation_fn):
for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]: for devices in [['/device:GPU:0', '/device:GPU:0', '/device:GPU:0'], ['/device:GPU:0', '/device:GPU:0']]:
shape = (3, 4) shape = (3, 4)
np_ans = None np_ans = None
tensors = [] tensors = []
@ -84,7 +84,7 @@ class BroadcastTest(test.TestCase):
# Create session inside outer loop to test use of # Create session inside outer loop to test use of
# same communicator across multiple sessions. # same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess: with self.test_session(use_gpu=True) as sess:
for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]: for devices in [['/device:GPU:0', '/device:GPU:0', '/device:GPU:0'], ['/device:GPU:0', '/device:GPU:0']]:
shape = (3, 4) shape = (3, 4)
sender = np.random.randint(0, len(devices) - 1) sender = np.random.randint(0, len(devices) - 1)
with ops.device(devices[sender]): with ops.device(devices[sender]):
@ -115,7 +115,7 @@ class CombinedTest(test.TestCase):
# Create session inside outer loop to test use of # Create session inside outer loop to test use of
# same communicator across multiple sessions. # same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess: with self.test_session(use_gpu=True) as sess:
for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]: for devices in [['/device:GPU:0', '/device:GPU:0', '/device:GPU:0'], ['/device:GPU:0', '/device:GPU:0']]:
shape = (3, 4) shape = (3, 4)
# all-reduce # all-reduce

View File

@ -15,6 +15,7 @@ py_library(
"__init__.py", "__init__.py",
"python/__init__.py", "python/__init__.py",
"python/ops/__init__.py", "python/ops/__init__.py",
"python/ops/alpha_dropout.py",
"python/ops/cross_entropy.py", "python/ops/cross_entropy.py",
"python/ops/sampling_ops.py", "python/ops/sampling_ops.py",
], ],
@ -44,6 +45,23 @@ py_test(
], ],
) )
py_test(
name = "alpha_dropout_test",
size = "small",
srcs = ["python/ops/alpha_dropout_test.py"],
srcs_version = "PY2AND3",
deps = [
":nn_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:nn",
"//tensorflow/python:random_ops",
],
)
filegroup( filegroup(
name = "all_files", name = "all_files",
srcs = glob( srcs = glob(

View File

@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""Module for variants of ops in tf.nn. """Module for variants of ops in tf.nn.
@@alpha_dropout
@@deprecated_flipped_softmax_cross_entropy_with_logits @@deprecated_flipped_softmax_cross_entropy_with_logits
@@deprecated_flipped_sparse_softmax_cross_entropy_with_logits @@deprecated_flipped_sparse_softmax_cross_entropy_with_logits
@@deprecated_flipped_sigmoid_cross_entropy_with_logits @@deprecated_flipped_sigmoid_cross_entropy_with_logits
@ -27,6 +28,7 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import # pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.nn.python.ops.cross_entropy import * from tensorflow.contrib.nn.python.ops.cross_entropy import *
from tensorflow.contrib.nn.python.ops.sampling_ops import * from tensorflow.contrib.nn.python.ops.sampling_ops import *
from tensorflow.contrib.nn.python.ops.alpha_dropout import *
# pylint: enable=unused-import,wildcard-import # pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.util.all_util import remove_undocumented

View File

@ -0,0 +1,88 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numbers
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
def alpha_dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name
"""Computes alpha dropout.
Alpha Dropout is a dropout that maintains the self-normalizing property. For
an input with zero mean and unit standard deviation, the output of
Alpha Dropout maintains the original mean and standard deviation of the input.
See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
Args:
x: A tensor.
keep_prob: A scalar `Tensor` with the same type as x. The probability
that each element is kept.
noise_shape: A 1-D `Tensor` of type `int32`, representing the
shape for randomly generated keep/drop flags.
seed: A Python integer. Used to create random seeds. See
@{tf.set_random_seed} for behavior.
name: A name for this operation (optional).
Returns:
A Tensor of the same shape of `x`.
Raises:
ValueError: If `keep_prob` is not in `(0, 1]`.
"""
with ops.name_scope(name, "alpha_dropout", [x]) as name:
x = ops.convert_to_tensor(x, name="x")
if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1.:
raise ValueError("keep_prob must be a scalar tensor or a float in the "
"range (0, 1], got %g" % keep_prob)
keep_prob = ops.convert_to_tensor(keep_prob,
dtype=x.dtype,
name="keep_prob")
keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())
# Do nothing if we know keep_prob == 1
if tensor_util.constant_value(keep_prob) == 1:
return x
alpha = -1.7580993408473766
noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x)
random_tensor = random_ops.random_uniform(noise_shape,
seed=seed,
dtype=x.dtype)
kept_idx = gen_math_ops.greater_equal(random_tensor, 1 - keep_prob)
kept_idx = math_ops.cast(kept_idx, x.dtype)
# Mask
x = x * kept_idx + alpha * (1 - kept_idx)
# Affine transformation parameters
a = (keep_prob + keep_prob * (1 - keep_prob) * alpha ** 2) ** -0.5
b = -a * alpha * (1 - keep_prob)
# Affine transformation
return a * x + b

View File

@ -0,0 +1,88 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for sampling_ops.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.nn.python.ops.alpha_dropout import alpha_dropout
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.platform import test
class AlphaDropoutTest(test.TestCase):
def testAlphaDropout(self):
x_dim, y_dim = 40, 30
for keep_prob in [0.1, 0.5, 0.8]:
with self.test_session():
t = random_ops.random_normal([x_dim, y_dim])
output = alpha_dropout(t, keep_prob)
self.assertEqual([x_dim, y_dim], output.get_shape())
t_mean, t_std = nn_impl.moments(t, axes=[0, 1])
output_mean, output_std = nn_impl.moments(output, axes=[0, 1])
self.assertLess(abs(t_mean.eval() - output_mean.eval()), 0.1)
self.assertLess(abs(t_std.eval() - output_std.eval()), 0.1)
def testShapedDropoutShapeError(self):
# Runs shaped dropout and verifies an error is thrown on misshapen noise.
x_dim = 40
y_dim = 30
keep_prob = 0.5
t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
with self.assertRaises(ValueError):
_ = alpha_dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
with self.assertRaises(ValueError):
_ = alpha_dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5])
with self.assertRaises(ValueError):
_ = alpha_dropout(t, keep_prob, noise_shape=[x_dim + 3])
with self.assertRaises(ValueError):
_ = alpha_dropout(t, keep_prob, noise_shape=[x_dim])
# test that broadcasting proceeds
_ = alpha_dropout(t, keep_prob, noise_shape=[y_dim])
_ = alpha_dropout(t, keep_prob, noise_shape=[1, y_dim])
_ = alpha_dropout(t, keep_prob, noise_shape=[x_dim, 1])
_ = alpha_dropout(t, keep_prob, noise_shape=[1, 1])
def testInvalidKeepProb(self):
x_dim, y_dim = 40, 30
t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
with self.assertRaises(ValueError):
alpha_dropout(t, -1.0)
with self.assertRaises(ValueError):
alpha_dropout(t, 1.1)
with self.assertRaises(ValueError):
alpha_dropout(t, [0.0, 1.0])
with self.assertRaises(ValueError):
alpha_dropout(t, array_ops.placeholder(dtypes.float64))
with self.assertRaises(ValueError):
alpha_dropout(t, array_ops.placeholder(dtypes.float32, shape=[2]))
def testNoDropoutFast(self):
x = array_ops.zeros((5,))
for p in 1, constant_op.constant(1.0):
y = alpha_dropout(x, keep_prob=p)
self.assertTrue(x is y)
if __name__ == '__main__':
test.main()

View File

@ -50,6 +50,10 @@ See @{$python/contrib.rnn} guide.
@@UGRNNCell @@UGRNNCell
@@IntersectionRNNCell @@IntersectionRNNCell
@@PhasedLSTMCell @@PhasedLSTMCell
@@ConvLSTMCell
@@Conv1DLSTMCell
@@Conv2DLSTMCell
@@Conv3DLSTMCell
@@HighwayWrapper @@HighwayWrapper
@@GLSTMCell @@GLSTMCell

View File

@ -40,6 +40,7 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.framework import test_util
# pylint: enable=protected-access # pylint: enable=protected-access
@ -445,11 +446,12 @@ class RNNCellTest(test.TestCase):
# Can't perform this test w/o a GPU # Can't perform this test w/o a GPU
return return
gpu_dev = test.gpu_device_name()
with self.test_session(use_gpu=True) as sess: with self.test_session(use_gpu=True) as sess:
with variable_scope.variable_scope( with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)): "root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 1, 3]) x = array_ops.zeros([1, 1, 3])
cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), "/gpu:0") cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), gpu_dev)
with ops.device("/cpu:0"): with ops.device("/cpu:0"):
outputs, _ = rnn.dynamic_rnn( outputs, _ = rnn.dynamic_rnn(
cell=cell, inputs=x, dtype=dtypes.float32) cell=cell, inputs=x, dtype=dtypes.float32)
@ -461,7 +463,7 @@ class RNNCellTest(test.TestCase):
_ = sess.run(outputs, options=opts, run_metadata=run_metadata) _ = sess.run(outputs, options=opts, run_metadata=run_metadata)
step_stats = run_metadata.step_stats step_stats = run_metadata.step_stats
ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1 ix = 0 if gpu_dev in step_stats.dev_stats[0].device else 1
gpu_stats = step_stats.dev_stats[ix].node_stats gpu_stats = step_stats.dev_stats[ix].node_stats
cpu_stats = step_stats.dev_stats[1 - ix].node_stats cpu_stats = step_stats.dev_stats[1 - ix].node_stats
self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name]) self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name])

View File

@ -42,7 +42,6 @@ from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging from tensorflow.python.platform import tf_logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.framework import test_util
class Plus1RNNCell(rnn_lib.RNNCell): class Plus1RNNCell(rnn_lib.RNNCell):
"""RNN Cell generating (output, new_state) = (input + 1, state + 1).""" """RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
@ -2208,11 +2207,11 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
if not test.is_gpu_available(): if not test.is_gpu_available():
return # Test requires access to a GPU return # Test requires access to a GPU
gpu_dev = test.gpu_device_name()
run_metadata = self._execute_rnn_on( run_metadata = self._execute_rnn_on(
rnn_device="/cpu:0", cell_device=test_util.gpu_device_name()) rnn_device="/cpu:0", cell_device=gpu_dev)
step_stats = run_metadata.step_stats step_stats = run_metadata.step_stats
ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1
("sycl" in step_stats.dev_stats[0].device)) else 1
gpu_stats = step_stats.dev_stats[ix].node_stats gpu_stats = step_stats.dev_stats[ix].node_stats
cpu_stats = step_stats.dev_stats[1 - ix].node_stats cpu_stats = step_stats.dev_stats[1 - ix].node_stats
@ -2233,12 +2232,12 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
if not test.is_gpu_available(): if not test.is_gpu_available():
return # Test requires access to a GPU return # Test requires access to a GPU
gpu_dev = test.gpu_device_name()
run_metadata = self._execute_rnn_on( run_metadata = self._execute_rnn_on(
rnn_device="/cpu:0", cell_device="/cpu:0", rnn_device="/cpu:0", cell_device="/cpu:0",
input_device=test_util.gpu_device_name()) input_device=gpu_dev)
step_stats = run_metadata.step_stats step_stats = run_metadata.step_stats
ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1
("sycl" in step_stats.dev_stats[0].device)) else 1
gpu_stats = step_stats.dev_stats[ix].node_stats gpu_stats = step_stats.dev_stats[ix].node_stats
cpu_stats = step_stats.dev_stats[1 - ix].node_stats cpu_stats = step_stats.dev_stats[1 - ix].node_stats
@ -2253,11 +2252,11 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
if not test.is_gpu_available(): if not test.is_gpu_available():
return # Test requires access to a GPU return # Test requires access to a GPU
gpu_dev = test.gpu_device_name()
run_metadata = self._execute_rnn_on( run_metadata = self._execute_rnn_on(
input_device=test_util.gpu_device_name()) input_device=gpu_dev)
step_stats = run_metadata.step_stats step_stats = run_metadata.step_stats
ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1
("sycl" in step_stats.dev_stats[0].device)) else 1
gpu_stats = step_stats.dev_stats[ix].node_stats gpu_stats = step_stats.dev_stats[ix].node_stats
cpu_stats = step_stats.dev_stats[1 - ix].node_stats cpu_stats = step_stats.dev_stats[1 - ix].node_stats

View File

@ -357,7 +357,7 @@ def training_gru_block_vs_gru_cell(batch_size,
ops.reset_default_graph() ops.reset_default_graph()
with session.Session(graph=ops.Graph()) as sess: with session.Session(graph=ops.Graph()) as sess:
# Specify the device which is been used. # Specify the device which is been used.
with ops.device("/cpu:0" if not use_gpu else "/gpu:0"): with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"):
# Random initializers. # Random initializers.
seed = 1994 seed = 1994
@ -429,7 +429,7 @@ def inference_gru_block_vs_gru_cell(batch_size,
"""Benchmark inference speed between GRUBlockCell vs GRUCell.""" """Benchmark inference speed between GRUBlockCell vs GRUCell."""
ops.reset_default_graph() ops.reset_default_graph()
with session.Session(graph=ops.Graph()) as sess: with session.Session(graph=ops.Graph()) as sess:
with ops.device("/cpu:0" if not use_gpu else "/gpu:0"): with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"):
# Random initializers. # Random initializers.
seed = 1994 seed = 1994
@ -484,7 +484,7 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size,
"""Benchmark single bprop step speed between GRUBlockCell vs GRUCell.""" """Benchmark single bprop step speed between GRUBlockCell vs GRUCell."""
ops.reset_default_graph() ops.reset_default_graph()
with session.Session(graph=ops.Graph()) as sess: with session.Session(graph=ops.Graph()) as sess:
with ops.device("/cpu:0" if not use_gpu else "/gpu:0"): with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"):
initializer = init_ops.random_uniform_initializer(-1, 1, seed=1989) initializer = init_ops.random_uniform_initializer(-1, 1, seed=1989)
# Inputs # Inputs
x = vs.get_variable("x", [batch_size, input_size]) x = vs.get_variable("x", [batch_size, input_size])

View File

@ -875,6 +875,152 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].c, expected_state_c) self.assertAllClose(res[1].c, expected_state_c)
self.assertAllClose(res[1].h, expected_state_h) self.assertAllClose(res[1].h, expected_state_h)
def testConv1DLSTMCell(self):
with self.test_session() as sess:
shape = [2,1]
filter_size = [3]
num_features = 1
batch_size = 2
expected_state_c = np.array(
[[[1.4375670191], [1.4375670191]],
[[2.7542609292], [2.7542609292]]],
dtype=np.float32)
expected_state_h = np.array(
[[[0.6529865603], [0.6529865603]],
[[0.8736877431], [0.8736877431]]],
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(1.0/2.0)):
x = array_ops.placeholder(dtypes.float32, [None, None, 1])
cell = contrib_rnn_cell.Conv1DLSTMCell(input_shape=shape,
kernel_shape=filter_size,
output_channels=num_features)
hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
output, state = cell(x, hidden)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state], {
hidden[0].name:
np.array([[[1.],[1.]],
[[2.],[2.]]]),
x.name:
np.array([[[1.],[1.]],
[[2.],[2.]]]),
})
# This is a smoke test, making sure expected values are unchanged.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], res[1].h)
self.assertAllClose(res[1].c, expected_state_c)
self.assertAllClose(res[1].h, expected_state_h)
def testConv2DLSTMCell(self):
with self.test_session() as sess:
shape = [2,2,1]
filter_size = [3,3]
num_features = 1
batch_size = 2
expected_state_c = np.array(
[[[[1.4375670191], [1.4375670191]],
[[1.4375670191], [1.4375670191]]],
[[[2.7542609292], [2.7542609292]],
[[2.7542609292], [2.7542609292]]]],
dtype=np.float32)
expected_state_h = np.array(
[[[[0.6529865603], [0.6529865603]],
[[0.6529865603], [0.6529865603]]],
[[[0.8736877431], [0.8736877431]],
[[0.8736877431], [0.8736877431]]]],
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(1.0/4.0)):
x = array_ops.placeholder(dtypes.float32, [None, None, None, 1])
cell = contrib_rnn_cell.Conv2DLSTMCell(input_shape=shape,
kernel_shape=filter_size,
output_channels=num_features)
hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
output, state = cell(x, hidden)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state], {
hidden[0].name:
np.array([[[[1.],[1.]],
[[1.],[1.]]],
[[[2.],[2.]],
[[2.],[2.]]]]),
x.name:
np.array([[[[1.],[1.]],
[[1.],[1.]]],
[[[2.],[2.]],
[[2.],[2.]]]]),
})
# This is a smoke test, making sure expected values are unchanged.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], res[1].h)
self.assertAllClose(res[1].c, expected_state_c)
self.assertAllClose(res[1].h, expected_state_h)
def testConv3DLSTMCell(self):
with self.test_session() as sess:
shape = [2,2,2,1]
filter_size = [3,3,3]
num_features = 1
batch_size = 2
expected_state_c = np.array(
[[[[[1.4375670191], [1.4375670191]],
[[1.4375670191], [1.4375670191]]],
[[[1.4375670191], [1.4375670191]],
[[1.4375670191], [1.4375670191]]]],
[[[[2.7542609292], [2.7542609292]],
[[2.7542609292], [2.7542609292]]],
[[[2.7542609292], [2.7542609292]],
[[2.7542609292], [2.7542609292]]]]],
dtype=np.float32)
expected_state_h = np.array(
[[[[[0.6529865603], [0.6529865603]],
[[0.6529865603], [0.6529865603]]],
[[[0.6529865603], [0.6529865603]],
[[0.6529865603], [0.6529865603]]]],
[[[[0.8736877431], [0.8736877431]],
[[0.8736877431], [0.8736877431]]],
[[[0.8736877431], [0.8736877431]],
[[0.8736877431], [0.8736877431]]]]],
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(1.0/8.0)):
x = array_ops.placeholder(dtypes.float32, [None, None, None, None, 1])
cell = contrib_rnn_cell.Conv3DLSTMCell(input_shape=shape,
kernel_shape=filter_size,
output_channels=num_features)
hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
output, state = cell(x, hidden)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state], {
hidden[0].name:
np.array([[[[[1.],[1.]],
[[1.],[1.]]],
[[[1.],[1.]],
[[1.],[1.]]]],
[[[[2.],[2.]],
[[2.],[2.]]],
[[[2.],[2.]],
[[2.],[2.]]]]]),
x.name:
np.array([[[[[1.],[1.]],
[[1.],[1.]]],
[[[1.],[1.]],
[[1.],[1.]]]],
[[[[2.],[2.]],
[[2.],[2.]]],
[[[2.],[2.]],
[[2.],[2.]]]]])
})
# This is a smoke test, making sure expected values are unchanged.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], res[1].h)
self.assertAllClose(res[1].c, expected_state_c)
self.assertAllClose(res[1].h, expected_state_h)
def testHighwayWrapper(self): def testHighwayWrapper(self):
with self.test_session() as sess: with self.test_session() as sess:
with variable_scope.variable_scope( with variable_scope.variable_scope(

View File

@ -26,6 +26,7 @@ from tensorflow.contrib.layers.python.layers import layers
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
@ -1921,6 +1922,181 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell):
return new_h, new_state return new_h, new_state
class ConvLSTMCell(rnn_cell_impl.RNNCell):
"""Convolutional LSTM recurrent network cell.
https://arxiv.org/pdf/1506.04214v1.pdf
"""
def __init__(self,
conv_ndims,
input_shape,
output_channels,
kernel_shape,
use_bias=True,
skip_connection=False,
forget_bias=1.0,
initializers=None,
name="conv_lstm_cell"):
"""Construct ConvLSTMCell.
Args:
conv_ndims: Convolution dimensionality (1, 2 or 3).
input_shape: Shape of the input as int tuple, excluding the batch size.
output_channels: int, number of output channels of the conv LSTM.
kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3).
use_bias: Use bias in convolutions.
skip_connection: If set to `True`, concatenate the input to the
output of the conv LSTM. Default: `False`.
forget_bias: Forget bias.
name: Name of the module.
Raises:
ValueError: If `skip_connection` is `True` and stride is different from 1
or if `input_shape` is incompatible with `conv_ndims`.
"""
super(ConvLSTMCell, self).__init__(name=name)
if conv_ndims != len(input_shape)-1:
raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(
input_shape, conv_ndims))
self._conv_ndims = conv_ndims
self._input_shape = input_shape
self._output_channels = output_channels
self._kernel_shape = kernel_shape
self._use_bias = use_bias
self._forget_bias = forget_bias
self._skip_connection = skip_connection
self._total_output_channels = output_channels
if self._skip_connection:
self._total_output_channels += self._input_shape[-1]
state_size = tensor_shape.TensorShape(self._input_shape[:-1]
+ [self._output_channels])
self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
self._output_size = tensor_shape.TensorShape(self._input_shape[:-1]
+ [self._total_output_channels])
@property
def output_size(self):
return self._output_size
@property
def state_size(self):
return self._state_size
def call(self, inputs, state, scope=None):
cell, hidden = state
new_hidden = _conv([inputs, hidden],
self._kernel_shape,
4*self._output_channels,
self._use_bias)
gates = array_ops.split(value=new_hidden,
num_or_size_splits=4,
axis=self._conv_ndims+1)
input_gate, new_input, forget_gate, output_gate = gates
new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell
new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input)
output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate)
if self._skip_connection:
output = array_ops.concat([output, inputs], axis=-1)
new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output)
return output, new_state
class Conv1DLSTMCell(ConvLSTMCell):
"""1D Convolutional LSTM recurrent network cell.
https://arxiv.org/pdf/1506.04214v1.pdf
"""
def __init__(self, name="conv_1d_lstm_cell", **kwargs):
"""Construct Conv1DLSTM. See `ConvLSTMCell` for more details."""
super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs)
class Conv2DLSTMCell(ConvLSTMCell):
"""2D Convolutional LSTM recurrent network cell.
https://arxiv.org/pdf/1506.04214v1.pdf
"""
def __init__(self, name="conv_2d_lstm_cell", **kwargs):
"""Construct Conv2DLSTM. See `ConvLSTMCell` for more details."""
super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs)
class Conv3DLSTMCell(ConvLSTMCell):
"""3D Convolutional LSTM recurrent network cell.
https://arxiv.org/pdf/1506.04214v1.pdf
"""
def __init__(self, name="conv_3d_lstm_cell", **kwargs):
"""Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs)
def _conv(args,
filter_size,
num_features,
bias,
bias_start=0.0):
"""convolution:
Args:
args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
batch x n, Tensors.
filter_size: int tuple of filter height and width.
num_features: int, number of features.
bias_start: starting value to initialize the bias; 0 by default.
Returns:
A 3D, 4D, or 5D Tensor with shape [batch ... num_features]
Raises:
ValueError: if some of the arguments has unspecified or wrong shape.
"""
# Calculate the total size of arguments on dimension 1.
total_arg_size_depth = 0
shapes = [a.get_shape().as_list() for a in args]
shape_length = len(shapes[0])
for shape in shapes:
if len(shape) not in [3,4,5]:
raise ValueError("Conv Linear expects 3D, 4D or 5D arguments: %s" % str(shapes))
if len(shape) != len(shapes[0]):
raise ValueError("Conv Linear expects all args to be of same Dimensiton: %s" % str(shapes))
else:
total_arg_size_depth += shape[-1]
dtype = [a.dtype for a in args][0]
# determine correct conv operation
if shape_length == 3:
conv_op = nn_ops.conv1d
strides = 1
elif shape_length == 4:
conv_op = nn_ops.conv2d
strides = shape_length*[1]
elif shape_length == 5:
conv_op = nn_ops.conv3d
strides = shape_length*[1]
# Now the computation.
kernel = vs.get_variable(
"kernel",
filter_size + [total_arg_size_depth, num_features],
dtype=dtype)
if len(args) == 1:
res = conv_op(args[0],
kernel,
strides,
padding='SAME')
else:
res = conv_op(array_ops.concat(axis=shape_length-1, values=args),
kernel,
strides,
padding='SAME')
if not bias:
return res
bias_term = vs.get_variable(
"biases", [num_features],
dtype=dtype,
initializer=init_ops.constant_initializer(
bias_start, dtype=dtype))
return res + bias_term
class GLSTMCell(rnn_cell_impl.RNNCell): class GLSTMCell(rnn_cell_impl.RNNCell):
"""Group LSTM cell (G-LSTM). """Group LSTM cell (G-LSTM).

View File

@ -78,7 +78,7 @@ class GatherTreeTest(test.TestCase):
sequence_length = [[3, 3, 3]] sequence_length = [[3, 3, 3]]
expected_result = _transpose_batch_time( expected_result = _transpose_batch_time(
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
with ops.device("/gpu:0"): with ops.device("/device:GPU:0"):
beams = beam_search_ops.gather_tree( beams = beam_search_ops.gather_tree(
step_ids=step_ids, parent_ids=parent_ids, step_ids=step_ids, parent_ids=parent_ids,
sequence_length=sequence_length) sequence_length=sequence_length)

View File

@ -979,9 +979,9 @@ def _compute_attention(attention_mechanism, cell_output, previous_alignments,
# alignments shape is # alignments shape is
# [batch_size, 1, memory_time] # [batch_size, 1, memory_time]
# attention_mechanism.values shape is # attention_mechanism.values shape is
# [batch_size, memory_time, attention_mechanism.num_units] # [batch_size, memory_time, memory_size]
# the batched matmul is over memory_time, so the output shape is # the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, attention_mechanism.num_units]. # [batch_size, 1, memory_size].
# we then squeeze out the singleton dim. # we then squeeze out the singleton dim.
context = math_ops.matmul(expanded_alignments, attention_mechanism.values) context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
context = array_ops.squeeze(context, [1]) context = array_ops.squeeze(context, [1])

View File

@ -301,7 +301,12 @@ class Exporter(object):
if exports_to_keep: if exports_to_keep:
# create a simple parser that pulls the export_version from the directory. # create a simple parser that pulls the export_version from the directory.
def parser(path): def parser(path):
match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path) if os.name == 'nt':
match = re.match("^" + export_dir_base.replace('\\','/') + "/(\\d{8})$",
path.path.replace('\\','/'))
else:
match = re.match("^" + export_dir_base + "/(\\d{8})$",
path.path)
if not match: if not match:
return None return None
return path._replace(export_version=int(match.group(1))) return path._replace(export_version=int(match.group(1)))

View File

@ -0,0 +1,187 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SGDR learning rate decay function."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops, control_flow_ops
def sgdr_decay(learning_rate, global_step, initial_period_steps,
t_mul=2.0, m_mul=1.0, name=None):
"""Implements Stochastic Gradient Descent with Warm Restarts (SGDR).
As described in "SGDR: Stochastic Gradient Descent
with Warm Restarts" by Ilya Loshchilov & Frank Hutter, Proceedings of
ICLR'2017, available at https://arxiv.org/pdf/1608.03983.pdf
The learning rate decreases according to cosine annealing:
```python
learning_rate * 0.5 * (1 + cos(x_val * pi)) # for x_val defined in [0, 1]
```
Thus, at the beginning (when the restart index i = 0),
the learning rate decreases for `initial_period_steps` steps from the initial
learning rate `learning_rate` (when `x_val=0`, we get `cos(0)=1`) to
0 (when `x_val=1`, we get `cos(pi)=-1`).
The decrease within the i-th period takes `t_i` steps,
where `t_0` = `initial_period_steps` is the user-defined number of batch
iterations (not epochs as in the paper) to be performed before the first
restart is launched.
Then, we perform the first restart (i=1) by setting the learning rate to
`learning_rate*(m_mul^i)`, where `m_mul in [0,1]` (set to 1 by default).
The i-th restart runs for `t_i=t_0*(t_mul^i)` steps, i.e., every new
restart runs `t_mul` times longer than the previous one.
Importantly, when one has no access to a validation set, SGDR suggests
to report the best expected / recommended solution in the following way:
When we are within our initial run (i=0), every new solution represents
SGDR's recommended solution. Instead, when i>0, the recommended solution is
the one obtained at the end of each restart.
Note that the minimum learning rate is set to 0 for simplicity,
you can adjust the code to deal with any positive minimum learning rate
as defined in the paper.
`initial_period_steps` is the duration of the first period measured in terms
of number of minibatch updates. If one wants to use epochs, one should compute
the number of updates required for an epoch.
For example, assume the following parameters and intention:
Minibatch size: 100
Training dataset size: 10000
If the user wants the first decay period to span across 5 epochs, then
`initial_period_steps` = 5 * 10000/100 = 500
Train for 10000 batch iterations with the initial learning rate set to
0.1, then restart to run 2 times longer, i.e, for 20000 batch iterations
and with the initial learning rate 0.05, then restart again and again,
doubling the runtime of each new period and with two times smaller
initial learning rate.
To accomplish the above, one would write:
```python
...
global_step = tf.Variable(0, trainable=False)
starter_learning_rate = 0.1
learning_rate = sgdr_decay(starter_learning_rate, global_step,
initial_period_steps=10000, t_mul=2, m_mul=0.5)
# Passing global_step to minimize() will increment it at each step.
learning_step = (
tf.train.GradientDescentOptimizer(learning_rate)
.minimize(...my loss..., global_step=global_step)
)
# Step | 0 | 1000 | 5000 | 9000 | 9999 | 10000 | 11000 |
# LR | 0.1 | 0.097 | 0.05 | 0.002 | 0.00 | 0.05 | 0.0496 |
# Step | 20000 | 29000 | 29999 | 30000 |
# LR | 0.025 | 0.0003 | 0.00 | 0.025 |
```
Args:
learning_rate: A scalar `float32` or `float64` `Tensor` or a
Python number. The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
Global step to use for the decay computation. Must not be negative.
initial_period_steps: Duration of the first period measured as the number
of minibatch updates, if one wants to use epochs, one should compute
the number of updates required for an epoch.
t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
Must be positive.
Used to derive the number of iterations in the i-th period:
`initial_period_steps * (t_mul^i)`. Defaults to 2.0.
m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
Must be positive.
Used to derive the initial learning rate of the i-th period:
`learning_rate * (m_mul^i)`. Defaults to 1.0
Returns:
A scalar `Tensor` of the same type as `learning_rate`.
The learning rate for a provided global_step.
Raises:
ValueError: if `global_step` is not supplied.
"""
if global_step is None:
raise ValueError("global_step is required for sgdr_decay.")
with ops.name_scope(name, "SGDRDecay",
[learning_rate, global_step,
initial_period_steps, t_mul, m_mul]) as name:
learning_rate = ops.convert_to_tensor(learning_rate,
name="initial_learning_rate")
dtype = learning_rate.dtype
global_step = math_ops.cast(global_step, dtype)
t_0 = math_ops.cast(initial_period_steps, dtype)
t_mul = math_ops.cast(t_mul, dtype)
m_mul = math_ops.cast(m_mul, dtype)
c_one = math_ops.cast(constant_op.constant(1.0), dtype)
c_half = math_ops.cast(constant_op.constant(0.5), dtype)
c_pi = math_ops.cast(constant_op.constant(math.pi), dtype)
# Find normalized value of the current step
x_val = math_ops.div(global_step, t_0)
def compute_step(x_val, geometric=False):
if geometric:
# Consider geometric series where t_mul != 1
# 1 + t_mul + t_mul^2 ... = (1 - t_mul^i_restart) / (1 - t_mul)
# First find how many restarts were performed for a given x_val
# Find maximal integer i_restart value for which this equation holds
# x_val >= (1 - t_mul^i_restart) / (1 - t_mul)
# x_val * (1 - t_mul) <= (1 - t_mul^i_restart)
# t_mul^i_restart <= (1 - x_val * (1 - t_mul))
# tensorflow allows only log with base e
# i_restart <= log(1 - x_val * (1 - t_mul) / log(t_mul)
# Find how many restarts were performed
i_restart = math_ops.floor(
math_ops.log(c_one - x_val * (c_one - t_mul)) / math_ops.log(t_mul))
# Compute the sum of all restarts before the current one
sum_r = (c_one - t_mul ** i_restart) / (c_one - t_mul)
# Compute our position within the current restart
x_val = (x_val - sum_r) / t_mul ** i_restart
else:
# Find how many restarts were performed
i_restart = math_ops.floor(x_val)
# Compute our position within the current restart
x_val = x_val - i_restart
return i_restart, x_val
i_restart, x_val = control_flow_ops.cond(
math_ops.equal(t_mul, c_one),
lambda: compute_step(x_val, geometric=False),
lambda: compute_step(x_val, geometric=True))
# If m_mul < 1, then the initial learning rate of every new restart will be
# smaller, i.e., by a factor of m_mul ** i_restart at i_restart-th restart
m_fac = learning_rate * (m_mul ** i_restart)
return math_ops.multiply(c_half * m_fac,
(math_ops.cos(x_val * c_pi) + c_one), name=name)

View File

@ -0,0 +1,145 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional test for sgdr learning rate decay."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from sgdr_learning_rate_decay import sgdr_decay
from tensorflow.python.platform import googletest
from tensorflow.python.framework import test_util
from tensorflow.python.framework import dtypes
from tensorflow import placeholder
class SGDRDecayTest(test_util.TensorFlowTestCase):
"""Unit tests for SGDR learning rate decay."""
def get_original_values(self, lr, t_e, mult_factor, iter_per_epoch, epochs):
"""Get an array with learning rate values from the consecutive steps using
the original implementation
(https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
t0 = math.pi / 2.0
tt = 0
te_next = t_e
lr_values = []
sh_lr = lr
for epoch in range(epochs):
for _ in range(iter_per_epoch):
# In the original approach training function is executed here
lr_values.append(sh_lr)
dt = 2.0 * math.pi / float(2.0 * t_e)
tt = tt + float(dt) / iter_per_epoch
if tt >= math.pi:
tt = tt - math.pi
cur_t = t0 + tt
new_lr = lr * (1.0 + math.sin(cur_t)) / 2.0 # lr_min = 0, lr_max = lr
sh_lr = new_lr
if (epoch + 1) == te_next: # time to restart
sh_lr = lr
tt = 0 # by setting to 0 we set lr to lr_max, see above
t_e = t_e * mult_factor # change the period of restarts
te_next = te_next + t_e # note the next restart's epoch
return lr_values
def get_sgdr_values(self, lr, initial_period_steps, t_mul, iters):
"""Get an array with learning rate values from the consecutive steps
using current tensorflow implementation."""
with self.test_session():
step = placeholder(dtypes.int32)
decay = sgdr_decay(lr, step, initial_period_steps, t_mul)
lr_values = []
for i in range(iters):
lr_values.append(decay.eval(feed_dict={step: i}))
return lr_values
def testCompareToOriginal(self):
"""Compare values generated by tensorflow implementation to the values
generated by the original implementation
(https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
with self.test_session():
lr = 10.0
init_steps = 2
t_mul = 3
iters = 10
epochs = 50
org_lr = self.get_original_values(lr, init_steps, t_mul, iters, epochs)
sgdr_lr = self.get_sgdr_values(lr, init_steps*iters, t_mul, iters*epochs)
for org, sgdr in zip(org_lr, sgdr_lr):
self.assertAllClose(org, sgdr)
def testMDecay(self):
"""Test m_mul argument. Check values for learning rate at the beginning
of the first, second, third and fourth period. """
with self.test_session():
step = placeholder(dtypes.int32)
lr = 0.1
t_e = 10
t_mul = 3
m_mul = 0.9
decay = sgdr_decay(lr, step, t_e, t_mul, m_mul)
test_step = 0
self.assertAllClose(decay.eval(feed_dict={step: test_step}),
lr)
test_step = t_e
self.assertAllClose(decay.eval(feed_dict={step: test_step}),
lr * m_mul)
test_step = t_e + t_e*t_mul
self.assertAllClose(decay.eval(feed_dict={step: test_step}),
lr * m_mul**2)
test_step = t_e + t_e*t_mul + t_e * (t_mul**2)
self.assertAllClose(decay.eval(feed_dict={step: test_step}),
lr * (m_mul**3))
def testCos(self):
"""Check learning rate values at the beginning, in the middle
and at the end of the period."""
with self.test_session():
step = placeholder(dtypes.int32)
lr = 0.2
t_e = 1000
t_mul = 1
decay = sgdr_decay(lr, step, t_e, t_mul)
test_step = 0
self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr)
test_step = t_e//2
self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2)
test_step = t_e
self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr)
test_step = t_e*3//2
self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2)
if __name__ == "__main__":
googletest.main()

View File

@ -17,6 +17,8 @@ limitations under the License.
#include "tensorflow/contrib/verbs/verbs_server_lib.h" #include "tensorflow/contrib/verbs/verbs_server_lib.h"
#include "grpc/support/alloc.h"
#include "tensorflow/contrib/verbs/rdma_mgr.h" #include "tensorflow/contrib/verbs/rdma_mgr.h"
#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h" #include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h"

View File

@ -116,6 +116,7 @@ load(
"tf_lib_proto_parsing_deps", "tf_lib_proto_parsing_deps",
"tf_additional_verbs_lib_defines", "tf_additional_verbs_lib_defines",
"tf_additional_mpi_lib_defines", "tf_additional_mpi_lib_defines",
"tf_additional_gdr_lib_defines",
"tf_additional_gpu_tracer_srcs", "tf_additional_gpu_tracer_srcs",
"tf_additional_gpu_tracer_deps", "tf_additional_gpu_tracer_deps",
"tf_additional_gpu_tracer_cuda_deps", "tf_additional_gpu_tracer_cuda_deps",
@ -1245,72 +1246,36 @@ tf_proto_library_cc(
], ],
) )
LIB_INTERNAL_WINDOWS_DEPS = glob(
[
"lib/**/*.h",
"lib/**/*.cc",
"platform/*.h",
"platform/*.cc",
"platform/profile_utils/**/*.h",
"platform/profile_utils/**/*.cc",
] + [
"framework/resource_handle.h",
"framework/resource_handle.cc",
"framework/variant_tensor_data.h",
"framework/variant_tensor_data.cc",
],
exclude = [
"**/*test*",
"lib/hash/crc32c_accelerate.cc",
"lib/gif/**/*",
"lib/jpeg/**/*",
"platform/gif.h",
"platform/jpeg.h",
"platform/**/env_time.cc",
"platform/**/cuda.h",
"platform/**/cuda_libdevice_path.cc",
"platform/**/stream_executor.h",
"platform/load_library.cc",
"platform/variant_coding.cc",
"platform/**/variant_cord_coding.cc",
],
)
cc_library( cc_library(
name = "lib_internal", name = "lib_internal",
srcs = select({ srcs = glob(
"//tensorflow:windows": LIB_INTERNAL_WINDOWS_DEPS, [
"//tensorflow:windows_msvc": LIB_INTERNAL_WINDOWS_DEPS, "lib/**/*.h",
"//conditions:default": glob( "lib/**/*.cc",
[ "platform/*.h",
"lib/**/*.h", "platform/*.cc",
"lib/**/*.cc", "platform/profile_utils/**/*.h",
"platform/*.h", "platform/profile_utils/**/*.cc",
"platform/*.cc", "framework/resource_handle.h",
"platform/profile_utils/**/*.h", "framework/resource_handle.cc",
"platform/profile_utils/**/*.cc", ],
"framework/resource_handle.h", exclude = [
"framework/resource_handle.cc", "**/*test*",
], "framework/variant.cc",
exclude = [ "lib/hash/crc32c_accelerate.cc",
"**/*test*", "lib/gif/**/*",
"framework/variant.cc", "lib/jpeg/**/*",
"platform/variant_coding.cc", "platform/gif.h",
"lib/hash/crc32c_accelerate.cc", "platform/jpeg.h",
"lib/gif/**/*", "platform/**/env_time.cc",
"lib/jpeg/**/*", "platform/**/cuda.h",
"platform/gif.h", "platform/**/cuda_libdevice_path.cc",
"platform/jpeg.h", "platform/**/stream_executor.h",
"platform/**/env_time.cc", "platform/**/gpu_tracer.cc",
"platform/**/cuda.h", "platform/variant_coding.cc",
"platform/**/cuda_libdevice_path.cc", "platform/**/variant_cord_coding.cc",
"platform/**/stream_executor.h", ],
"platform/**/gpu_tracer.cc", ) + tf_additional_lib_srcs(
"platform/variant_coding.cc",
"platform/**/variant_cord_coding.cc",
],
),
}) + tf_additional_lib_srcs(
exclude = [ exclude = [
"**/*test*", "**/*test*",
"platform/**/cuda.h", "platform/**/cuda.h",
@ -1370,9 +1335,12 @@ cc_library(
defines = tf_additional_lib_defines() + [ defines = tf_additional_lib_defines() + [
"SNAPPY", "SNAPPY",
] + tf_additional_verbs_lib_defines() + ] + tf_additional_verbs_lib_defines() +
tf_additional_mpi_lib_defines(), tf_additional_mpi_lib_defines() +
tf_additional_gdr_lib_defines(),
linkopts = select({ linkopts = select({
"//tensorflow:freebsd": [], "//tensorflow:freebsd": [],
"//tensorflow:windows": [],
"//tensorflow:windows_msvc": [],
"//conditions:default": [ "//conditions:default": [
"-ldl", "-ldl",
"-lpthread", "-lpthread",
@ -1407,6 +1375,8 @@ cc_library(
copts = tf_copts(), copts = tf_copts(),
linkopts = select({ linkopts = select({
"//tensorflow:freebsd": [], "//tensorflow:freebsd": [],
"//tensorflow:windows": [],
"//tensorflow:windows_msvc": [],
"//conditions:default": ["-ldl"], "//conditions:default": ["-ldl"],
}), }),
deps = [ deps = [
@ -1430,6 +1400,8 @@ cc_library(
copts = tf_copts(), copts = tf_copts(),
linkopts = select({ linkopts = select({
"//tensorflow:freebsd": [], "//tensorflow:freebsd": [],
"//tensorflow:windows": [],
"//tensorflow:windows_msvc": [],
"//conditions:default": ["-ldl"], "//conditions:default": ["-ldl"],
}), }),
deps = [ deps = [
@ -1605,6 +1577,8 @@ tf_cuda_library(
copts = tf_copts(), copts = tf_copts(),
linkopts = select({ linkopts = select({
"//tensorflow:freebsd": [], "//tensorflow:freebsd": [],
"//tensorflow:windows": [],
"//tensorflow:windows_msvc": [],
"//conditions:default": ["-ldl"], "//conditions:default": ["-ldl"],
}) + [ }) + [
"-lm", "-lm",

View File

@ -22,7 +22,7 @@ limitations under the License.
// Device names // Device names
// * Every Device should have a unique name with the format: // * Every Device should have a unique name with the format:
// /job:___/replica:___/task:___/(gpu|cpu):___ // /job:___/replica:___/task:___/(gpu|cpu):___
// An example name would be "/job:train/replica:0/task:3/gpu:2". // An example name would be "/job:train/replica:0/task:3/device:GPU:2".
// * Task numbers are within the specified replica, so there are as // * Task numbers are within the specified replica, so there are as
// many "task zeros" as replicas. // many "task zeros" as replicas.

View File

@ -471,7 +471,7 @@ Status DirectSession::Run(const RunOptions& run_options,
args.step_id = step_id_counter_.fetch_add(1); args.step_id = step_id_counter_.fetch_add(1);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
GetOrCreateExecutors(pool, input_tensor_names, output_names, target_nodes, GetOrCreateExecutors(input_tensor_names, output_names, target_nodes,
&executors_and_keys, &run_state_args)); &executors_and_keys, &run_state_args));
const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1); const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
@ -711,7 +711,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
DebugOptions debug_options; DebugOptions debug_options;
RunStateArgs run_state_args(debug_options); RunStateArgs run_state_args(debug_options);
run_state_args.is_partial_run = true; run_state_args.is_partial_run = true;
TF_RETURN_IF_ERROR(GetOrCreateExecutors(pool, input_names, output_names, TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
target_nodes, &executors_and_keys, target_nodes, &executors_and_keys,
&run_state_args)); &run_state_args));
@ -1042,9 +1042,9 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds,
} }
Status DirectSession::GetOrCreateExecutors( Status DirectSession::GetOrCreateExecutors(
thread::ThreadPool* pool, gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
gtl::ArraySlice<string> outputs, gtl::ArraySlice<string> target_nodes, gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args) { RunStateArgs* run_state_args) {
int64 handle_name_counter_value = -1; int64 handle_name_counter_value = -1;
if (LogMemory::IsEnabled() || run_state_args->is_partial_run) { if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
handle_name_counter_value = handle_name_counter_.fetch_add(1); handle_name_counter_value = handle_name_counter_.fetch_add(1);

View File

@ -194,8 +194,8 @@ class DirectSession : public Session {
// Retrieves an already existing set of executors to run 'inputs' and // Retrieves an already existing set of executors to run 'inputs' and
// 'outputs', or creates and caches them for future use. // 'outputs', or creates and caches them for future use.
::tensorflow::Status GetOrCreateExecutors( ::tensorflow::Status GetOrCreateExecutors(
thread::ThreadPool* pool, gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
gtl::ArraySlice<string> outputs, gtl::ArraySlice<string> target_nodes, gtl::ArraySlice<string> target_nodes,
ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args); ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
// Creates several graphs given the existing graph_def_ and the // Creates several graphs given the existing graph_def_ and the

View File

@ -476,7 +476,7 @@ TEST(DirectSessionTest, PlacePrunedGraph) {
vx.scalar<float>()() = 1.0; vx.scalar<float>()() = 1.0;
Node* x = test::graph::Constant(&g, vx); Node* x = test::graph::Constant(&g, vx);
Node* y = test::graph::Unary(&g, "Darth", x); Node* y = test::graph::Unary(&g, "Darth", x);
y->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0"); y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
GraphDef def; GraphDef def;
test::graph::ToGraphDef(&g, &def); test::graph::ToGraphDef(&g, &def);
@ -494,7 +494,7 @@ TEST(DirectSessionTest, PlacePrunedGraph) {
vx.scalar<float>()() = 1.0; vx.scalar<float>()() = 1.0;
Node* x = test::graph::Constant(&g, vx); Node* x = test::graph::Constant(&g, vx);
Node* y = test::graph::Unary(&g, "Darth", x); Node* y = test::graph::Unary(&g, "Darth", x);
y->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0"); y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
GraphDef def; GraphDef def;
test::graph::ToGraphDef(&g, &def); test::graph::ToGraphDef(&g, &def);

View File

@ -154,14 +154,14 @@ static void TestHWAccelerator(bool enableHWTrace) {
Tensor x_tensor(DT_FLOAT, TensorShape({2, 1})); Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
test::FillValues<float>(&x_tensor, {1, 1}); test::FillValues<float>(&x_tensor, {1, 1});
Node* x = test::graph::Constant(&graph, x_tensor); Node* x = test::graph::Constant(&graph, x_tensor);
x->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0"); x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
// y = A * x // y = A * x
Node* y = test::graph::Matmul(&graph, a, x, false, false); Node* y = test::graph::Matmul(&graph, a, x, false, false);
y->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0"); y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL

View File

@ -114,14 +114,14 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
<< num_bytes << ". See error logs for more detailed info."; << num_bytes << ". See error logs for more detailed info.";
} }
} }
if (LogMemory::IsEnabled()) { if (LogMemory::IsEnabled() && ret != nullptr) {
LogMemory::RecordRawAllocation(operation_, step_id_, num_bytes, ret, LogMemory::RecordRawAllocation(operation_, step_id_, num_bytes, ret,
allocator_); allocator_);
} }
return ret; return ret;
} }
void deallocate(void* buffer) const override { void deallocate(void* buffer) const override {
if (LogMemory::IsEnabled()) { if (LogMemory::IsEnabled() && buffer != nullptr) {
LogMemory::RecordRawDeallocation(operation_, step_id_, buffer, allocator_, LogMemory::RecordRawDeallocation(operation_, step_id_, buffer, allocator_,
true); true);
} }
@ -588,7 +588,7 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
BaseGPUDevice* gpu_device; BaseGPUDevice* gpu_device;
TF_RETURN_IF_ERROR(CreateGPUDevice(options, TF_RETURN_IF_ERROR(CreateGPUDevice(options,
strings::StrCat(name_prefix, "/gpu:", i), strings::StrCat(name_prefix, "/device:GPU:", i),
valid_gpu_ids[i], &gpu_device)); valid_gpu_ids[i], &gpu_device));
TF_RETURN_IF_ERROR(gpu_device->Init(options)); TF_RETURN_IF_ERROR(gpu_device->Init(options));
devices->push_back(gpu_device); devices->push_back(gpu_device);
@ -1049,7 +1049,7 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
size_t new_id = ids->size(); size_t new_id = ids->size();
ids->push_back(visible_gpu_id); ids->push_back(visible_gpu_id);
LOG(INFO) << "Creating TensorFlow device (/gpu:" << new_id << ") -> " LOG(INFO) << "Creating TensorFlow device (/device:GPU:" << new_id << ") -> "
<< "(" << GetShortDeviceDescription(visible_gpu_id, desc) << ")"; << "(" << GetShortDeviceDescription(visible_gpu_id, desc) << ")";
} }

View File

@ -141,7 +141,7 @@ class BaseGPUDeviceFactory : public DeviceFactory {
Allocator* cpu_allocator) = 0; Allocator* cpu_allocator) = 0;
// Returns into 'ids' the list of valid GPU ids, in the order that // Returns into 'ids' the list of valid GPU ids, in the order that
// they should map to logical gpu ids "/gpu:0", "/gpu:1", etc, based // they should map to logical gpu ids "/device:GPU:0", "/device:GPU:1", etc, based
// upon 'visible_device_list', a comma-separated list of 'visible // upon 'visible_device_list', a comma-separated list of 'visible
// gpu ids'. // gpu ids'.
Status GetValidDeviceIds(const string& visible_device_list, Status GetValidDeviceIds(const string& visible_device_list,

View File

@ -106,9 +106,9 @@ TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
TEST_F(GpuStreamUtilTest, StreamOverrides) { TEST_F(GpuStreamUtilTest, StreamOverrides) {
auto root = Scope::NewRootScope().ExitOnError(); auto root = Scope::NewRootScope().ExitOnError();
ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0, ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0,
"/gpu:0"); "/device:GPU:0");
Output n = ops::MatMul(root, {}, {}); Output n = ops::MatMul(root, {}, {});
ops::_Send(root.WithOpName("output"), n, "output", "/gpu:0", 0, "/cpu:0"); ops::_Send(root.WithOpName("output"), n, "output", "/device:GPU:0", 0, "/cpu:0");
Graph g(OpRegistry::Global()); Graph g(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&g)); TF_ASSERT_OK(root.ToGraph(&g));

View File

@ -167,7 +167,7 @@ Allocator* ProcessState::GetCPUAllocator(int numa_node) {
if (!status.ok()) { if (!status.ok()) {
LOG(ERROR) << "GetCPUAllocator: " << status.error_message(); LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
} }
Allocator* allocator; VisitableAllocator* allocator;
if (use_bfc_allocator) { if (use_bfc_allocator) {
// TODO(reedwm): evaluate whether 64GB by default is the best choice. // TODO(reedwm): evaluate whether 64GB by default is the best choice.
int64 cpu_mem_limit_in_mb = -1; int64 cpu_mem_limit_in_mb = -1;
@ -192,7 +192,7 @@ Allocator* ProcessState::GetCPUAllocator(int numa_node) {
if (LogMemory::IsEnabled()) { if (LogMemory::IsEnabled()) {
// Wrap the allocator to track allocation ids for better logging // Wrap the allocator to track allocation ids for better logging
// at the cost of performance. // at the cost of performance.
allocator = new TrackingAllocator(allocator, true); allocator = new TrackingVisitableAllocator(allocator, true);
} }
cpu_allocators_.push_back(allocator); cpu_allocators_.push_back(allocator);
} }
@ -237,14 +237,14 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
LOG(ERROR) << "GetCUDAHostAllocator: " << status.error_message(); LOG(ERROR) << "GetCUDAHostAllocator: " << status.error_message();
} }
int64 cuda_host_mem_limit = cuda_host_mem_limit_in_mb * (1LL << 20); int64 cuda_host_mem_limit = cuda_host_mem_limit_in_mb * (1LL << 20);
Allocator* allocator = VisitableAllocator* allocator =
new BFCAllocator(new CUDAHostAllocator(se), cuda_host_mem_limit, new BFCAllocator(new CUDAHostAllocator(se), cuda_host_mem_limit,
true /*allow_growth*/, "cuda_host_bfc" /*name*/); true /*allow_growth*/, "cuda_host_bfc" /*name*/);
if (LogMemory::IsEnabled()) { if (LogMemory::IsEnabled()) {
// Wrap the allocator to track allocation ids for better logging // Wrap the allocator to track allocation ids for better logging
// at the cost of performance. // at the cost of performance.
allocator = new TrackingAllocator(allocator, true); allocator = new TrackingVisitableAllocator(allocator, true);
} }
cuda_host_allocators_.push_back(allocator); cuda_host_allocators_.push_back(allocator);
if (FLAGS_brain_gpu_record_mem_types) { if (FLAGS_brain_gpu_record_mem_types) {

View File

@ -53,7 +53,7 @@ TEST(MemoryTypeChecker, Int32NotOk) {
EXPECT_TRUE(errors::IsInternal(ValidateMemoryTypes(DEVICE_GPU, g))); EXPECT_TRUE(errors::IsInternal(ValidateMemoryTypes(DEVICE_GPU, g)));
// But we can insert _HostSend/_HostRecv to ensure the invariant. // But we can insert _HostSend/_HostRecv to ensure the invariant.
TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_GPU, "/gpu:0", g)); TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_GPU, "/device:GPU:0", g));
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_GPU, g)); TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_GPU, g));
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL

View File

@ -682,7 +682,7 @@ Status SimplePlacer::Run() {
int dst_root_id = colocation_graph.FindRoot(dst->id()); int dst_root_id = colocation_graph.FindRoot(dst->id());
auto& src_root = colocation_graph.members_[src_root_id]; auto& src_root = colocation_graph.members_[src_root_id];
auto& dst_root = colocation_graph.members_[dst_root_id]; auto& dst_root = colocation_graph.members_[dst_root_id];
// If both the source node and this node have paritally // If both the source node and this node have partially
// specified a device, then 'node's device should be // specified a device, then 'node's device should be
// cleared: the reference edge forces 'node' to be on the // cleared: the reference edge forces 'node' to be on the
// same device as the source node. // same device as the source node.

View File

@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
SYCLAllocator::SYCLAllocator(Eigen::QueueInterface *queue) SYCLAllocator::SYCLAllocator(Eigen::QueueInterface* queue)
: sycl_device_(new Eigen::SyclDevice(queue)) { : sycl_device_(new Eigen::SyclDevice(queue)) {
cl::sycl::queue& sycl_queue = sycl_device_->sycl_queue(); cl::sycl::queue& sycl_queue = sycl_device_->sycl_queue();
const cl::sycl::device& device = sycl_queue.get_device(); const cl::sycl::device& device = sycl_queue.get_device();
@ -28,14 +28,15 @@ SYCLAllocator::SYCLAllocator(Eigen::QueueInterface *queue)
} }
SYCLAllocator::~SYCLAllocator() { SYCLAllocator::~SYCLAllocator() {
if(sycl_device_) { if (sycl_device_) {
delete sycl_device_; delete sycl_device_;
} }
} }
string SYCLAllocator::Name() { return "device:SYCL"; } string SYCLAllocator::Name() { return "device:SYCL"; }
void *SYCLAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { void* SYCLAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
mutex_lock lock(mu_);
assert(sycl_device_); assert(sycl_device_);
if (num_bytes == 0) { if (num_bytes == 0) {
// Cannot allocate no bytes in SYCL, so instead allocate a single byte // Cannot allocate no bytes in SYCL, so instead allocate a single byte
@ -45,7 +46,6 @@ void *SYCLAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
const auto& allocated_buffer = sycl_device_->get_sycl_buffer(p); const auto& allocated_buffer = sycl_device_->get_sycl_buffer(p);
const std::size_t bytes_allocated = allocated_buffer.get_range().size(); const std::size_t bytes_allocated = allocated_buffer.get_range().size();
mutex_lock lock(mu_);
++stats_.num_allocs; ++stats_.num_allocs;
stats_.bytes_in_use += bytes_allocated; stats_.bytes_in_use += bytes_allocated;
stats_.max_bytes_in_use = stats_.max_bytes_in_use =
@ -56,12 +56,12 @@ void *SYCLAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
return p; return p;
} }
void SYCLAllocator::DeallocateRaw(void *ptr) { void SYCLAllocator::DeallocateRaw(void* ptr) {
const auto& buffer_to_delete = sycl_device_->get_sycl_buffer(ptr);
const std::size_t dealloc_size = buffer_to_delete.get_range().size();
mutex_lock lock(mu_); mutex_lock lock(mu_);
stats_.bytes_in_use -= dealloc_size;
if (sycl_device_) { if (sycl_device_) {
const auto& buffer_to_delete = sycl_device_->get_sycl_buffer(ptr);
const std::size_t dealloc_size = buffer_to_delete.get_range().size();
stats_.bytes_in_use -= dealloc_size;
sycl_device_->deallocate(ptr); sycl_device_->deallocate(ptr);
} }
} }
@ -72,6 +72,10 @@ void SYCLAllocator::GetStats(AllocatorStats* stats) {
} }
size_t SYCLAllocator::RequestedSize(void* ptr) { size_t SYCLAllocator::RequestedSize(void* ptr) {
mutex_lock lock(mu_);
if(!sycl_device_) {
return 0;
}
const auto& buffer = sycl_device_->get_sycl_buffer(ptr); const auto& buffer = sycl_device_->get_sycl_buffer(ptr);
return buffer.get_size(); return buffer.get_size();
} }

View File

@ -29,15 +29,20 @@ namespace tensorflow {
class SYCLAllocator : public Allocator { class SYCLAllocator : public Allocator {
public: public:
SYCLAllocator(Eigen::QueueInterface *queue); SYCLAllocator(Eigen::QueueInterface* queue);
virtual ~SYCLAllocator() override; virtual ~SYCLAllocator() override;
string Name() override; string Name() override;
void *AllocateRaw(size_t alignment, size_t num_bytes) override; void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void *ptr) override; void DeallocateRaw(void* ptr) override;
virtual bool ShouldAllocateEmptyTensors() override final { return true; } virtual bool ShouldAllocateEmptyTensors() override final { return true; }
void Synchronize() { sycl_device_->synchronize(); } void Synchronize() {
bool Ok() { return sycl_device_->ok(); } mutex_lock lock(mu_);
if (sycl_device_) {
sycl_device_->synchronize();
}
}
bool Ok() { return sycl_device_ && sycl_device_->ok(); }
void GetStats(AllocatorStats* stats) override; void GetStats(AllocatorStats* stats) override;
// The SYCL buffers keep track of their size, so we already have tracking. // The SYCL buffers keep track of their size, so we already have tracking.
bool TracksAllocationSizes() override { return true; } bool TracksAllocationSizes() override { return true; }
@ -46,10 +51,19 @@ class SYCLAllocator : public Allocator {
// AllocatedSize(void* ptr) by default. // AllocatedSize(void* ptr) by default.
size_t RequestedSize(void* ptr) override; size_t RequestedSize(void* ptr) override;
Eigen::SyclDevice* getSyclDevice() { return sycl_device_; } Eigen::SyclDevice* getSyclDevice() { return sycl_device_; }
// Clear the SYCL device used by the Allocator
void ClearSYCLDevice() {
mutex_lock lock(mu_);
if(sycl_device_) {
delete sycl_device_;
sycl_device_ = nullptr;
}
}
private: private:
Eigen::SyclDevice *sycl_device_; // owned
mutable mutex mu_; mutable mutex mu_;
Eigen::SyclDevice* sycl_device_ GUARDED_BY(mu_); // owned
AllocatorStats stats_ GUARDED_BY(mu_); AllocatorStats stats_ GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(SYCLAllocator); TF_DISALLOW_COPY_AND_ASSIGN(SYCLAllocator);

View File

@ -22,20 +22,10 @@ limitations under the License.
#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/tracing.h"
namespace tensorflow { namespace tensorflow {
std::mutex GSYCLInterface::mutex_;
GSYCLInterface *GSYCLInterface::s_instance = 0;
void ShutdownSycl() {
GSYCLInterface::Reset();
}
void SYCLDevice::RegisterDevice() {
atexit(ShutdownSycl);
}
SYCLDevice::~SYCLDevice() {} SYCLDevice::~SYCLDevice() {}
void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) { void SYCLDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
assert(context); assert(context);
if (port::Tracing::IsActive()) { if (port::Tracing::IsActive()) {
// TODO(pbar) We really need a useful identifier of the graph node. // TODO(pbar) We really need a useful identifier of the graph node.
@ -46,16 +36,16 @@ void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) {
op_kernel->Compute(context); op_kernel->Compute(context);
} }
Allocator *SYCLDevice::GetAllocator(AllocatorAttributes attr) { Allocator* SYCLDevice::GetAllocator(AllocatorAttributes attr) {
if (attr.on_host()) if (attr.on_host())
return cpu_allocator_; return cpu_allocator_;
else else
return sycl_allocator_; return sycl_allocator_;
} }
Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto, Status SYCLDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs, const AllocatorAttributes alloc_attrs,
Tensor *tensor) { Tensor* tensor) {
AllocatorAttributes attr; AllocatorAttributes attr;
attr.set_on_host(true); attr.set_on_host(true);
Allocator* host_alloc = GetAllocator(attr); Allocator* host_alloc = GetAllocator(attr);
@ -79,18 +69,18 @@ Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto,
} }
device_context_->CopyCPUTensorToDevice( device_context_->CopyCPUTensorToDevice(
&parsed, this, &copy, [&status](const Status &s) { status = s; }); &parsed, this, &copy, [&status](const Status& s) { status = s; });
*tensor = copy; *tensor = copy;
} }
return status; return status;
} }
Status SYCLDevice::FillContextMap(const Graph *graph, Status SYCLDevice::FillContextMap(const Graph* graph,
DeviceContextMap *device_context_map) { DeviceContextMap* device_context_map) {
// Fill in the context map. It is OK for this map to contain // Fill in the context map. It is OK for this map to contain
// duplicate DeviceContexts so long as we increment the refcount. // duplicate DeviceContexts so long as we increment the refcount.
device_context_map->resize(graph->num_node_ids()); device_context_map->resize(graph->num_node_ids());
for (Node *n : graph->nodes()) { for (Node* n : graph->nodes()) {
device_context_->Ref(); device_context_->Ref();
(*device_context_map)[n->id()] = device_context_; (*device_context_map)[n->id()] = device_context_;
} }

View File

@ -27,201 +27,190 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
class GSYCLInterface {
class GSYCLInterface std::vector<Eigen::QueueInterface*> m_queue_interface_; // owned
{ std::vector<Allocator*> m_cpu_allocator_; // not owned
std::vector<Eigen::QueueInterface*> m_queue_interface_; // owned std::vector<SYCLAllocator*> m_sycl_allocator_; // owned
std::vector<Allocator*> m_cpu_allocator_; // not owned std::vector<SYCLDeviceContext*> m_sycl_context_; // ref counted
std::vector<SYCLAllocator*> m_sycl_allocator_; // owned GSYCLInterface() {
std::vector<SYCLDeviceContext*> m_sycl_context_; // owned bool found_device = false;
auto device_list = Eigen::get_sycl_supported_devices();
static std::mutex mutex_; // Obtain list of supported devices from Eigen
static GSYCLInterface* s_instance; for (const auto& device : device_list) {
GSYCLInterface() { if (device.is_gpu()) {
bool found_device =false; // returns first found GPU
auto device_list = Eigen::get_sycl_supported_devices(); AddDevice(device);
// Obtain list of supported devices from Eigen found_device = true;
for (const auto& device : device_list) {
if(device.is_gpu()) {
// returns first found GPU
AddDevice(device);
found_device = true;
}
}
if(!found_device) {
// Currently Intel GPU is not supported
LOG(WARNING) << "No OpenCL GPU found that is supported by ComputeCpp, trying OpenCL CPU";
}
for (const auto& device : device_list) {
if(device.is_cpu()) {
// returns first found CPU
AddDevice(device);
found_device = true;
}
}
if(!found_device) {
// Currently Intel GPU is not supported
LOG(FATAL) << "No OpenCL GPU nor CPU found that is supported by ComputeCpp";
} else {
LOG(INFO) << "Found following OpenCL devices:";
for (int i = 0; i < device_list.size(); i++) {
LOG(INFO) << GetShortDeviceDescription(i);
}
} }
} }
~GSYCLInterface() { if (!found_device) {
m_cpu_allocator_.clear(); // Currently Intel GPU is not supported
LOG(WARNING) << "No OpenCL GPU found that is supported by ComputeCpp, "
for (auto p : m_sycl_allocator_) { "trying OpenCL CPU";
p->Synchronize();
delete p;
}
m_sycl_allocator_.clear();
for(auto p : m_sycl_context_) {
p->Unref();
}
m_sycl_context_.clear();
for (auto p : m_queue_interface_) {
p->deallocate_all();
delete p;
p = nullptr;
}
m_queue_interface_.clear();
} }
void AddDevice(const cl::sycl::device & d) { for (const auto& device : device_list) {
m_queue_interface_.push_back(new Eigen::QueueInterface(d)); if (device.is_cpu()) {
m_cpu_allocator_.push_back(cpu_allocator()); // returns first found CPU
m_sycl_allocator_.push_back(new SYCLAllocator(m_queue_interface_.back())); AddDevice(device);
m_sycl_context_.push_back(new SYCLDeviceContext()); found_device = true;
}
public:
static GSYCLInterface *instance()
{
std::lock_guard<std::mutex> lock(mutex_);
if (!s_instance) {
s_instance = new GSYCLInterface();
}
return s_instance;
}
static void Reset()
{
std::lock_guard<std::mutex> lock(mutex_);
if(s_instance) {
delete s_instance;
s_instance = NULL;
} }
} }
Eigen::QueueInterface * GetQueueInterface(size_t i = 0) { if (!found_device) {
if(!m_queue_interface_.empty()) { // Currently Intel GPU is not supported
return m_queue_interface_[i]; LOG(FATAL)
} else { << "No OpenCL GPU nor CPU found that is supported by ComputeCpp";
std::cerr << "No cl::sycl::device has been added" << std::endl; } else {
return nullptr; LOG(INFO) << "Found following OpenCL devices:";
for (int i = 0; i < device_list.size(); i++) {
LOG(INFO) << GetShortDeviceDescription(i);
} }
} }
}
SYCLAllocator * GetSYCLAllocator(size_t i = 0) { ~GSYCLInterface() {
if(!m_sycl_allocator_.empty()) { m_cpu_allocator_.clear();
return m_sycl_allocator_[i];
} else { for (auto p : m_sycl_allocator_) {
std::cerr << "No cl::sycl::device has been added" << std::endl; p->Synchronize();
return nullptr; p->ClearSYCLDevice();
} // Cannot delete the Allocator instances, as the Allocator lifetime
// needs to exceed any Tensor created by it. There is no way of
// knowing when all Tensors have been deallocated, as they are
// RefCounted and wait until all instances of a Tensor have been
// destroyed before calling Allocator.Deallocate. This could happen at
// program exit, which can set up a race condition between destroying
// Tensors and Allocators when the program is cleaning up.
}
m_sycl_allocator_.clear();
for (auto p : m_sycl_context_) {
p->Unref();
}
m_sycl_context_.clear();
for (auto p : m_queue_interface_) {
p->deallocate_all();
delete p;
}
m_queue_interface_.clear();
}
void AddDevice(const cl::sycl::device& d) {
m_queue_interface_.push_back(new Eigen::QueueInterface(d));
m_cpu_allocator_.push_back(cpu_allocator());
m_sycl_allocator_.push_back(new SYCLAllocator(m_queue_interface_.back()));
m_sycl_context_.push_back(new SYCLDeviceContext());
}
public:
static const GSYCLInterface* instance() {
// c++11 guarantees that this will be constructed in a thread safe way
static const GSYCLInterface instance;
return &instance;
}
Eigen::QueueInterface* GetQueueInterface(size_t i = 0) const {
if (!m_queue_interface_.empty()) {
return m_queue_interface_[i];
} else {
std::cerr << "No cl::sycl::device has been added" << std::endl;
return nullptr;
}
}
SYCLAllocator* GetSYCLAllocator(size_t i = 0) const {
if (!m_sycl_allocator_.empty()) {
return m_sycl_allocator_[i];
} else {
std::cerr << "No cl::sycl::device has been added" << std::endl;
return nullptr;
}
}
Allocator* GetCPUAllocator(size_t i = 0) const {
if (!m_cpu_allocator_.empty()) {
return m_cpu_allocator_[i];
} else {
std::cerr << "No cl::sycl::device has been added" << std::endl;
return nullptr;
}
}
SYCLDeviceContext* GetSYCLContext(size_t i = 0) const {
if (!m_sycl_context_.empty()) {
return m_sycl_context_[i];
} else {
std::cerr << "No cl::sycl::device has been added" << std::endl;
return nullptr;
}
}
string GetShortDeviceDescription(int device_id = 0) const {
Eigen::QueueInterface* queue_ptr = GetQueueInterface(device_id);
if (!queue_ptr) {
LOG(ERROR)
<< "Device name cannot be given after Eigen QueueInterface destroyed";
return "";
}
auto device = queue_ptr->sycl_queue().get_device();
auto name = device.get_info<cl::sycl::info::device::name>();
auto vendor = device.get_info<cl::sycl::info::device::vendor>();
auto profile = device.get_info<cl::sycl::info::device::profile>();
std::string type;
if (device.is_host()) {
type = "Host";
} else if (device.is_cpu()) {
type = "CPU";
} else if (device.is_gpu()) {
type = "GPU";
} else if (device.is_accelerator()) {
type = "Accelerator";
} else {
type = "Unknown";
} }
Allocator * GetCPUAllocator(size_t i = 0) { return strings::StrCat("id: ", device_id, ", type: ", type, ", name: ",
if(!m_cpu_allocator_.empty()) { name.c_str(), ", vendor: ", vendor.c_str(),
return m_cpu_allocator_[i]; ", profile: ", profile.c_str());
} else { }
std::cerr << "No cl::sycl::device has been added" << std::endl;
return nullptr;
}
}
SYCLDeviceContext * GetSYCLContext(size_t i = 0) {
if(!m_sycl_context_.empty()) {
return m_sycl_context_[i];
} else {
std::cerr << "No cl::sycl::device has been added" << std::endl;
return nullptr;
}
}
string GetShortDeviceDescription(int device_id = 0) {
auto _device = GetSYCLAllocator(device_id)
->getSyclDevice()
->sycl_queue()
.get_device();
auto _name = _device.get_info<cl::sycl::info::device::name>();
auto _vendor = _device.get_info<cl::sycl::info::device::vendor>();
auto _profile = _device.get_info<cl::sycl::info::device::profile>();
std::string _type;
if (_device.is_host()) {
_type = "Host";
} else if (_device.is_cpu()) {
_type = "CPU";
} else if (_device.is_gpu()) {
_type = "GPU";
} else if (_device.is_accelerator()) {
_type = "Accelerator";
} else {
_type = "Unknown";
}
return strings::StrCat("id: ", device_id, " ,type: ", _type, " ,name: ",
_name.c_str(), " ,vendor: ", _vendor.c_str(),
" ,profile: ", _profile.c_str());
}
}; };
class SYCLDevice : public LocalDevice { class SYCLDevice : public LocalDevice {
public: public:
SYCLDevice(const SessionOptions &options, const string &name, SYCLDevice(const SessionOptions& options, const string& name,
Bytes memory_limit, const DeviceLocality &locality, Bytes memory_limit, const DeviceLocality& locality,
const string &physical_device_desc, SYCLAllocator * sycl_allocator, const string& physical_device_desc, SYCLAllocator* sycl_allocator,
Allocator *cpu_allocator, SYCLDeviceContext* ctx) Allocator* cpu_allocator, SYCLDeviceContext* ctx)
: LocalDevice( : LocalDevice(options, Device::BuildDeviceAttributes(
options, name, DEVICE_SYCL, memory_limit, locality,
Device::BuildDeviceAttributes(name, DEVICE_SYCL, memory_limit, physical_device_desc)),
locality, physical_device_desc)),
cpu_allocator_(cpu_allocator), cpu_allocator_(cpu_allocator),
sycl_allocator_(sycl_allocator), sycl_allocator_(sycl_allocator),
device_context_(ctx) { device_context_(ctx) {
RegisterDevice();
set_eigen_sycl_device(sycl_allocator->getSyclDevice()); set_eigen_sycl_device(sycl_allocator->getSyclDevice());
} }
~SYCLDevice() override; ~SYCLDevice() override;
void Compute(OpKernel *op_kernel, OpKernelContext *context) override; void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
Allocator *GetAllocator(AllocatorAttributes attr) override; Allocator* GetAllocator(AllocatorAttributes attr) override;
Status MakeTensorFromProto(const TensorProto &tensor_proto, Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs, const AllocatorAttributes alloc_attrs,
Tensor *tensor) override; Tensor* tensor) override;
Status FillContextMap(const Graph *graph, Status FillContextMap(const Graph* graph,
DeviceContextMap *device_context_map) override; DeviceContextMap* device_context_map) override;
Status Sync() override; Status Sync() override;
private: private:
void RegisterDevice(); Allocator* cpu_allocator_; // not owned
SYCLAllocator* sycl_allocator_; // not owned
Allocator *cpu_allocator_; // not owned SYCLDeviceContext* device_context_; // not owned
SYCLAllocator *sycl_allocator_; // not owned
SYCLDeviceContext *device_context_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -21,17 +21,60 @@ limitations under the License.
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
// For DMA helper // For DMA helper
#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
namespace tensorflow { namespace tensorflow {
inline void* GetBase(const Tensor* src) { inline void const* GetBase(const Tensor* src) { return DMAHelper::base(src); }
return const_cast<void*>(DMAHelper::base(src)); inline void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
inline void SYCLmemcpy(Eigen::SyclDevice const& device,
Tensor const& src_tensor, Tensor* dst_tensor) {
const size_t size = src_tensor.TotalBytes();
void* dst_ptr = GetBase(dst_tensor);
void const* src_ptr = GetBase(&src_tensor);
#define COPY_WITH_TYPE(T) \
device.memcpy(dst_ptr, static_cast<T const*>(src_ptr), size);
switch (src_tensor.dtype()) {
case DT_COMPLEX128:
COPY_WITH_TYPE(cl::sycl::cl_ulong2);
break;
case DT_DOUBLE:
case DT_COMPLEX64:
case DT_INT64:
COPY_WITH_TYPE(cl::sycl::cl_ulong);
break;
case DT_FLOAT:
case DT_INT32:
case DT_QINT32:
COPY_WITH_TYPE(cl::sycl::cl_uint);
break;
case DT_INT16:
case DT_UINT16:
case DT_BFLOAT16:
case DT_QINT16:
case DT_QUINT16:
case DT_HALF:
COPY_WITH_TYPE(cl::sycl::cl_ushort);
break;
case DT_BOOL:
COPY_WITH_TYPE(bool);
break;
case DT_UINT8:
case DT_INT8:
case DT_QINT8:
case DT_QUINT8:
COPY_WITH_TYPE(cl::sycl::cl_uchar);
break;
default:
LOG(FATAL) << "Unknown data type " << src_tensor.dtype();
break;
} }
#undef COPY_WITH_TYPE
inline void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
} }
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ #endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_

View File

@ -86,7 +86,7 @@ void DebugGateway::CopyTensor(const string& node_name, const int output_slot,
// Determine if the tensor is on device (GPU) or host (CPU). // Determine if the tensor is on device (GPU) or host (CPU).
// The second part of the check is necessary because even an OpKernel on // The second part of the check is necessary because even an OpKernel on
// may have output tensors allocated on CPU. // may have output tensors allocated on CPU.
if ((device->name().find("gpu:") != string::npos || device->name().find("SYCL:") != string::npos) && if ((device->name().find("GPU:") != string::npos || device->name().find("SYCL:") != string::npos) &&
!ctx->output_alloc_attr(output_slot).on_host()) { !ctx->output_alloc_attr(output_slot).on_host()) {
// GPU tensors: Copy it to host (CPU). // GPU tensors: Copy it to host (CPU).
DeviceContext* device_ctxt = ctx->op_device_context(); DeviceContext* device_ctxt = ctx->op_device_context();

View File

@ -47,7 +47,7 @@ class SessionDebugMinusAXTest : public ::testing::Test {
Graph graph(OpRegistry::Global()); Graph graph(OpRegistry::Global());
#if GOOGLE_CUDA #if GOOGLE_CUDA
const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0"; const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0";
#elif defined(TENSORFLOW_USE_SYCL) #elif defined(TENSORFLOW_USE_SYCL)
const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0"; const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
#else #else
@ -505,7 +505,7 @@ class SessionDebugOutputSlotWithoutOngoingEdgeTest : public ::testing::Test {
Graph graph(OpRegistry::Global()); Graph graph(OpRegistry::Global());
#if GOOGLE_CUDA #if GOOGLE_CUDA
const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0"; const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0";
#elif defined(TENSORFLOW_USE_SYCL) #elif defined(TENSORFLOW_USE_SYCL)
const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0"; const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
#else #else
@ -607,7 +607,7 @@ class SessionDebugVariableTest : public ::testing::Test {
Graph graph(OpRegistry::Global()); Graph graph(OpRegistry::Global());
#if GOOGLE_CUDA #if GOOGLE_CUDA
const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0"; const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0";
#elif defined(TENSORFLOW_USE_SYCL) #elif defined(TENSORFLOW_USE_SYCL)
const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0"; const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
#else #else
@ -879,7 +879,7 @@ class SessionDebugGPUSwitchTest : public ::testing::Test {
Graph graph(OpRegistry::Global()); Graph graph(OpRegistry::Global());
#ifdef GOOGLE_CUDA #ifdef GOOGLE_CUDA
const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0"; const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0";
#elif TENSORFLOW_USE_SYCL #elif TENSORFLOW_USE_SYCL
const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0"; const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
#endif #endif

View File

@ -53,14 +53,14 @@ class DebugIOUtilsTest : public ::testing::Test {
}; };
TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) { TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) {
DebugNodeKey debug_node_key("/job:worker/replica:1/task:0/gpu:2", DebugNodeKey debug_node_key("/job:worker/replica:1/task:0/device:GPU:2",
"hidden_1/MatMul", 0, "DebugIdentity"); "hidden_1/MatMul", 0, "DebugIdentity");
EXPECT_EQ("/job:worker/replica:1/task:0/gpu:2", debug_node_key.device_name); EXPECT_EQ("/job:worker/replica:1/task:0/device:GPU:2", debug_node_key.device_name);
EXPECT_EQ("hidden_1/MatMul", debug_node_key.node_name); EXPECT_EQ("hidden_1/MatMul", debug_node_key.node_name);
EXPECT_EQ(0, debug_node_key.output_slot); EXPECT_EQ(0, debug_node_key.output_slot);
EXPECT_EQ("DebugIdentity", debug_node_key.debug_op); EXPECT_EQ("DebugIdentity", debug_node_key.debug_op);
EXPECT_EQ("hidden_1/MatMul:0:DebugIdentity", debug_node_key.debug_node_name); EXPECT_EQ("hidden_1/MatMul:0:DebugIdentity", debug_node_key.debug_node_name);
EXPECT_EQ("_tfdbg_device_,job_worker,replica_1,task_0,gpu_2", EXPECT_EQ("_tfdbg_device_,job_worker,replica_1,task_0,device_GPU_2",
debug_node_key.device_path); debug_node_key.device_path);
} }

View File

@ -140,7 +140,7 @@ Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
} }
#define ALICE "/job:j/replica:0/task:0/cpu:0" #define ALICE "/job:j/replica:0/task:0/cpu:0"
#define BOB "/job:j/replica:0/task:0/gpu:0" #define BOB "/job:j/replica:0/task:0/device:GPU:0"
TEST_F(ExecutorTest, SimpleAdd) { TEST_F(ExecutorTest, SimpleAdd) {
// c = a + b // c = a + b

View File

@ -31,9 +31,9 @@ TEST(GrpcChannelTest, IsSameAddressSpace) {
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0", EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0",
"/job:mnist/replica:10/task:10/cpu:1")); "/job:mnist/replica:10/task:10/cpu:1"));
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0", EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0",
"/job:mnist/replica:10/task:10/gpu:2")); "/job:mnist/replica:10/task:10/device:GPU:2"));
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10", EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10",
"/job:mnist/replica:10/task:10/gpu:2")); "/job:mnist/replica:10/task:10/device:GPU:2"));
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:1", EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:1",
"/job:mnist/replica:10/task:10")); "/job:mnist/replica:10/task:10"));

View File

@ -129,28 +129,14 @@ class GrpcRemoteWorker : public WorkerInterface {
TensorResponse* response, StatusCallback done) override { TensorResponse* response, StatusCallback done) override {
VLOG(1) << "RecvTensorAsync req: " << request->DebugString(); VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
int64 start_usec = Env::Default()->NowMicros(); int64 start_usec = Env::Default()->NowMicros();
// Don't propagate dma_ok over gRPC.
RecvTensorRequest* req_copy = nullptr;
if (request->dma_ok()) {
req_copy = new RecvTensorRequest;
*req_copy = *request;
req_copy->set_dma_ok(false);
}
// Type-specialized logging for this method. // Type-specialized logging for this method.
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2); bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
StatusCallback wrapper_done; StatusCallback wrapper_done;
const StatusCallback* cb_to_use; const StatusCallback* cb_to_use;
if (!logging_active && req_copy == nullptr) { if (!logging_active) {
cb_to_use = &done; // No additional work to do, so just use done directly cb_to_use = &done; // No additional work to do, so just use done directly
} else if (!logging_active) {
wrapper_done = [req_copy, done](Status s) {
delete req_copy;
done(s);
};
cb_to_use = &wrapper_done;
} else { } else {
wrapper_done = [this, request, req_copy, response, done, wrapper_done = [this, request, response, done, start_usec](Status s) {
start_usec](Status s) {
if (logger_->LoggingActive()) { if (logger_->LoggingActive()) {
int64 end_usec = Env::Default()->NowMicros(); int64 end_usec = Env::Default()->NowMicros();
int64 step_id = request->step_id(); int64 step_id = request->step_id();
@ -189,14 +175,12 @@ class GrpcRemoteWorker : public WorkerInterface {
} }
VLOG(2) << "done callback, req: " << request->DebugString() VLOG(2) << "done callback, req: " << request->DebugString()
<< " response " << response->metadata().DebugString(); << " response " << response->metadata().DebugString();
delete req_copy;
done(s); done(s);
}; };
cb_to_use = &wrapper_done; cb_to_use = &wrapper_done;
} }
IssueRequest(req_copy ? req_copy : request, response, recvtensor_, IssueRequest(request, response, recvtensor_, *cb_to_use, call_opts);
*cb_to_use, call_opts);
} }
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,

View File

@ -105,7 +105,8 @@ GrpcServer::~GrpcServer() {
Status GrpcServer::Init( Status GrpcServer::Init(
ServiceInitFunction service_func, ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func) { const RendezvousMgrCreationFunction& rendezvous_mgr_func,
const WorkerCreationFunction& worker_func) {
mutex_lock l(mu_); mutex_lock l(mu_);
CHECK_EQ(state_, NEW); CHECK_EQ(state_, NEW);
master_env_.env = env_; master_env_.env = env_;
@ -183,7 +184,8 @@ Status GrpcServer::Init(
master_impl_ = CreateMaster(&master_env_); master_impl_ = CreateMaster(&master_env_);
master_service_ = NewGrpcMasterService( master_service_ = NewGrpcMasterService(
master_impl_.get(), config.operation_timeout_in_ms(), &builder); master_impl_.get(), config.operation_timeout_in_ms(), &builder);
worker_impl_ = NewGrpcWorker(&worker_env_); worker_impl_ =
worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_);
worker_service_ = worker_service_ =
NewGrpcWorkerService(worker_impl_.get(), &builder).release(); NewGrpcWorkerService(worker_impl_.get(), &builder).release();
// extra service: // extra service:
@ -239,7 +241,13 @@ Status GrpcServer::Init(
return Status::OK(); return Status::OK();
} }
Status GrpcServer::Init() { return Init(nullptr, nullptr); } Status GrpcServer::Init(
ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
return Init(service_func, rendezvous_mgr_func, nullptr);
}
Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr); }
Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec) { GrpcChannelSpec* channel_spec) {

View File

@ -45,6 +45,10 @@ typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)> typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
ServiceInitFunction; ServiceInitFunction;
// function that creates a grpc based worker implementation.
typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*)>
WorkerCreationFunction;
class GrpcServer : public ServerInterface { class GrpcServer : public ServerInterface {
protected: protected:
GrpcServer(const ServerDef& server_def, Env* env); GrpcServer(const ServerDef& server_def, Env* env);
@ -64,6 +68,10 @@ class GrpcServer : public ServerInterface {
const string target() const override; const string target() const override;
protected: protected:
Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
const WorkerCreationFunction& worker_func);
Status Init(ServiceInitFunction service_func, Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func); const RendezvousMgrCreationFunction& rendezvous_mgr_func);

View File

@ -347,32 +347,25 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
if (src_dev->tensorflow_gpu_device_info() && (!on_host)) { if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
#if GOOGLE_CUDA #if GOOGLE_CUDA
const DeviceContext* send_dev_context = send_args.device_context; const DeviceContext* send_dev_context = send_args.device_context;
RecvTensorResponse* tmp = new RecvTensorResponse; AllocatorAttributes alloc_attrs;
tmp->set_is_dead(is_dead); alloc_attrs.set_gpu_compatible(true);
alloc_attrs.set_on_host(true);
Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
CHECK(send_dev_context) CHECK(send_dev_context)
<< "send dev name: " << src_dev->name() << "send dev name: " << src_dev->name()
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info(); << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
// "val" is on a GPU. Uses GPUUtil to fill the response proto. // "val" is on a GPU. Uses GPUUtil to fill the copy on host.
StatusCallback response_ready = [response, done, StatusCallback copy_ready = [response, done, copy,
tmp](const Status& s) { is_dead](const Status& s) {
// The value is now ready to be returned on the wire. // The value is now ready to be returned on the wire.
tmp->set_send_start_micros(Env::Default()->NowMicros()); grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
grpc::EncodeRecvTensorResponseToByteBuffer(*tmp, response);
done(s); done(s);
delete tmp; delete copy;
}; };
// TODO (jeff,sanjay,mrry): Avoid copy on GPU path by GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy,
// modifying GPUUtil::SetProtoFromGPU to accept a copy_ready);
// ::grpc::ByteBuffer to serialize to, rather than
// encoding into a protocol buffer and then
// serializing that (i.e. figure out how to use
// EncodeTensorToByteBuffer on this path rather than
// EncodeRecvTensorResponseToByteBuffer)
GPUUtil::SetProtoFromGPU(val, src_dev, send_dev_context,
tmp->mutable_tensor(), is_dead,
response_ready);
#else #else
done(errors::Internal("No GPU device in process")); done(errors::Internal("No GPU device in process"));
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA

View File

@ -34,8 +34,10 @@ class GrpcWorker : public Worker {
GrpcWorker(WorkerEnv* env); GrpcWorker(WorkerEnv* env);
// Specialized version of RecvTensor for gRPC, which avoids a copy. // Specialized version of RecvTensor for gRPC, which avoids a copy.
void GrpcRecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, virtual void GrpcRecvTensorAsync(CallOptions* opts,
::grpc::ByteBuffer* response, StatusCallback done); const RecvTensorRequest* request,
::grpc::ByteBuffer* response,
StatusCallback done);
WorkerEnv* env(); WorkerEnv* env();
}; };

Some files were not shown because too many files have changed in this diff Show More