mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge changes from github.
PiperOrigin-RevId: 180746153
This commit is contained in:
parent
2eef71c3f9
commit
71896cc7e5
2
LICENSE
2
LICENSE
|
|
@ -1,4 +1,4 @@
|
||||||
Copyright 2017 The TensorFlow Authors. All rights reserved.
|
Copyright 2018 The TensorFlow Authors. All rights reserved.
|
||||||
|
|
||||||
Apache License
|
Apache License
|
||||||
Version 2.0, January 2004
|
Version 2.0, January 2004
|
||||||
|
|
|
||||||
|
|
@ -46,8 +46,8 @@ packages on Linux, Mac, and Windows.
|
||||||
|
|
||||||
|
|
||||||
**Individual whl files**
|
**Individual whl files**
|
||||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/))
|
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/))
|
||||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/))
|
||||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
||||||
* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/))
|
* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/))
|
||||||
* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/))
|
* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/))
|
||||||
|
|
|
||||||
|
|
@ -763,24 +763,6 @@ Status LgammaGrad(const Scope& scope, const Operation& op,
|
||||||
}
|
}
|
||||||
REGISTER_GRADIENT_OP("Lgamma", LgammaGrad);
|
REGISTER_GRADIENT_OP("Lgamma", LgammaGrad);
|
||||||
|
|
||||||
Status SelectGrad(const Scope& scope, const Operation& op,
|
|
||||||
const std::vector<Output>& grad_inputs,
|
|
||||||
std::vector<Output>* grad_outputs) {
|
|
||||||
auto comparator = op.input(0);
|
|
||||||
auto x = op.input(1);
|
|
||||||
auto zeros = ZerosLike(scope, x);
|
|
||||||
auto grad = grad_inputs[0];
|
|
||||||
|
|
||||||
auto gx_1 = Where3(scope, comparator, grad, zeros);
|
|
||||||
auto gx_2 = Where3(scope, comparator, zeros, grad);
|
|
||||||
|
|
||||||
grad_outputs->push_back(NoGradient());
|
|
||||||
grad_outputs->push_back(gx_1);
|
|
||||||
grad_outputs->push_back(gx_2);
|
|
||||||
return scope.status();
|
|
||||||
}
|
|
||||||
REGISTER_GRADIENT_OP("Select", SelectGrad);
|
|
||||||
|
|
||||||
Status MinOrMaxGrad(const Scope& scope, const Operation& op,
|
Status MinOrMaxGrad(const Scope& scope, const Operation& op,
|
||||||
const std::vector<Output>& grad_inputs,
|
const std::vector<Output>& grad_inputs,
|
||||||
std::vector<Output>* grad_outputs) {
|
std::vector<Output>* grad_outputs) {
|
||||||
|
|
|
||||||
|
|
@ -904,13 +904,5 @@ TEST_F(NaryGradTest, Prod) {
|
||||||
RunTest({x}, {x_shape}, {y}, {y_shape});
|
RunTest({x}, {x_shape}, {y}, {y_shape});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NaryGradTest, Select) {
|
|
||||||
TensorShape shape({3, 4});
|
|
||||||
auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
|
||||||
auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
|
||||||
auto y = Where3(scope_, Greater(scope_, x1, x2), x1, x2);
|
|
||||||
RunTest({x1, x2}, {shape, shape}, {y}, {shape});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,7 @@ Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the bdoy subgraph defined by `body`. `outputs` must be non-null and
|
// Create the body subgraph defined by `body`. `outputs` must be non-null and
|
||||||
// empty.
|
// empty.
|
||||||
Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body,
|
Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body,
|
||||||
const std::vector<Output>& inputs,
|
const std::vector<Output>& inputs,
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ package(
|
||||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
|
|
@ -167,6 +168,7 @@ tf_kernel_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "index_ops_kernel_argmax_float_1d",
|
name = "index_ops_kernel_argmax_float_1d",
|
||||||
srcs = ["index_ops_kernel_argmax_float_1d.cc"],
|
srcs = ["index_ops_kernel_argmax_float_1d.cc"],
|
||||||
|
copts = tf_copts(),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
|
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
|
||||||
|
|
@ -179,6 +181,7 @@ cc_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "index_ops_kernel_argmax_float_2d",
|
name = "index_ops_kernel_argmax_float_2d",
|
||||||
srcs = ["index_ops_kernel_argmax_float_2d.cc"],
|
srcs = ["index_ops_kernel_argmax_float_2d.cc"],
|
||||||
|
copts = tf_copts(),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
|
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
|
||||||
#ifdef __AVX__
|
#ifdef TF_XLA_HAS_AVX
|
||||||
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX(
|
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX(
|
||||||
xla::cpu::runtime::V8F32AVX x) {
|
xla::cpu::runtime::V8F32AVX x) {
|
||||||
return Eigen::internal::pexp(x);
|
return Eigen::internal::pexp(x);
|
||||||
|
|
@ -29,7 +29,7 @@ xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX(
|
||||||
xla::cpu::runtime::V8F32AVX x) {
|
xla::cpu::runtime::V8F32AVX x) {
|
||||||
return Eigen::internal::plog(x);
|
return Eigen::internal::plog(x);
|
||||||
}
|
}
|
||||||
#endif // __AVX__
|
#endif // TF_XLA_HAS_AVX
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,11 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
|
#if defined(__AVX__)
|
||||||
|
#include <immintrin.h>
|
||||||
|
#define TF_XLA_HAS_AVX
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
|
@ -31,14 +36,16 @@ namespace runtime {
|
||||||
extern const char *const kExpV8F32AVXSymbolName;
|
extern const char *const kExpV8F32AVXSymbolName;
|
||||||
extern const char *const kLogV8F32AVXSymbolName;
|
extern const char *const kLogV8F32AVXSymbolName;
|
||||||
|
|
||||||
typedef float V8F32AVX __attribute__((__vector_size__(32)));
|
#ifdef TF_XLA_HAS_AVX
|
||||||
|
typedef __m256 V8F32AVX;
|
||||||
|
#endif
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
#ifdef __AVX__
|
#ifdef TF_XLA_HAS_AVX
|
||||||
// The following functions are vectorized versions of a selection of libm
|
// The following functions are vectorized versions of a selection of libm
|
||||||
// library functions.
|
// library functions.
|
||||||
// References to these functions are created by the LLVM vectorizer.
|
// References to these functions are created by the LLVM vectorizer.
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
|
||||||
#ifdef __ARM_NEON__
|
#ifdef TF_XLA_HAS_NEON
|
||||||
|
|
||||||
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON(
|
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON(
|
||||||
xla::cpu::runtime::V4F32NEON x) {
|
xla::cpu::runtime::V4F32NEON x) {
|
||||||
|
|
@ -32,7 +32,7 @@ xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON(
|
||||||
return Eigen::internal::plog(p);
|
return Eigen::internal::plog(p);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // __ARM_NEON__
|
#endif // TF_XLA_HAS_NEON
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||||
// __attribute__((__vector_size__(*))). Unfortunately, the typedef for the ARM
|
// __attribute__((__vector_size__(*))). Unfortunately, the typedef for the ARM
|
||||||
// NEON SIMD types is not portable, so the type has to come from <arm_neon.h>
|
// NEON SIMD types is not portable, so the type has to come from <arm_neon.h>
|
||||||
#include <arm_neon.h>
|
#include <arm_neon.h>
|
||||||
|
#define TF_XLA_HAS_NEON
|
||||||
#endif // __ARM_NEON__
|
#endif // __ARM_NEON__
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
@ -36,12 +37,9 @@ namespace runtime {
|
||||||
extern const char *const kExpV4F32NEONSymbolName;
|
extern const char *const kExpV4F32NEONSymbolName;
|
||||||
extern const char *const kLogV4F32NEONSymbolName;
|
extern const char *const kLogV4F32NEONSymbolName;
|
||||||
|
|
||||||
#ifdef __ARM_NEON__
|
#ifdef TF_XLA_HAS_NEON
|
||||||
typedef float32x4_t V4F32NEON;
|
typedef float32x4_t V4F32NEON;
|
||||||
#else
|
#endif // TF_XLA_HAS_NEON
|
||||||
// On non-ARM platforms ensure the declaration is present
|
|
||||||
struct V4F32NEON;
|
|
||||||
#endif // __ARM_NEON__
|
|
||||||
|
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
|
|
@ -49,7 +47,7 @@ struct V4F32NEON;
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
#ifdef __ARM_NEON__
|
#ifdef TF_XLA_HAS_NEON
|
||||||
// The following functions are vectorized versions of a selection of libm
|
// The following functions are vectorized versions of a selection of libm
|
||||||
// library functions.
|
// library functions.
|
||||||
// References to these functions are created by the LLVM vectorizer.
|
// References to these functions are created by the LLVM vectorizer.
|
||||||
|
|
@ -58,7 +56,7 @@ xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON(
|
||||||
|
|
||||||
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON(
|
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON(
|
||||||
xla::cpu::runtime::V4F32NEON x);
|
xla::cpu::runtime::V4F32NEON x);
|
||||||
#endif // __ARM_NEON__
|
#endif // TF_XLA_HAS_NEON
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
|
||||||
#ifdef __SSE4_1__
|
#ifdef TF_XLA_HAS_SSE4_1
|
||||||
|
|
||||||
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE(
|
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE(
|
||||||
xla::cpu::runtime::V4F32SSE x) {
|
xla::cpu::runtime::V4F32SSE x) {
|
||||||
|
|
@ -33,7 +33,7 @@ xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE(
|
||||||
return Eigen::internal::plog(p);
|
return Eigen::internal::plog(p);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // __SSE4_1__
|
#endif // TF_XLA_HAS_SSE4_1
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,13 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
|
// MSVC does not have __SSE4_1__ macro. Eigen enables EIGEN_VECTORIZE_SSE4_1
|
||||||
|
// when __AVX__ is defined, we should do the same.
|
||||||
|
#if defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__))
|
||||||
|
#include <smmintrin.h>
|
||||||
|
#define TF_XLA_HAS_SSE4_1
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
|
@ -31,7 +38,9 @@ namespace runtime {
|
||||||
extern const char *const kExpV4F32SSESymbolName;
|
extern const char *const kExpV4F32SSESymbolName;
|
||||||
extern const char *const kLogV4F32SSESymbolName;
|
extern const char *const kLogV4F32SSESymbolName;
|
||||||
|
|
||||||
typedef float V4F32SSE __attribute__((__vector_size__(16)));
|
#ifdef TF_XLA_HAS_SSE4_1
|
||||||
|
typedef __m128 V4F32SSE;
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
|
|
@ -39,7 +48,7 @@ typedef float V4F32SSE __attribute__((__vector_size__(16)));
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
#ifdef __SSE4_1__
|
#ifdef TF_XLA_HAS_SSE4_1
|
||||||
// The following functions are vectorized versions of a selection of libm
|
// The following functions are vectorized versions of a selection of libm
|
||||||
// library functions.
|
// library functions.
|
||||||
// References to these functions are created by the LLVM vectorizer.
|
// References to these functions are created by the LLVM vectorizer.
|
||||||
|
|
|
||||||
|
|
@ -64,14 +64,14 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
|
||||||
&ir_builder),
|
&ir_builder),
|
||||||
llvm::ConstantFP::get(vector_type, 9.0), &ir_builder);
|
llvm::ConstantFP::get(vector_type, 9.0), &ir_builder);
|
||||||
|
|
||||||
std::array<float, 7> numerator_coeffs(
|
std::array<float, 7> numerator_coeffs{
|
||||||
{-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
|
-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
|
||||||
5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
|
5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
|
||||||
4.89352455891786e-03f});
|
4.89352455891786e-03f};
|
||||||
|
|
||||||
std::array<float, 4> denominator_coeffs(
|
std::array<float, 4> denominator_coeffs{
|
||||||
{1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
|
1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
|
||||||
4.89352518554385e-03f});
|
4.89352518554385e-03f};
|
||||||
|
|
||||||
llvm::Value* input_squared =
|
llvm::Value* input_squared =
|
||||||
ir_builder.CreateFMul(input_clamped, input_clamped);
|
ir_builder.CreateFMul(input_clamped, input_clamped);
|
||||||
|
|
|
||||||
|
|
@ -103,17 +103,17 @@ llvm::StringRef GetHostCpuName() {
|
||||||
|
|
||||||
CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() {
|
CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() {
|
||||||
CompilerFunctor::VectorIntrinsics intrinsics;
|
CompilerFunctor::VectorIntrinsics intrinsics;
|
||||||
#ifdef __SSE4_1__
|
#ifdef TF_XLA_HAS_SSE4_1
|
||||||
intrinsics.sse_intrinsics = true;
|
intrinsics.sse_intrinsics = true;
|
||||||
#else
|
#else
|
||||||
intrinsics.sse_intrinsics = false;
|
intrinsics.sse_intrinsics = false;
|
||||||
#endif
|
#endif
|
||||||
#ifdef __AVX__
|
#ifdef TF_XLA_HAS_AVX
|
||||||
intrinsics.avx_intrinsics = true;
|
intrinsics.avx_intrinsics = true;
|
||||||
#else
|
#else
|
||||||
intrinsics.avx_intrinsics = false;
|
intrinsics.avx_intrinsics = false;
|
||||||
#endif
|
#endif
|
||||||
#ifdef __ARM_NEON__
|
#ifdef TF_XLA_HAS_NEON
|
||||||
intrinsics.neon_intrinsics = true;
|
intrinsics.neon_intrinsics = true;
|
||||||
#else
|
#else
|
||||||
intrinsics.neon_intrinsics = false;
|
intrinsics.neon_intrinsics = false;
|
||||||
|
|
@ -215,15 +215,15 @@ bool RegisterKnownJITSymbols() {
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
|
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
|
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
|
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
|
||||||
#ifdef __ARM_NEON__
|
#ifdef TF_XLA_HAS_NEON
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32NEON);
|
REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32NEON);
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON);
|
REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON);
|
||||||
#endif
|
#endif
|
||||||
#ifdef __SSE4_1__
|
#ifdef TF_XLA_HAS_SSE4_1
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32SSE);
|
REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32SSE);
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE);
|
REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE);
|
||||||
#endif
|
#endif
|
||||||
#ifdef __AVX__
|
#ifdef TF_XLA_HAS_AVX
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(ExpV8F32AVX);
|
REGISTER_CPU_RUNTIME_SYMBOL(ExpV8F32AVX);
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX);
|
REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -407,7 +407,9 @@ endif()
|
||||||
|
|
||||||
# Let's get to work!
|
# Let's get to work!
|
||||||
include(tf_core_framework.cmake)
|
include(tf_core_framework.cmake)
|
||||||
|
if (tensorflow_ENABLE_GPU)
|
||||||
include(tf_stream_executor.cmake)
|
include(tf_stream_executor.cmake)
|
||||||
|
endif()
|
||||||
|
|
||||||
include(tf_core_cpu.cmake)
|
include(tf_core_cpu.cmake)
|
||||||
include(tf_core_ops.cmake)
|
include(tf_core_ops.cmake)
|
||||||
|
|
|
||||||
26
tensorflow/contrib/cmake/make.sh
Executable file
26
tensorflow/contrib/cmake/make.sh
Executable file
|
|
@ -0,0 +1,26 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
(
|
||||||
|
cd "$(dirname "$0")"
|
||||||
|
mkdir -p _build
|
||||||
|
|
||||||
|
(
|
||||||
|
cd _build
|
||||||
|
rm -rf -- *
|
||||||
|
cmake ..
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
@ -86,9 +86,7 @@ tensorflow/python/ops/distributions
|
||||||
tensorflow/python/ops/linalg
|
tensorflow/python/ops/linalg
|
||||||
tensorflow/python/ops/losses
|
tensorflow/python/ops/losses
|
||||||
tensorflow/python/platform
|
tensorflow/python/platform
|
||||||
tensorflow/python/platform/default
|
tensorflow/python/profiler
|
||||||
tensorflow/python/platform/summary
|
|
||||||
tensorflow/python/profiler/
|
|
||||||
tensorflow/python/profiler/internal
|
tensorflow/python/profiler/internal
|
||||||
tensorflow/python/saved_model
|
tensorflow/python/saved_model
|
||||||
tensorflow/python/summary
|
tensorflow/python/summary
|
||||||
|
|
@ -115,8 +113,6 @@ tensorflow/contrib/batching/kernels
|
||||||
tensorflow/contrib/batching/python
|
tensorflow/contrib/batching/python
|
||||||
tensorflow/contrib/batching/python/ops
|
tensorflow/contrib/batching/python/ops
|
||||||
tensorflow/contrib/bayesflow
|
tensorflow/contrib/bayesflow
|
||||||
tensorflow/contrib/bayesflow/examples
|
|
||||||
tensorflow/contrib/bayesflow/examples/reinforce_simple
|
|
||||||
tensorflow/contrib/bayesflow/python
|
tensorflow/contrib/bayesflow/python
|
||||||
tensorflow/contrib/bayesflow/python/ops
|
tensorflow/contrib/bayesflow/python/ops
|
||||||
tensorflow/contrib/boosted_trees
|
tensorflow/contrib/boosted_trees
|
||||||
|
|
@ -212,16 +208,6 @@ tensorflow/contrib/input_pipeline/python/ops
|
||||||
tensorflow/contrib/integrate
|
tensorflow/contrib/integrate
|
||||||
tensorflow/contrib/integrate/python
|
tensorflow/contrib/integrate/python
|
||||||
tensorflow/contrib/integrate/python/ops
|
tensorflow/contrib/integrate/python/ops
|
||||||
tensorflow/contrib/ios_examples
|
|
||||||
tensorflow/contrib/ios_examples/benchmark
|
|
||||||
tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj
|
|
||||||
tensorflow/contrib/ios_examples/benchmark/data
|
|
||||||
tensorflow/contrib/ios_examples/camera
|
|
||||||
tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj
|
|
||||||
tensorflow/contrib/ios_examples/camera/en.lproj
|
|
||||||
tensorflow/contrib/ios_examples/simple
|
|
||||||
tensorflow/contrib/ios_examples/simple/data
|
|
||||||
tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj
|
|
||||||
tensorflow/contrib/keras
|
tensorflow/contrib/keras
|
||||||
tensorflow/contrib/keras/api
|
tensorflow/contrib/keras/api
|
||||||
tensorflow/contrib/keras/api/keras
|
tensorflow/contrib/keras/api/keras
|
||||||
|
|
@ -276,9 +262,6 @@ tensorflow/contrib/layers/python/ops
|
||||||
tensorflow/contrib/learn
|
tensorflow/contrib/learn
|
||||||
tensorflow/contrib/learn/python
|
tensorflow/contrib/learn/python
|
||||||
tensorflow/contrib/learn/python/learn
|
tensorflow/contrib/learn/python/learn
|
||||||
tensorflow/contrib/learn/python/learn/dataframe
|
|
||||||
tensorflow/contrib/learn/python/learn/dataframe/queues
|
|
||||||
tensorflow/contrib/learn/python/learn/dataframe/transforms
|
|
||||||
tensorflow/contrib/learn/python/learn/datasets
|
tensorflow/contrib/learn/python/learn/datasets
|
||||||
tensorflow/contrib/learn/python/learn/datasets/data
|
tensorflow/contrib/learn/python/learn/datasets/data
|
||||||
tensorflow/contrib/learn/python/learn/estimators
|
tensorflow/contrib/learn/python/learn/estimators
|
||||||
|
|
@ -301,6 +284,9 @@ tensorflow/contrib/linear_optimizer/kernels
|
||||||
tensorflow/contrib/linear_optimizer/kernels/g3doc
|
tensorflow/contrib/linear_optimizer/kernels/g3doc
|
||||||
tensorflow/contrib/linear_optimizer/python
|
tensorflow/contrib/linear_optimizer/python
|
||||||
tensorflow/contrib/linear_optimizer/python/ops
|
tensorflow/contrib/linear_optimizer/python/ops
|
||||||
|
# TODO(drpngx): Fix failing imports
|
||||||
|
# tensorflow/contrib/lite/python
|
||||||
|
# tensorflow/contrib/lite/toco/python
|
||||||
tensorflow/contrib/lookup
|
tensorflow/contrib/lookup
|
||||||
tensorflow/contrib/losses
|
tensorflow/contrib/losses
|
||||||
tensorflow/contrib/losses/python
|
tensorflow/contrib/losses/python
|
||||||
|
|
@ -314,7 +300,6 @@ tensorflow/contrib/memory_stats/python
|
||||||
tensorflow/contrib/memory_stats/python/ops
|
tensorflow/contrib/memory_stats/python/ops
|
||||||
tensorflow/contrib/meta_graph_transform
|
tensorflow/contrib/meta_graph_transform
|
||||||
tensorflow/contrib/metrics
|
tensorflow/contrib/metrics
|
||||||
tensorflow/contrib/metrics/ops
|
|
||||||
tensorflow/contrib/metrics/python
|
tensorflow/contrib/metrics/python
|
||||||
tensorflow/contrib/metrics/python/metrics
|
tensorflow/contrib/metrics/python/metrics
|
||||||
tensorflow/contrib/metrics/python/ops
|
tensorflow/contrib/metrics/python/ops
|
||||||
|
|
@ -346,7 +331,6 @@ tensorflow/contrib/pi_examples/label_image
|
||||||
tensorflow/contrib/pi_examples/label_image/data
|
tensorflow/contrib/pi_examples/label_image/data
|
||||||
tensorflow/contrib/periodic_resample
|
tensorflow/contrib/periodic_resample
|
||||||
tensorflow/contrib/periodic_resample/python
|
tensorflow/contrib/periodic_resample/python
|
||||||
tensorflow/contrib/periodic_resample/python/kernels
|
|
||||||
tensorflow/contrib/periodic_resample/python/ops
|
tensorflow/contrib/periodic_resample/python/ops
|
||||||
tensorflow/contrib/predictor
|
tensorflow/contrib/predictor
|
||||||
tensorflow/contrib/quantization
|
tensorflow/contrib/quantization
|
||||||
|
|
@ -411,13 +395,9 @@ tensorflow/contrib/tensorboard/plugins
|
||||||
tensorflow/contrib/tensorboard/plugins/projector
|
tensorflow/contrib/tensorboard/plugins/projector
|
||||||
tensorflow/contrib/tensor_forest
|
tensorflow/contrib/tensor_forest
|
||||||
tensorflow/contrib/tensor_forest/client
|
tensorflow/contrib/tensor_forest/client
|
||||||
tensorflow/contrib/tensor_forest/core
|
|
||||||
tensorflow/contrib/tensor_forest/core/ops
|
|
||||||
tensorflow/contrib/tensor_forest/data
|
|
||||||
tensorflow/contrib/tensor_forest/hybrid
|
tensorflow/contrib/tensor_forest/hybrid
|
||||||
tensorflow/contrib/tensor_forest/hybrid/core
|
tensorflow/contrib/tensor_forest/hybrid/core
|
||||||
tensorflow/contrib/tensor_forest/hybrid/core/ops
|
tensorflow/contrib/tensor_forest/hybrid/core/ops
|
||||||
tensorflow/contrib/tensor_forest/hybrid/ops
|
|
||||||
tensorflow/contrib/tensor_forest/hybrid/python
|
tensorflow/contrib/tensor_forest/hybrid/python
|
||||||
tensorflow/contrib/tensor_forest/hybrid/python/layers
|
tensorflow/contrib/tensor_forest/hybrid/python/layers
|
||||||
tensorflow/contrib/tensor_forest/hybrid/python/models
|
tensorflow/contrib/tensor_forest/hybrid/python/models
|
||||||
|
|
|
||||||
|
|
@ -126,10 +126,15 @@ STRING(REGEX REPLACE ";" "\\\\;" python_protos "${python_protos}")
|
||||||
STRING(REGEX REPLACE "\n" ";" python_protos "${python_protos}")
|
STRING(REGEX REPLACE "\n" ";" python_protos "${python_protos}")
|
||||||
|
|
||||||
foreach(python_proto ${python_protos})
|
foreach(python_proto ${python_protos})
|
||||||
|
if(NOT python_proto MATCHES "\#")
|
||||||
|
if(NOT EXISTS "${tensorflow_source_dir}/${python_proto}")
|
||||||
|
message(SEND_ERROR "Python proto directory not found: ${python_proto}")
|
||||||
|
endif()
|
||||||
file(GLOB_RECURSE tf_python_protos_src RELATIVE ${tensorflow_source_dir}
|
file(GLOB_RECURSE tf_python_protos_src RELATIVE ${tensorflow_source_dir}
|
||||||
"${tensorflow_source_dir}/${python_proto}/*.proto"
|
"${tensorflow_source_dir}/${python_proto}/*.proto"
|
||||||
)
|
)
|
||||||
list(APPEND tf_python_protos_srcs ${tf_python_protos_src})
|
list(APPEND tf_python_protos_srcs ${tf_python_protos_src})
|
||||||
|
endif()
|
||||||
endforeach(python_proto)
|
endforeach(python_proto)
|
||||||
|
|
||||||
RELATIVE_PROTOBUF_GENERATE_PYTHON(
|
RELATIVE_PROTOBUF_GENERATE_PYTHON(
|
||||||
|
|
@ -142,10 +147,15 @@ STRING(REGEX REPLACE ";" "\\\\;" python_protos_cc "${python_protos_cc}")
|
||||||
STRING(REGEX REPLACE "\n" ";" python_protos_cc "${python_protos_cc}")
|
STRING(REGEX REPLACE "\n" ";" python_protos_cc "${python_protos_cc}")
|
||||||
|
|
||||||
foreach(python_proto_cc ${python_protos_cc})
|
foreach(python_proto_cc ${python_protos_cc})
|
||||||
|
if(NOT python_proto_cc MATCHES "\#")
|
||||||
|
if(NOT EXISTS "${tensorflow_source_dir}/${python_proto_cc}")
|
||||||
|
message(SEND_ERROR "Python proto CC directory not found: ${python_proto_cc}")
|
||||||
|
endif()
|
||||||
file(GLOB_RECURSE tf_python_protos_cc_src RELATIVE ${tensorflow_source_dir}
|
file(GLOB_RECURSE tf_python_protos_cc_src RELATIVE ${tensorflow_source_dir}
|
||||||
"${tensorflow_source_dir}/${python_proto_cc}/*.proto"
|
"${tensorflow_source_dir}/${python_proto_cc}/*.proto"
|
||||||
)
|
)
|
||||||
list(APPEND tf_python_protos_cc_srcs ${tf_python_protos_cc_src})
|
list(APPEND tf_python_protos_cc_srcs ${tf_python_protos_cc_src})
|
||||||
|
endif()
|
||||||
endforeach(python_proto_cc)
|
endforeach(python_proto_cc)
|
||||||
|
|
||||||
RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS
|
RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS
|
||||||
|
|
@ -199,7 +209,12 @@ STRING(REGEX REPLACE ";" "\\\\;" python_modules "${python_modules}")
|
||||||
STRING(REGEX REPLACE "\n" ";" python_modules "${python_modules}")
|
STRING(REGEX REPLACE "\n" ";" python_modules "${python_modules}")
|
||||||
|
|
||||||
foreach(python_module ${python_modules})
|
foreach(python_module ${python_modules})
|
||||||
|
if(NOT python_module MATCHES "\#")
|
||||||
|
if(NOT EXISTS "${tensorflow_source_dir}/${python_module}")
|
||||||
|
message(SEND_ERROR "Python module not found: ${python_module}")
|
||||||
|
endif()
|
||||||
add_python_module(${python_module})
|
add_python_module(${python_module})
|
||||||
|
endif()
|
||||||
endforeach(python_module)
|
endforeach(python_module)
|
||||||
|
|
||||||
add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD
|
add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD
|
||||||
|
|
|
||||||
|
|
@ -2674,15 +2674,24 @@ def spatial_softmax(features,
|
||||||
indexing='ij')
|
indexing='ij')
|
||||||
pos_x = array_ops.reshape(pos_x, [height * width])
|
pos_x = array_ops.reshape(pos_x, [height * width])
|
||||||
pos_y = array_ops.reshape(pos_y, [height * width])
|
pos_y = array_ops.reshape(pos_y, [height * width])
|
||||||
|
|
||||||
if temperature is None:
|
if temperature is None:
|
||||||
temperature_collections = utils.get_variable_collections(
|
temp_initializer = init_ops.ones_initializer()
|
||||||
|
else:
|
||||||
|
temp_initializer = init_ops.constant_initializer(temperature)
|
||||||
|
|
||||||
|
if not trainable:
|
||||||
|
temp_collections = None
|
||||||
|
else:
|
||||||
|
temp_collections = utils.get_variable_collections(
|
||||||
variables_collections, 'temperature')
|
variables_collections, 'temperature')
|
||||||
|
|
||||||
temperature = variables.model_variable(
|
temperature = variables.model_variable(
|
||||||
'temperature',
|
'temperature',
|
||||||
shape=(),
|
shape=(),
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
initializer=init_ops.ones_initializer(),
|
initializer=temp_initializer,
|
||||||
collections=temperature_collections,
|
collections=temp_collections,
|
||||||
trainable=trainable)
|
trainable=trainable)
|
||||||
if data_format == 'NCHW':
|
if data_format == 'NCHW':
|
||||||
features = array_ops.reshape(features, [-1, height * width])
|
features = array_ops.reshape(features, [-1, height * width])
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,6 @@ limitations under the License.
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
|
||||||
template <typename T>
|
|
||||||
bool ConvertHelper(const string& s, T* value);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename Tlabel>
|
template <typename T, typename Tlabel>
|
||||||
class DecodeLibsvmOp : public OpKernel {
|
class DecodeLibsvmOp : public OpKernel {
|
||||||
|
|
@ -57,7 +53,7 @@ class DecodeLibsvmOp : public OpKernel {
|
||||||
"]: \"", input_flat(i), "\""));
|
"]: \"", input_flat(i), "\""));
|
||||||
Tlabel label_value;
|
Tlabel label_value;
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, ConvertHelper<Tlabel>(entries[0], &label_value),
|
ctx, strings::SafeStringToNumeric<Tlabel>(entries[0], &label_value),
|
||||||
errors::InvalidArgument("Label format incorrect: ", entries[0]));
|
errors::InvalidArgument("Label format incorrect: ", entries[0]));
|
||||||
label(i) = label_value;
|
label(i) = label_value;
|
||||||
for (int j = 1; j < entries.size(); j++) {
|
for (int j = 1; j < entries.size(); j++) {
|
||||||
|
|
@ -74,7 +70,7 @@ class DecodeLibsvmOp : public OpKernel {
|
||||||
"Feature index should be >= 0, got ", feature_index));
|
"Feature index should be >= 0, got ", feature_index));
|
||||||
T feature_value;
|
T feature_value;
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, ConvertHelper<T>(pair[1], &feature_value),
|
ctx, strings::SafeStringToNumeric<T>(pair[1], &feature_value),
|
||||||
errors::InvalidArgument("Feature format incorrect: ", entries[j]));
|
errors::InvalidArgument("Feature format incorrect: ", entries[j]));
|
||||||
out_values.emplace_back(feature_value);
|
out_values.emplace_back(feature_value);
|
||||||
out_indices.emplace_back(std::pair<int64, int64>(i, feature_index));
|
out_indices.emplace_back(std::pair<int64, int64>(i, feature_index));
|
||||||
|
|
@ -128,25 +124,6 @@ class DecodeLibsvmOp : public OpKernel {
|
||||||
int64 num_features_;
|
int64 num_features_;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace {
|
|
||||||
template <>
|
|
||||||
bool ConvertHelper<float>(const string& s, float* value) {
|
|
||||||
return strings::safe_strtof(s.c_str(), value);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
bool ConvertHelper<double>(const string& s, double* value) {
|
|
||||||
return strings::safe_strtod(s.c_str(), value);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
bool ConvertHelper<int32>(const string& s, int32* value) {
|
|
||||||
return strings::safe_strto32(s.c_str(), value);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
bool ConvertHelper<int64>(const string& s, int64* value) {
|
|
||||||
return strings::safe_strto64(s.c_str(), value);
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
#define REGISTER_KERNEL(type) \
|
#define REGISTER_KERNEL(type) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
|
REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
|
|
|
||||||
|
|
@ -80,8 +80,7 @@ FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file,
|
||||||
} else {
|
} else {
|
||||||
allocation_ = new FileCopyAllocation(filename, error_reporter);
|
allocation_ = new FileCopyAllocation(filename, error_reporter);
|
||||||
}
|
}
|
||||||
if (!allocation_->valid()) return;
|
if (!allocation_->valid() || !CheckModelIdentifier()) return;
|
||||||
if (!CheckModelIdentifier()) return;
|
|
||||||
|
|
||||||
model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes());
|
model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -137,8 +137,8 @@ Following are the ops supported for using On-Device Smart Reply model:
|
||||||
|
|
||||||
* **HASHTABLE_LOOKUP**
|
* **HASHTABLE_LOOKUP**
|
||||||
|
|
||||||
This is a custom op that uses label id from predict op and looks up the
|
This is an op inside TensorFlow Lite that uses label id from predict op and
|
||||||
response text from the given label id.
|
looks up the response text from the given label id.
|
||||||
|
|
||||||
## Further Information
|
## Further Information
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -144,6 +144,7 @@ py_test(
|
||||||
tags = ["no_pip"],
|
tags = ["no_pip"],
|
||||||
deps = [
|
deps = [
|
||||||
":predictor_factories",
|
":predictor_factories",
|
||||||
|
":testing_common",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -68,10 +68,10 @@ class CoreEstimatorPredictor(predictor.Predictor):
|
||||||
serving_input_receiver = serving_input_receiver_fn()
|
serving_input_receiver = serving_input_receiver_fn()
|
||||||
signature_def = _get_signature_def(
|
signature_def = _get_signature_def(
|
||||||
serving_input_receiver, estimator, output_key)
|
serving_input_receiver, estimator, output_key)
|
||||||
checkpoint_path = estimator.model_dir
|
checkpoint_dir = estimator.model_dir
|
||||||
self._session = monitored_session.MonitoredSession(
|
self._session = monitored_session.MonitoredSession(
|
||||||
session_creator=monitored_session.ChiefSessionCreator(
|
session_creator=monitored_session.ChiefSessionCreator(
|
||||||
checkpoint_filename_with_path=checkpoint_path))
|
checkpoint_dir=checkpoint_dir))
|
||||||
|
|
||||||
feed_tensor_info = signature_def.inputs
|
feed_tensor_info = signature_def.inputs
|
||||||
self._feed_tensors = {k: self._graph.get_tensor_by_name(v.name)
|
self._feed_tensors = {k: self._graph.get_tensor_by_name(v.name)
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
||||||
from tensorflow.contrib.predictor import contrib_estimator_predictor
|
from tensorflow.contrib.predictor import contrib_estimator_predictor
|
||||||
from tensorflow.contrib.predictor import core_estimator_predictor
|
from tensorflow.contrib.predictor import core_estimator_predictor
|
||||||
from tensorflow.contrib.predictor import saved_model_predictor
|
from tensorflow.contrib.predictor import saved_model_predictor
|
||||||
|
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import estimator as contrib_estimator
|
||||||
from tensorflow.python.estimator import estimator as core_estimator
|
from tensorflow.python.estimator import estimator as core_estimator
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -85,7 +87,7 @@ def from_estimator(estimator,
|
||||||
TypeError: if `estimator` is a contrib `Estimator` instead of a core
|
TypeError: if `estimator` is a contrib `Estimator` instead of a core
|
||||||
`Estimator`.
|
`Estimator`.
|
||||||
"""
|
"""
|
||||||
if isinstance(estimator, estimator.Estimator):
|
if isinstance(estimator, contrib_estimator.Estimator):
|
||||||
raise TypeError('Espected estimator to be of type '
|
raise TypeError('Espected estimator to be of type '
|
||||||
'tf.python.estimator.Estimator, but got type '
|
'tf.python.estimator.Estimator, but got type '
|
||||||
'tf.contrib.learn.Estimator. You likely want to call '
|
'tf.contrib.learn.Estimator. You likely want to call '
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.predictor import predictor_factories
|
from tensorflow.contrib.predictor import predictor_factories
|
||||||
|
from tensorflow.contrib.predictor import testing_common
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
MODEL_DIR_NAME = 'contrib/predictor/test_export_dir'
|
MODEL_DIR_NAME = 'contrib/predictor/test_export_dir'
|
||||||
|
|
@ -46,6 +47,29 @@ class PredictorFactoriesTest(test.TestCase):
|
||||||
with self.assertRaisesRegexp(RuntimeError, bad_tags_regex):
|
with self.assertRaisesRegexp(RuntimeError, bad_tags_regex):
|
||||||
predictor_factories.from_saved_model(self._export_dir, tags='bad_tag')
|
predictor_factories.from_saved_model(self._export_dir, tags='bad_tag')
|
||||||
|
|
||||||
|
def testFromContribEstimator(self):
|
||||||
|
estimator = testing_common.get_arithmetic_estimator(core=False)
|
||||||
|
input_fn = testing_common.get_arithmetic_input_fn(core=False)
|
||||||
|
predictor_factories.from_contrib_estimator(estimator, input_fn,
|
||||||
|
output_alternative_key='sum')
|
||||||
|
|
||||||
|
def testFromContribEstimatorWithCoreEstimatorRaises(self):
|
||||||
|
estimator = testing_common.get_arithmetic_estimator(core=True)
|
||||||
|
input_fn = testing_common.get_arithmetic_input_fn(core=True)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
predictor_factories.from_contrib_estimator(estimator, input_fn)
|
||||||
|
|
||||||
|
def testFromCoreEstimator(self):
|
||||||
|
estimator = testing_common.get_arithmetic_estimator(core=True)
|
||||||
|
input_fn = testing_common.get_arithmetic_input_fn(core=True)
|
||||||
|
predictor_factories.from_estimator(estimator, input_fn)
|
||||||
|
|
||||||
|
def testFromCoreEstimatorWithContribEstimatorRaises(self):
|
||||||
|
estimator = testing_common.get_arithmetic_estimator(core=False)
|
||||||
|
input_fn = testing_common.get_arithmetic_input_fn(core=False)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
predictor_factories.from_estimator(estimator, input_fn)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None):
|
||||||
# A 1024-point STFT with frames of 64 ms and 75% overlap.
|
# A 1024-point STFT with frames of 64 ms and 75% overlap.
|
||||||
stfts = tf.contrib.signal.stft(pcm, frame_length=1024, frame_step=256,
|
stfts = tf.contrib.signal.stft(pcm, frame_length=1024, frame_step=256,
|
||||||
fft_length=1024)
|
fft_length=1024)
|
||||||
spectrograms = tf.abs(stft)
|
spectrograms = tf.abs(stfts)
|
||||||
|
|
||||||
# Warp the linear scale spectrograms into the mel-scale.
|
# Warp the linear scale spectrograms into the mel-scale.
|
||||||
num_spectrogram_bins = stfts.shape[-1].value
|
num_spectrogram_bins = stfts.shape[-1].value
|
||||||
|
|
|
||||||
|
|
@ -676,7 +676,7 @@ file were implicitly obtained from each provided variable's `var.op.name`.
|
||||||
|
|
||||||
This works well when the variable names in the checkpoint file match those in
|
This works well when the variable names in the checkpoint file match those in
|
||||||
the graph. However, sometimes, we want to restore a model from a checkpoint
|
the graph. However, sometimes, we want to restore a model from a checkpoint
|
||||||
whose variables have different names those in the current graph. In this case,
|
whose variables have different names to those in the current graph. In this case,
|
||||||
we must provide the `Saver` a dictionary that maps from each checkpoint variable
|
we must provide the `Saver` a dictionary that maps from each checkpoint variable
|
||||||
name to each graph variable. Consider the following example where the checkpoint
|
name to each graph variable. Consider the following example where the checkpoint
|
||||||
variables names are obtained via a simple function:
|
variables names are obtained via a simple function:
|
||||||
|
|
|
||||||
|
|
@ -753,9 +753,10 @@ def train(train_op,
|
||||||
if logdir:
|
if logdir:
|
||||||
sv.start_standard_services(sess)
|
sv.start_standard_services(sess)
|
||||||
elif startup_delay_steps > 0:
|
elif startup_delay_steps > 0:
|
||||||
|
# (use sys.maxsize because sys.maxint doesn't exist in Python 3)
|
||||||
_wait_for_step(sess, global_step,
|
_wait_for_step(sess, global_step,
|
||||||
min(startup_delay_steps, number_of_steps or
|
min(startup_delay_steps, number_of_steps or
|
||||||
sys.maxint))
|
sys.maxsize))
|
||||||
threads = sv.start_queue_runners(sess)
|
threads = sv.start_queue_runners(sess)
|
||||||
logging.info('Starting Queues.')
|
logging.info('Starting Queues.')
|
||||||
if is_chief and sync_optimizer is not None:
|
if is_chief and sync_optimizer is not None:
|
||||||
|
|
|
||||||
|
|
@ -530,7 +530,6 @@ py_library(
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":client_lib",
|
":client_lib",
|
||||||
"//tensorflow/contrib/framework:framework_py",
|
|
||||||
"//tensorflow/contrib/layers:layers_py",
|
"//tensorflow/contrib/layers:layers_py",
|
||||||
"//tensorflow/contrib/learn",
|
"//tensorflow/contrib/learn",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib import framework as contrib_framework
|
|
||||||
from tensorflow.contrib import layers
|
from tensorflow.contrib import layers
|
||||||
|
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
|
|
@ -190,7 +189,7 @@ def get_model_fn(params,
|
||||||
features, labels, input_weights=weights,
|
features, labels, input_weights=weights,
|
||||||
num_trainers=num_trainers,
|
num_trainers=num_trainers,
|
||||||
trainer_id=trainer_id),
|
trainer_id=trainer_id),
|
||||||
state_ops.assign_add(contrib_framework.get_global_step(), 1))
|
state_ops.assign_add(training_util.get_global_step(), 1))
|
||||||
|
|
||||||
# Put weights back in
|
# Put weights back in
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@ py_library(
|
||||||
"python/training/resample.py",
|
"python/training/resample.py",
|
||||||
"python/training/sampling_ops.py",
|
"python/training/sampling_ops.py",
|
||||||
"python/training/sequence_queueing_state_saver.py",
|
"python/training/sequence_queueing_state_saver.py",
|
||||||
"python/training/sgdr_learning_rate_decay.py",
|
|
||||||
"python/training/training.py",
|
"python/training/training.py",
|
||||||
"python/training/tuner.py",
|
"python/training/tuner.py",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,7 @@ load(
|
||||||
"if_linux_x86_64",
|
"if_linux_x86_64",
|
||||||
"if_mobile",
|
"if_mobile",
|
||||||
"if_not_mobile",
|
"if_not_mobile",
|
||||||
|
"if_windows",
|
||||||
"if_not_windows",
|
"if_not_windows",
|
||||||
"tf_copts",
|
"tf_copts",
|
||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
|
|
@ -562,7 +563,7 @@ cc_library(
|
||||||
"platform/prefetch.h",
|
"platform/prefetch.h",
|
||||||
"platform/thread_annotations.h",
|
"platform/thread_annotations.h",
|
||||||
"platform/types.h",
|
"platform/types.h",
|
||||||
],
|
] + if_windows(["platform/windows/integral_types.h"]),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps =
|
deps =
|
||||||
[
|
[
|
||||||
|
|
|
||||||
|
|
@ -761,6 +761,10 @@ int64 MinSystemMemory(int64 available_memory) {
|
||||||
// builds in windows); because in non-opt builds more system memory
|
// builds in windows); because in non-opt builds more system memory
|
||||||
// is necessary.
|
// is necessary.
|
||||||
min_system_memory *= 2;
|
min_system_memory *= 2;
|
||||||
|
#endif
|
||||||
|
#if defined(NVIDIA_TEGRA)
|
||||||
|
// 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM and Video RAM
|
||||||
|
min_system_memory = 1<<30;
|
||||||
#endif
|
#endif
|
||||||
return min_system_memory;
|
return min_system_memory;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -194,10 +194,9 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
tf_kernel_library(
|
||||||
name = "fill_functor",
|
name = "fill_functor",
|
||||||
srcs = ["fill_functor.cc"],
|
prefix = "fill_functor",
|
||||||
hdrs = ["fill_functor.h"],
|
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
|
|
@ -3043,6 +3042,7 @@ tf_kernel_library(
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}),
|
}),
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
"fill_functor.h",
|
||||||
"conv_grad_ops.h",
|
"conv_grad_ops.h",
|
||||||
"deep_conv2d.h",
|
"deep_conv2d.h",
|
||||||
"gemm_functors.h",
|
"gemm_functors.h",
|
||||||
|
|
@ -3067,6 +3067,7 @@ tf_kernel_library(
|
||||||
":conv_2d",
|
":conv_2d",
|
||||||
":conv_3d",
|
":conv_3d",
|
||||||
":image_resizer_state",
|
":image_resizer_state",
|
||||||
|
":fill_functor",
|
||||||
":ops_util",
|
":ops_util",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
|
@ -3173,7 +3174,9 @@ tf_kernel_library(
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "fused_batch_norm_op",
|
name = "fused_batch_norm_op",
|
||||||
prefix = "fused_batch_norm_op",
|
prefix = "fused_batch_norm_op",
|
||||||
deps = NN_DEPS,
|
deps = NN_DEPS + [
|
||||||
|
":fill_functor",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
|
|
|
||||||
|
|
@ -151,18 +151,6 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||||
typedef Eigen::SyclDevice SYCLDevice;
|
typedef Eigen::SyclDevice SYCLDevice;
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
namespace functor {
|
|
||||||
|
|
||||||
// Partial specialization of FillFunctor<Device=CPUDevice, T>.
|
|
||||||
template <typename T>
|
|
||||||
struct FillFunctor<CPUDevice, T> {
|
|
||||||
void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
|
|
||||||
typename TTypes<T>::ConstScalar in) {
|
|
||||||
out.device(d) = out.constant(in());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // end namespace functor
|
|
||||||
|
|
||||||
template <typename Device, typename T, typename Index>
|
template <typename Device, typename T, typename Index>
|
||||||
class FillOp : public OpKernel {
|
class FillOp : public OpKernel {
|
||||||
|
|
@ -191,28 +179,6 @@ class FillOp : public OpKernel {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
|
||||||
|
|
||||||
namespace functor {
|
|
||||||
// Partial specialization of FillFunctor<Device=SYCLDevice, T>.
|
|
||||||
template <typename T>
|
|
||||||
struct FillFunctor<SYCLDevice, T> {
|
|
||||||
void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
|
|
||||||
typename TTypes<T>::ConstScalar in) {
|
|
||||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
|
||||||
Eigen::array<int, 1> rank1{1};
|
|
||||||
#else
|
|
||||||
Eigen::IndexList<Eigen::type2index<1> > rank1;
|
|
||||||
#endif
|
|
||||||
const int size = out.dimension(0);
|
|
||||||
Eigen::array<int, 1> broadcast_dims{size};
|
|
||||||
|
|
||||||
To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace functor
|
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
|
||||||
|
|
||||||
#define REGISTER_KERNEL(D, TYPE) \
|
#define REGISTER_KERNEL(D, TYPE) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("Fill") \
|
REGISTER_KERNEL_BUILDER(Name("Fill") \
|
||||||
.Device(DEVICE_##D) \
|
.Device(DEVICE_##D) \
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_slice.h"
|
#include "tensorflow/core/framework/tensor_slice.h"
|
||||||
#include "tensorflow/core/kernels/conv_2d.h"
|
#include "tensorflow/core/kernels/conv_2d.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
#ifdef TENSORFLOW_USE_LIBXSMM
|
#ifdef TENSORFLOW_USE_LIBXSMM
|
||||||
#include "tensorflow/core/kernels/xsmm_conv2d.h"
|
#include "tensorflow/core/kernels/xsmm_conv2d.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -595,6 +596,13 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
||||||
if (filter_shape.num_elements() == 0) {
|
if (filter_shape.num_elements() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// If input is empty, set gradients to zero.
|
||||||
|
if (input.shape().num_elements() == 0) {
|
||||||
|
functor::SetZeroFunctor<Device, T> f;
|
||||||
|
f(context->eigen_device<Device>(), filter_backprop->flat<T>());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// For now we take the stride from the second and third dimensions only (we
|
// For now we take the stride from the second and third dimensions only (we
|
||||||
// do not support striding on the batch or depth dimension).
|
// do not support striding on the batch or depth dimension).
|
||||||
|
|
|
||||||
|
|
@ -19,12 +19,15 @@ limitations under the License.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
#include <limits>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "cuda/include/cuda.h"
|
#include "cuda/include/cuda.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/kernels/conv_2d.h"
|
#include "tensorflow/core/kernels/conv_2d.h"
|
||||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
#include "tensorflow/core/lib/math/math_util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
|
@ -223,186 +226,137 @@ __global__ void SwapDimension1And2InTensor3Simple(int nthreads, const T* input,
|
||||||
// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor,
|
// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor,
|
||||||
// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
|
// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
|
||||||
//
|
//
|
||||||
// Each thread block operates on a single tile, a square of dimensions TileSize
|
// Each thread block operates on a single tile, a rectangle of dimensions
|
||||||
// x TileSize. We require that the thread block's X dimension equals TileSize,
|
// TileSizeI x TileSizeJ.
|
||||||
// and its Y dimension equals NumSubTiles.
|
|
||||||
//
|
//
|
||||||
// For best performance, you should probably set TileSize equal to the number of
|
// In general, for best performance, you should probably set TileSizeI,
|
||||||
// threads in a warp (32 in nvidia GPUs). With a TileSize of 32, NumSubTiles ==
|
// TileSizeJ equal to the number of threads in a warp (32 in nvidia GPUs).
|
||||||
// 4 or 8 seems to get the best performance on K40 GPUs.
|
// With a TileSizeI, TileSizeJ of 32, NumThreads of 128 or 256 seems to get
|
||||||
template <typename T, int TileSize, int NumSubTiles, bool conjugate = false>
|
// the best performance on K40 GPUs.
|
||||||
__global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
|
template <typename T, int NumThreads, int TileSizeI, int TileSizeJ,
|
||||||
Dimension<3> input_dims,
|
bool conjugate = false>
|
||||||
T* output) {
|
__global__ void SwapDimension1And2InTensor3UsingTiles(
|
||||||
// One extra line in the inner dimension to avoid share memory bank conflict.
|
const T* __restrict__ input, Dimension<3> input_dims,
|
||||||
__shared__ T shared_memory_tile[TileSize][TileSize + 1];
|
T* __restrict__ output) {
|
||||||
|
eigen_assert(blockDim.x == NumThreads);
|
||||||
static_assert(TileSize % NumSubTiles == 0,
|
eigen_assert(blockDim.y == 1);
|
||||||
"TileSize must be divisible by NumSubTiles");
|
|
||||||
eigen_assert(blockDim.x == TileSize);
|
|
||||||
eigen_assert(blockDim.y == NumSubTiles);
|
|
||||||
eigen_assert(blockDim.z == 1);
|
eigen_assert(blockDim.z == 1);
|
||||||
eigen_assert(gridDim.y == 1);
|
eigen_assert(gridDim.y == 1);
|
||||||
eigen_assert(gridDim.z == 1);
|
eigen_assert(gridDim.z == 1);
|
||||||
|
|
||||||
// We break down the tile into NumSubTiles groups, so each thread processes
|
constexpr int ReadRowPerPass = NumThreads / TileSizeJ;
|
||||||
// kSubTileSize elements (except at the edges of the input).
|
constexpr int WriteRowPerPass = NumThreads / TileSizeI;
|
||||||
const int kSubTileSize = TileSize / NumSubTiles;
|
// One extra line in the inner dimension to avoid share memory bank conflict.
|
||||||
|
__shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
|
||||||
|
|
||||||
int x = threadIdx.x;
|
int x = threadIdx.x;
|
||||||
|
|
||||||
Dimension<3> output_dims = {
|
Dimension<3> output_dims = {
|
||||||
input_dims[0],
|
input_dims[0], input_dims[2], input_dims[1],
|
||||||
input_dims[2],
|
|
||||||
input_dims[1],
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Dimension<3> input_dims_in_tiles = {
|
Dimension<3> input_dims_in_tiles = {
|
||||||
input_dims[0],
|
input_dims[0], (input_dims[1] + TileSizeI - 1) / TileSizeI,
|
||||||
(input_dims[1] + TileSize - 1) / TileSize,
|
(input_dims[2] + TileSizeJ - 1) / TileSizeJ,
|
||||||
(input_dims[2] + TileSize - 1) / TileSize,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Index<3> input_tile_index =
|
Index<3> input_tile_index =
|
||||||
FlatToTensorIndex(blockIdx.x, input_dims_in_tiles);
|
FlatToTensorIndex(blockIdx.x, input_dims_in_tiles);
|
||||||
|
|
||||||
Index<3> input_tile_origin = {
|
Index<3> input_tile_origin = {
|
||||||
input_tile_index[0],
|
input_tile_index[0], input_tile_index[1] * TileSizeI,
|
||||||
input_tile_index[1] * TileSize,
|
input_tile_index[2] * TileSizeJ,
|
||||||
input_tile_index[2] * TileSize,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
int input_origin_flat_index =
|
int input_origin_flat_index =
|
||||||
TensorIndexToFlat(input_tile_origin, input_dims);
|
TensorIndexToFlat(input_tile_origin, input_dims);
|
||||||
|
|
||||||
int tile_width = TileSize;
|
bool full_tile = true;
|
||||||
|
int tile_width = TileSizeJ;
|
||||||
|
|
||||||
// Only the last row or column may not have the full size.
|
// Only the last row or column may not have the full size.
|
||||||
if (input_tile_index[2] == input_dims_in_tiles[2] - 1) {
|
if (input_tile_index[2] == input_dims_in_tiles[2] - 1) {
|
||||||
tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileSize;
|
tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileSizeJ;
|
||||||
|
full_tile &= false;
|
||||||
}
|
}
|
||||||
int tile_height = TileSize;
|
|
||||||
|
int tile_height = TileSizeI;
|
||||||
|
|
||||||
if (input_tile_index[1] == input_dims_in_tiles[1] - 1) {
|
if (input_tile_index[1] == input_dims_in_tiles[1] - 1) {
|
||||||
tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSize;
|
tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSizeI;
|
||||||
|
full_tile &= false;
|
||||||
}
|
}
|
||||||
|
|
||||||
int input_flat_index = input_origin_flat_index + x;
|
// Calculate effective thread number. This ensures that we use the largest
|
||||||
int y_start = static_cast<int>(threadIdx.y) * kSubTileSize;
|
// number of threads available to form a regular thread block with no
|
||||||
|
// trailing incomplete lines.
|
||||||
|
constexpr int in_effective_thread_num = NumThreads / TileSizeJ * TileSizeJ;
|
||||||
|
|
||||||
// Load the data from input memory to the shared memory tile.
|
if (x < in_effective_thread_num) {
|
||||||
if (x < tile_width) {
|
// Orient the logical thread block with respect to the input array.
|
||||||
int y_end = min(y_start + kSubTileSize, tile_height);
|
// ie. align the contiguous dimension of thread blocks with the contiguous
|
||||||
for (int y = y_start; y < y_end; y++) {
|
// dimension of the input array.
|
||||||
shared_memory_tile[y][x] = maybe_conj<T, conjugate>::run(
|
int ti = x / TileSizeJ;
|
||||||
input[input_flat_index + y * input_dims[2]]);
|
int tj = x % TileSizeJ;
|
||||||
|
int input_index = input_origin_flat_index + ti * input_dims[2] + tj;
|
||||||
|
int input_increment = ReadRowPerPass * input_dims[2];
|
||||||
|
|
||||||
|
if (full_tile) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_loc = ti; i_loc < (TileSizeI); i_loc += ReadRowPerPass) {
|
||||||
|
shared_memory_tile[i_loc][tj] =
|
||||||
|
maybe_conj<T, conjugate>::run(input[input_index]);
|
||||||
|
input_index += input_increment;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (tj < tile_width) {
|
||||||
|
for (int i_loc = ti; i_loc < (tile_height); i_loc += ReadRowPerPass) {
|
||||||
|
shared_memory_tile[i_loc][tj] =
|
||||||
|
maybe_conj<T, conjugate>::run(input[input_index]);
|
||||||
|
input_index += input_increment;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
Index<3> output_tile_index = {
|
Index<3> output_tile_index = {
|
||||||
input_tile_index[0],
|
input_tile_index[0], input_tile_index[2], input_tile_index[1],
|
||||||
input_tile_index[2],
|
|
||||||
input_tile_index[1],
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Index<3> output_tile_origin = {
|
Index<3> output_tile_origin = {
|
||||||
output_tile_index[0],
|
output_tile_index[0], output_tile_index[1] * TileSizeJ,
|
||||||
output_tile_index[1] * TileSize,
|
output_tile_index[2] * TileSizeI,
|
||||||
output_tile_index[2] * TileSize,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
int output_origin_flat_index =
|
int output_origin_flat_index =
|
||||||
TensorIndexToFlat(output_tile_origin, output_dims);
|
TensorIndexToFlat(output_tile_origin, output_dims);
|
||||||
|
|
||||||
int output_flat_index = output_origin_flat_index + x;
|
constexpr int out_effective_thread_num = NumThreads / TileSizeI * TileSizeI;
|
||||||
|
|
||||||
// Load the data from the shared memory tile to the output memory.
|
if (x < out_effective_thread_num) {
|
||||||
if (x < tile_height) {
|
// Re-orient the logical thread block with respect to the output array.
|
||||||
int y_end = min(y_start + kSubTileSize, tile_width);
|
// ie. align the contiguous dimension of thread blocks with contiguous
|
||||||
for (int y = y_start; y < y_end; y++) {
|
// dimension of the output array.
|
||||||
output[output_flat_index + y * output_dims[2]] = shared_memory_tile[x][y];
|
int ti = x / TileSizeI;
|
||||||
|
int tj = x % TileSizeI;
|
||||||
|
int output_index = output_origin_flat_index + ti * output_dims[2] + tj;
|
||||||
|
int output_increment = WriteRowPerPass * output_dims[2];
|
||||||
|
|
||||||
|
if (full_tile) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_loc = ti; i_loc < (TileSizeJ); i_loc += WriteRowPerPass) {
|
||||||
|
output[output_index] = shared_memory_tile[tj][i_loc];
|
||||||
|
output_index += output_increment;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
if (tj < tile_height) {
|
||||||
|
for (int i_loc = ti; i_loc < (tile_width); i_loc += WriteRowPerPass) {
|
||||||
|
output[output_index] = shared_memory_tile[tj][i_loc];
|
||||||
|
output_index += output_increment;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor
|
|
||||||
// when only one of the dimension sizes is smaller than 16,
|
|
||||||
// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
|
|
||||||
//
|
|
||||||
// small_dim = the_smaller_dimension_size
|
|
||||||
// large_dim = the_larger_dimension_size
|
|
||||||
// tile_num_per_block = blockDim.x
|
|
||||||
// kTileLength = small_dim
|
|
||||||
//
|
|
||||||
// Each thread block operates on a single rectangle tile, where its width is
|
|
||||||
// kTileLength (we currently set it to 64) and its height is small_dim,
|
|
||||||
// We set the thread block's X dimension to be tile_num_per_block, and its Y
|
|
||||||
// and Z to be one.
|
|
||||||
template <typename T, int ShmemSize, bool SmallDim2, bool conjugate = false>
|
|
||||||
__global__ void SwapDimension1And2InTensor3SmallDim(const T* input,
|
|
||||||
int batch_per_block,
|
|
||||||
Dimension<3> input_dims,
|
|
||||||
T* output) {
|
|
||||||
// TODO(yangzihao) avoid share memory bank conflict.
|
|
||||||
__shared__ T shared_memory_tile[ShmemSize];
|
|
||||||
|
|
||||||
eigen_assert(blockDim.y == 1);
|
|
||||||
eigen_assert(blockDim.z == 1);
|
|
||||||
eigen_assert(gridDim.z == 1);
|
|
||||||
|
|
||||||
int block_offset = blockIdx.x * blockDim.x;
|
|
||||||
|
|
||||||
int x = threadIdx.x;
|
|
||||||
int tile_height = blockDim.x;
|
|
||||||
|
|
||||||
// Get tile height, width, and thread/block origin indices.
|
|
||||||
int small_dim = SmallDim2 ? input_dims[2] : input_dims[1];
|
|
||||||
int large_dim = SmallDim2 ? input_dims[1] : input_dims[2];
|
|
||||||
|
|
||||||
int global_offset = small_dim * large_dim * (blockIdx.y * batch_per_block) +
|
|
||||||
(SmallDim2 ? block_offset * small_dim : block_offset);
|
|
||||||
if (global_offset >= (input_dims[0] * input_dims[1] * input_dims[2])) return;
|
|
||||||
|
|
||||||
for (int batch = 0; batch < batch_per_block; ++batch) {
|
|
||||||
int block_origin_idx =
|
|
||||||
small_dim * large_dim * (blockIdx.y * batch_per_block + batch);
|
|
||||||
int thread_origin_idx =
|
|
||||||
block_origin_idx +
|
|
||||||
(SmallDim2 ? block_offset * small_dim : block_offset) + x;
|
|
||||||
|
|
||||||
if (block_offset + blockDim.x > large_dim) {
|
|
||||||
tile_height = large_dim - block_offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Load a continuous memory region to shared memory tile.
|
|
||||||
if (x < tile_height) {
|
|
||||||
for (int y = 0; y < small_dim; y++) {
|
|
||||||
int shmem_index =
|
|
||||||
SmallDim2 ? (x + y * tile_height) : (x * small_dim + y);
|
|
||||||
shared_memory_tile[shmem_index] = maybe_conj<T, conjugate>::run(
|
|
||||||
ldg(input + thread_origin_idx +
|
|
||||||
y * (SmallDim2 ? tile_height : large_dim)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Get block origin index for output array.
|
|
||||||
int output_block_offset = block_origin_idx;
|
|
||||||
int output_block_idx = SmallDim2 ? block_offset : block_offset * small_dim;
|
|
||||||
int output_block_origin_idx = output_block_offset + output_block_idx;
|
|
||||||
|
|
||||||
// Store the transposed memory region in shared memory to device.
|
|
||||||
if (x < tile_height) {
|
|
||||||
for (int y = 0; y < small_dim; y++) {
|
|
||||||
int output_idx = output_block_origin_idx + x +
|
|
||||||
y * (SmallDim2 ? large_dim : tile_height);
|
|
||||||
int shmem_index =
|
|
||||||
SmallDim2 ? (x * small_dim + y) : (x + y * tile_height);
|
|
||||||
output[output_idx] = shared_memory_tile[shmem_index];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -548,6 +502,380 @@ struct PadInput<GPUDevice, T, int, NDIMS> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// We want std::equal_to and std::greater, but they're not constexpr until
|
||||||
|
// C++14.
|
||||||
|
struct EqualTo {
|
||||||
|
constexpr bool operator()(int a, int b) const { return a == b; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GreaterThan {
|
||||||
|
constexpr bool operator()(int a, int b) const { return a > b; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// For each data type, the tile size posibility frontier denotes the tile size
|
||||||
|
// combinations that consume the most computational resources constrained by
|
||||||
|
// - number of threads per SM limit,
|
||||||
|
// - limit on size of the short dimension (<=15) due to the definition of
|
||||||
|
// narrow matrix,
|
||||||
|
// - shared memory limit and
|
||||||
|
// - some experimentally determined, type-specific constraint on the product of
|
||||||
|
// two side lengths to increase grid-level parallelism.
|
||||||
|
//
|
||||||
|
// A tile size combination lies on the frontier if and only if one or more
|
||||||
|
// constraint mentioned above is hit. Tile size combinations lying outside this
|
||||||
|
// frontier are either not possible, or are slower than the alternatives.
|
||||||
|
//
|
||||||
|
// It is instrumental to consider, for each data type, two subsets of the
|
||||||
|
// corresponding frontier:
|
||||||
|
// - long side frontier: the union of the biggest tile size combination for
|
||||||
|
// each legal long side len.
|
||||||
|
// - non long side frontier: the frontier set minus the long side frontier.
|
||||||
|
//
|
||||||
|
// TileSizePossibilityFrontierCheck defines the frontier using only the long
|
||||||
|
// side frontier tile size combinations (since one can easily extrapolate
|
||||||
|
// the entire frontier from this subset). It serves as a utility function
|
||||||
|
// to help us determine where a tile size combination of interest lies with
|
||||||
|
// resepect to the frontier.
|
||||||
|
template <typename Op>
|
||||||
|
constexpr bool TileSizePossibilityFrontierCheck(int TileLongSide,
|
||||||
|
int TileShortSide,
|
||||||
|
int size_of_t, Op op) {
|
||||||
|
// clang-format off
|
||||||
|
return (size_of_t == 16 && ((TileLongSide == 32 && op(TileShortSide, 4)) ||
|
||||||
|
(TileLongSide == 64 && op(TileShortSide, 4)) ||
|
||||||
|
(TileLongSide == 128 && op(TileShortSide, 4)) ||
|
||||||
|
(TileLongSide == 256 && op(TileShortSide, 2)))) ||
|
||||||
|
(size_of_t == 8 && ((TileLongSide == 32 && op(TileShortSide, 15)) ||
|
||||||
|
(TileLongSide == 64 && op(TileShortSide, 15)) ||
|
||||||
|
(TileLongSide == 128 && op(TileShortSide, 8)) ||
|
||||||
|
(TileLongSide == 256 && op(TileShortSide, 4)) ||
|
||||||
|
(TileLongSide == 512 && op(TileShortSide, 2)))) ||
|
||||||
|
(size_of_t == 4 && ((TileLongSide == 32 && op(TileShortSide, 15)) ||
|
||||||
|
(TileLongSide == 64 && op(TileShortSide, 15)) ||
|
||||||
|
(TileLongSide == 128 && op(TileShortSide, 15)) ||
|
||||||
|
(TileLongSide == 256 && op(TileShortSide, 8)) ||
|
||||||
|
(TileLongSide == 512 && op(TileShortSide, 4)) ||
|
||||||
|
(TileLongSide == 1024 && op(TileShortSide, 2)))) ||
|
||||||
|
(size_of_t == 2 && ((TileLongSide == 32 && op(TileShortSide, 15)) ||
|
||||||
|
(TileLongSide == 64 && op(TileShortSide, 15)) ||
|
||||||
|
(TileLongSide == 128 && op(TileShortSide, 15)) ||
|
||||||
|
(TileLongSide == 256 && op(TileShortSide, 8)) ||
|
||||||
|
(TileLongSide == 512 && op(TileShortSide, 4)) ||
|
||||||
|
(TileLongSide == 1024 && op(TileShortSide, 2)))) ||
|
||||||
|
(size_of_t == 1 && ((TileLongSide == 32 && op(TileShortSide, 15)) ||
|
||||||
|
(TileLongSide == 64 && op(TileShortSide, 15)) ||
|
||||||
|
(TileLongSide == 128 && op(TileShortSide, 15)) ||
|
||||||
|
(TileLongSide == 256 && op(TileShortSide, 8)) ||
|
||||||
|
(TileLongSide == 512 && op(TileShortSide, 4)) ||
|
||||||
|
(TileLongSide == 1024 && op(TileShortSide, 2))));
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool TileSizeOnLongSideFrontier(int TileLongSide, int TileShortSide,
|
||||||
|
int size_of_t) {
|
||||||
|
return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide,
|
||||||
|
size_of_t, EqualTo());
|
||||||
|
}
|
||||||
|
constexpr bool TileSizeOutsideFrontier(int TileLongSide, int TileShortSide,
|
||||||
|
int size_of_t) {
|
||||||
|
return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide,
|
||||||
|
size_of_t, GreaterThan());
|
||||||
|
}
|
||||||
|
constexpr bool TileSizeOnNonLongSideFrontier(int TileLongSide,
|
||||||
|
int TileShortSide, int size_of_t) {
|
||||||
|
// For a tile size combination (longside, shortside), lying on the frontier
|
||||||
|
// implies that (longside, shortside) is on or within the frontier but
|
||||||
|
// (longside*2, shortside) or (longside, shortside+1) is not. With the above
|
||||||
|
// critereon, we simply need to use !TileSizeOnLongSideFrontier to ensure that
|
||||||
|
// it is not on the long side frontier.
|
||||||
|
return !TileSizeOutsideFrontier(TileLongSide, TileShortSide, size_of_t) &&
|
||||||
|
(TileSizeOutsideFrontier(TileLongSide * 2, TileShortSide, size_of_t) ||
|
||||||
|
TileSizeOutsideFrontier(TileLongSide, TileShortSide + 1,
|
||||||
|
size_of_t)) &&
|
||||||
|
!TileSizeOnLongSideFrontier(TileLongSide, TileShortSide, size_of_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to launch a batch narrow matirx transpose kernel.
|
||||||
|
template <typename T, int TileLongSide, int TileShortSide>
|
||||||
|
void LaunchBatchNarrowMatrixTransposeKernel(
|
||||||
|
const GPUDevice& d, int tile_size_i, int tile_size_j, int total_tiles_count,
|
||||||
|
const T* input, const Dimension<3>& input_dims, T* output) {
|
||||||
|
constexpr int NumThreads = TileLongSide;
|
||||||
|
if (tile_size_i <= TileLongSide && tile_size_j <= TileShortSide) {
|
||||||
|
SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileLongSide,
|
||||||
|
TileShortSide>
|
||||||
|
<<<total_tiles_count, NumThreads, 0, d.stream()>>>(input, input_dims,
|
||||||
|
output);
|
||||||
|
} else {
|
||||||
|
SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileShortSide,
|
||||||
|
TileLongSide>
|
||||||
|
<<<total_tiles_count, NumThreads, 0, d.stream()>>>(input, input_dims,
|
||||||
|
output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursive template function to search, in a trial-and-error manner, for the
|
||||||
|
// minimum tile size configuration satisfying the requested tile side lengths.
|
||||||
|
// An important invariant of this search procedure is that for an unsatisfied
|
||||||
|
// request, we always try doubling the long side len first, and only after
|
||||||
|
// the request is satisfied for the long side len do we begin incrementing
|
||||||
|
// the short side len.
|
||||||
|
//
|
||||||
|
// We have three specializations of this search function depending on where the
|
||||||
|
// current tile size combination lies with respect to the frontier.
|
||||||
|
// - It lies within the frontier. If request is not satisfied, for the next tile
|
||||||
|
// size combination, we first try doubling the long side len and if that does
|
||||||
|
// not work, we then increment the short side len.
|
||||||
|
// - It lies on the non long side frontier. If the request is not satisfied, we
|
||||||
|
// can only increment the short side len.
|
||||||
|
// - It lies on the long side frontier. We launch the kernel without checking if
|
||||||
|
// the request is satisfied or not.
|
||||||
|
template <typename T, int TileLongSide, int TileShortSide,
|
||||||
|
typename dummy = void>
|
||||||
|
struct BatchNarrowMatrixTransposeDispatcher {
|
||||||
|
static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
|
||||||
|
int total_tiles_count, const T* input,
|
||||||
|
const Dimension<3>& input_dims, T* output) {
|
||||||
|
static_assert(
|
||||||
|
(TileLongSide & (TileLongSide - 1)) == 0,
|
||||||
|
"The length of the longer side of the tile is always a power of 2.");
|
||||||
|
bool request_satisfied = max(tile_size_i, tile_size_j) <= TileLongSide &&
|
||||||
|
min(tile_size_i, tile_size_j) <= TileShortSide;
|
||||||
|
|
||||||
|
if (request_satisfied) {
|
||||||
|
LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
|
||||||
|
d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
|
||||||
|
output);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the execution reaches here, then the kernel was not launched; we then
|
||||||
|
// determine whether it is the long side or the short side that falls short
|
||||||
|
// of the request and increase that parameter accordingly.
|
||||||
|
const bool long_side_request_not_satisfied =
|
||||||
|
max(tile_size_i, tile_size_j) > TileLongSide;
|
||||||
|
|
||||||
|
if (long_side_request_not_satisfied) {
|
||||||
|
BatchNarrowMatrixTransposeDispatcher<
|
||||||
|
T, TileLongSide * 2, TileShortSide>::DoIt(d, tile_size_i, tile_size_j,
|
||||||
|
total_tiles_count, input,
|
||||||
|
input_dims, output);
|
||||||
|
} else {
|
||||||
|
BatchNarrowMatrixTransposeDispatcher<
|
||||||
|
T, TileLongSide, TileShortSide + 1>::DoIt(d, tile_size_i, tile_size_j,
|
||||||
|
total_tiles_count, input,
|
||||||
|
input_dims, output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int TileLongSide, int TileShortSide>
|
||||||
|
struct BatchNarrowMatrixTransposeDispatcher<
|
||||||
|
T, TileLongSide, TileShortSide,
|
||||||
|
typename std::enable_if<TileSizeOnNonLongSideFrontier(
|
||||||
|
TileLongSide, TileShortSide, sizeof(T)),
|
||||||
|
void>::type> {
|
||||||
|
static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
|
||||||
|
int total_tiles_count, const T* input,
|
||||||
|
const Dimension<3>& input_dims, T* output) {
|
||||||
|
static_assert(
|
||||||
|
(TileLongSide & (TileLongSide - 1)) == 0,
|
||||||
|
"The length of the longer side of the tile is always a power of 2.");
|
||||||
|
bool request_satisfied = max(tile_size_i, tile_size_j) <= TileLongSide &&
|
||||||
|
min(tile_size_i, tile_size_j) <= TileShortSide;
|
||||||
|
|
||||||
|
if (request_satisfied) {
|
||||||
|
LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
|
||||||
|
d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
|
||||||
|
output);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the execution reaches here, then the kernel was not launched; since
|
||||||
|
// we are on the non long side frontier, we increment the short dimension
|
||||||
|
// and try again.
|
||||||
|
BatchNarrowMatrixTransposeDispatcher<
|
||||||
|
T, TileLongSide, TileShortSide + 1>::DoIt(d, tile_size_i, tile_size_j,
|
||||||
|
total_tiles_count, input,
|
||||||
|
input_dims, output);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int TileLongSide, int TileShortSide>
|
||||||
|
struct BatchNarrowMatrixTransposeDispatcher<
|
||||||
|
T, TileLongSide, TileShortSide,
|
||||||
|
typename std::enable_if<TileSizeOnLongSideFrontier(
|
||||||
|
TileLongSide, TileShortSide, sizeof(T)),
|
||||||
|
void>::type> {
|
||||||
|
static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
|
||||||
|
int total_tiles_count, const T* input,
|
||||||
|
const Dimension<3>& input_dims, T* output) {
|
||||||
|
static_assert(
|
||||||
|
(TileLongSide & (TileLongSide - 1)) == 0,
|
||||||
|
"The length of the longer side of the tile is always a power of 2.");
|
||||||
|
|
||||||
|
LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
|
||||||
|
d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
|
||||||
|
output);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// This function tries to recover, in a brute force way, the frontier defined in
|
||||||
|
// TileSizePossibilityFrontierCheck as a vector of tile size combinations lying
|
||||||
|
// on the long side frontier. This vector is sufficient to determine the entire
|
||||||
|
// frontier.
|
||||||
|
//
|
||||||
|
// Note that if one changes the frontier definition in
|
||||||
|
// TileSizePossibilityFrontierCheck and forgets to set the largest short
|
||||||
|
// side len of the largest legal long side len to 2, this function will fail
|
||||||
|
// and crash the program.
|
||||||
|
template <int SizeOfT>
|
||||||
|
const std::vector<std::pair<int, int>>& GetTileSizesFrontier() {
|
||||||
|
static_assert(
|
||||||
|
SizeOfT <= 16,
|
||||||
|
"Currently, only data types of sizes 16 bytes or less are supported.");
|
||||||
|
static_assert((SizeOfT & (SizeOfT - 1)) == 0,
|
||||||
|
"Data types must have sizes that are powers of 2.");
|
||||||
|
|
||||||
|
// Expensive work to populate sizes, lazily run in a thread-safe
|
||||||
|
// manner the first time GetTileSizesFrontier<N> is called.
|
||||||
|
static auto* frontier = [] {
|
||||||
|
auto* frontier = new std::vector<std::pair<int, int>>();
|
||||||
|
const int kMaxLongSideLen = 1024;
|
||||||
|
const int kMaxShortSideLen = 15;
|
||||||
|
for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) {
|
||||||
|
for (int short_side = 2; short_side <= kMaxShortSideLen;
|
||||||
|
short_side += 1) {
|
||||||
|
if (TileSizeOnLongSideFrontier(long_side, short_side, SizeOfT)) {
|
||||||
|
// The current combination lies on the frontier, thus we
|
||||||
|
// add it to the frontier definition.
|
||||||
|
frontier->push_back(std::make_pair(long_side, short_side));
|
||||||
|
|
||||||
|
// The long side length is the largest one allowed iff its
|
||||||
|
// corresponding short side length is 2.
|
||||||
|
if (short_side == 2) return frontier;
|
||||||
|
|
||||||
|
// We have exhausted all the possibilities in the frontier
|
||||||
|
// with the given long side length.
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG(FATAL)
|
||||||
|
<< "The corresponding short side length of the largest long side "
|
||||||
|
"length has to be 2.";
|
||||||
|
}();
|
||||||
|
return *frontier;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper structs to help determine which data type to use given the size of
|
||||||
|
// the matrix data type. A transpose of elements of size N will use a kernel
|
||||||
|
// which operates on an array of TransposeElemType<N>::type.
|
||||||
|
template <int ElemBytes>
|
||||||
|
struct TransposeElemType;
|
||||||
|
template <>
|
||||||
|
struct TransposeElemType<1> {
|
||||||
|
using type = uint8;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct TransposeElemType<2> {
|
||||||
|
using type = uint16;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct TransposeElemType<4> {
|
||||||
|
using type = uint32;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct TransposeElemType<8> {
|
||||||
|
using type = uint64;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct TransposeElemType<16> {
|
||||||
|
using type = float4;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A helper function to make RunSwapDimension1And2InTensor3 concise. This
|
||||||
|
// helper function looks at the data type and input matrix sizes and decides
|
||||||
|
// the thread numbers and tile sizes to use.
|
||||||
|
template <typename T, bool conjugate = false >
|
||||||
|
void SwapDimension1And2InTensor3WithNarrowMatrices(
|
||||||
|
const GPUDevice& d, const T* input, const Dimension<3>& input_dims,
|
||||||
|
T* output, const int kMinDimensionToUseTiles) {
|
||||||
|
// Get available tile sizes here for the data type requested:
|
||||||
|
const auto& tile_spec = GetTileSizesFrontier<sizeof(T)>();
|
||||||
|
|
||||||
|
int tile_long_side_len = 0;
|
||||||
|
int tile_short_side_len = 0;
|
||||||
|
float lowest_cost = std::numeric_limits<float>::max();
|
||||||
|
int data_long_side = max(input_dims[1], input_dims[2]);
|
||||||
|
|
||||||
|
for (auto tile_size_pair : tile_spec) {
|
||||||
|
int proposed_tile_long_side_len = tile_size_pair.first;
|
||||||
|
|
||||||
|
// Number of threads that will not be doing anything useful when reading
|
||||||
|
// the matrix because the thread block size is bigger than the data block
|
||||||
|
// size.
|
||||||
|
int num_wasted_threads =
|
||||||
|
data_long_side - MathUtil::FloorOfRatio<int>(
|
||||||
|
data_long_side, proposed_tile_long_side_len) *
|
||||||
|
proposed_tile_long_side_len;
|
||||||
|
|
||||||
|
int num_full_tiles = MathUtil::FloorOfRatio<int>(
|
||||||
|
data_long_side, proposed_tile_long_side_len);
|
||||||
|
|
||||||
|
float cost = 0;
|
||||||
|
|
||||||
|
// However, if we can execute two or more full tiles, then we gladly
|
||||||
|
// accept any number of wasted threads and ignore its cost.
|
||||||
|
if (num_full_tiles <= 1) cost = num_wasted_threads;
|
||||||
|
|
||||||
|
// Using less than or equal to here because given the same cost, we
|
||||||
|
// would like to launch as many threads as possible.
|
||||||
|
if (cost <= lowest_cost) {
|
||||||
|
tile_long_side_len = proposed_tile_long_side_len;
|
||||||
|
tile_short_side_len = tile_size_pair.second;
|
||||||
|
lowest_cost = cost;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request tile sizes such that the longer side of threadblock aligns with
|
||||||
|
// the longer side of input data block to maximize read throughput.
|
||||||
|
// The ideal tile shape is one where the length of the shorter side of the
|
||||||
|
// tile is equal to the length of the shorter side of the input matrix.
|
||||||
|
int requested_tile_size_i = input_dims[1] >= kMinDimensionToUseTiles
|
||||||
|
? tile_long_side_len
|
||||||
|
: input_dims[1];
|
||||||
|
int requested_tile_size_j = input_dims[1] >= kMinDimensionToUseTiles
|
||||||
|
? input_dims[2]
|
||||||
|
: tile_long_side_len;
|
||||||
|
|
||||||
|
// Truncate the shorter size requested according to the manual limit set in
|
||||||
|
// tile_spec to make sure that we do not launch configurations violating
|
||||||
|
// hardware limits.
|
||||||
|
requested_tile_size_i = requested_tile_size_i == tile_long_side_len
|
||||||
|
? tile_long_side_len
|
||||||
|
: min(requested_tile_size_i, tile_short_side_len);
|
||||||
|
requested_tile_size_j = requested_tile_size_j == tile_long_side_len
|
||||||
|
? tile_long_side_len
|
||||||
|
: min(requested_tile_size_j, tile_short_side_len);
|
||||||
|
|
||||||
|
Dimension<3> input_dims_in_tiles = {
|
||||||
|
input_dims[0],
|
||||||
|
MathUtil::CeilOfRatio<int>(input_dims[1], requested_tile_size_i),
|
||||||
|
MathUtil::CeilOfRatio<int>(input_dims[2], requested_tile_size_j),
|
||||||
|
};
|
||||||
|
|
||||||
|
int total_tiles_count =
|
||||||
|
input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2];
|
||||||
|
|
||||||
|
using ElemType = typename TransposeElemType<sizeof(T)>::type;
|
||||||
|
static_assert(alignof(T) >= alignof(ElemType), "Unexpected data alignment.");
|
||||||
|
BatchNarrowMatrixTransposeDispatcher<ElemType, 32, 2>::DoIt(
|
||||||
|
d, requested_tile_size_i, requested_tile_size_j, total_tiles_count,
|
||||||
|
reinterpret_cast<const ElemType*>(input), input_dims,
|
||||||
|
reinterpret_cast<ElemType*>(output));
|
||||||
|
}
|
||||||
|
|
||||||
// Launch the GPU kernel that would swap dimension-1 and dimension-2 in a
|
// Launch the GPU kernel that would swap dimension-1 and dimension-2 in a
|
||||||
// 3D tensor. It looks at the shape of the incoming data, and decides the best
|
// 3D tensor. It looks at the shape of the incoming data, and decides the best
|
||||||
// strategy to launch.
|
// strategy to launch.
|
||||||
|
|
@ -558,60 +886,33 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
|
||||||
// If one dimension is trivial, use SmallDim kernel for swapping.
|
// If one dimension is trivial, use SmallDim kernel for swapping.
|
||||||
// Otherwise, the trivial swapping relying on the ldg cache is more efficient.
|
// Otherwise, the trivial swapping relying on the ldg cache is more efficient.
|
||||||
static const int kMinDimensionToUseTiles = 16;
|
static const int kMinDimensionToUseTiles = 16;
|
||||||
bool use_tiles = (input_dims[1] >= kMinDimensionToUseTiles &&
|
static const int kMinDimensionToUseRectTiles = 96;
|
||||||
input_dims[2] >= kMinDimensionToUseTiles);
|
|
||||||
bool use_small_dim = ((input_dims[1] >= kMinDimensionToUseTiles &&
|
|
||||||
input_dims[2] < kMinDimensionToUseTiles)) ||
|
|
||||||
((input_dims[1] < kMinDimensionToUseTiles &&
|
|
||||||
input_dims[2] >= kMinDimensionToUseTiles));
|
|
||||||
static const int NumSubTiles = 8;
|
|
||||||
|
|
||||||
if (use_tiles) {
|
bool large_matrix = input_dims[1] >= kMinDimensionToUseTiles &&
|
||||||
static const int TileSize = 32;
|
input_dims[2] >= kMinDimensionToUseTiles;
|
||||||
Dimension<3> input_dims_in_tiles = {
|
bool narrow_matrix = input_dims[1] >= kMinDimensionToUseRectTiles ||
|
||||||
input_dims[0],
|
input_dims[2] >= kMinDimensionToUseRectTiles;
|
||||||
(input_dims[1] + TileSize - 1) / TileSize,
|
if (large_matrix) {
|
||||||
(input_dims[2] + TileSize - 1) / TileSize,
|
// We get best performance when kTileSize is the number of threads in a warp
|
||||||
};
|
|
||||||
int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
|
|
||||||
input_dims_in_tiles[2];
|
|
||||||
// We get best performance when TileSize is the number of threads in a warp
|
|
||||||
// (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
|
// (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
|
||||||
// threads.
|
// threads.
|
||||||
SwapDimension1And2InTensor3UsingTiles<T, TileSize, NumSubTiles, conjugate>
|
constexpr int kTileSize = 32;
|
||||||
<<<total_tiles_count, dim3(TileSize, NumSubTiles), 0, d.stream()>>>(
|
constexpr int kNumThreads = 256;
|
||||||
input, input_dims, output);
|
|
||||||
} else if (use_small_dim) {
|
Dimension<3> input_dims_in_tiles = {
|
||||||
// When only one of the dimensions is smaller than kMinDimensionToUseTiles,
|
input_dims[0], MathUtil::CeilOfRatio<int>(input_dims[1], kTileSize),
|
||||||
// we use one block to process a rectangle region with the size of
|
MathUtil::CeilOfRatio<int>(input_dims[2], kTileSize),
|
||||||
// kTileLength * small_dim. We found that when set kTileLength to 64 on
|
};
|
||||||
// TitanX Maxwell GPU, it achieves the best performance.
|
|
||||||
// large_dim
|
int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
|
||||||
// +---------------...--------+
|
input_dims_in_tiles[2];
|
||||||
// | | | |
|
SwapDimension1And2InTensor3UsingTiles<T, kNumThreads, kTileSize, kTileSize, conjugate>
|
||||||
// small_dim | | ... | |
|
<<<total_tiles_count, kNumThreads, 0, d.stream()>>>(input, input_dims,
|
||||||
// | | | |
|
output);
|
||||||
// +--------------...---------+
|
|
||||||
// \----- ------/ \- -/
|
} else if (narrow_matrix) {
|
||||||
// V V
|
SwapDimension1And2InTensor3WithNarrowMatrices<T, conjugate>(d, input, input_dims, output,
|
||||||
// kTileLength(tile_height) tile_height
|
kMinDimensionToUseTiles);
|
||||||
static const int kTileLength = 64;
|
|
||||||
static const int kGridDimY = 65535;
|
|
||||||
int large_dim = std::max(input_dims[2], input_dims[1]);
|
|
||||||
int tile_num_per_block = (large_dim + kTileLength - 1) / kTileLength;
|
|
||||||
int grid_dim_y = std::min(input_dims[0], kGridDimY);
|
|
||||||
int batch_per_block = (input_dims[0] + grid_dim_y - 1) / grid_dim_y;
|
|
||||||
if (input_dims[2] < input_dims[1]) {
|
|
||||||
SwapDimension1And2InTensor3SmallDim<
|
|
||||||
T, kTileLength * kMinDimensionToUseTiles, true, conjugate>
|
|
||||||
<<<dim3(tile_num_per_block, grid_dim_y), kTileLength, 0,
|
|
||||||
d.stream()>>>(input, batch_per_block, input_dims, output);
|
|
||||||
} else {
|
|
||||||
SwapDimension1And2InTensor3SmallDim<
|
|
||||||
T, kTileLength * kMinDimensionToUseTiles, false, conjugate>
|
|
||||||
<<<dim3(tile_num_per_block, grid_dim_y), kTileLength, 0,
|
|
||||||
d.stream()>>>(input, batch_per_block, input_dims, output);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
|
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
|
||||||
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
|
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ class GraphDefBuilderWrapper {
|
||||||
Status AddTensor(const Tensor& val, Node** output) {
|
Status AddTensor(const Tensor& val, Node** output) {
|
||||||
AddTensorInternal(val, output);
|
AddTensorInternal(val, output);
|
||||||
if (*output == nullptr) {
|
if (*output == nullptr) {
|
||||||
return errors::Internal("AddTesor: Failed to build Const op.");
|
return errors::Internal("AddTensor: Failed to build Const op.");
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||||
|
|
||||||
|
|
@ -74,6 +75,7 @@ DEFINE_SETZERO_SYCL(int32);
|
||||||
DEFINE_SETZERO_SYCL(int64);
|
DEFINE_SETZERO_SYCL(int64);
|
||||||
#undef DEFINE_SETZERO_SYCL
|
#undef DEFINE_SETZERO_SYCL
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void SetOneFunctor<Eigen::ThreadPoolDevice, T>::operator()(
|
void SetOneFunctor<Eigen::ThreadPoolDevice, T>::operator()(
|
||||||
const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out) {
|
const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out) {
|
||||||
|
|
@ -112,5 +114,47 @@ DEFINE_SETONE_SYCL(double);
|
||||||
#undef DEFINE_SETONE_SYCL
|
#undef DEFINE_SETONE_SYCL
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct FillFunctor<Eigen::ThreadPoolDevice, T> {
|
||||||
|
void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out,
|
||||||
|
typename TTypes<T>::ConstScalar in) {
|
||||||
|
out.device(d) = out.constant(in());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Explicit instantiations.
|
||||||
|
#define DEFINE_FILL_CPU(T) \
|
||||||
|
template struct FillFunctor<Eigen::ThreadPoolDevice, T>;
|
||||||
|
|
||||||
|
TF_CALL_ALL_TYPES(DEFINE_FILL_CPU);
|
||||||
|
DEFINE_FILL_CPU(quint8);
|
||||||
|
DEFINE_FILL_CPU(quint16);
|
||||||
|
#undef DEFINE_FILL_CPU
|
||||||
|
|
||||||
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
template <typename T>
|
||||||
|
struct FillFunctor<Eigen::SyclDevice, T> {
|
||||||
|
void operator()(const Eigen::SyclDevice& d, typename TTypes<T>::Flat out,
|
||||||
|
typename TTypes<T>::ConstScalar in) {
|
||||||
|
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||||
|
Eigen::array<int, 1> rank1{1};
|
||||||
|
#else
|
||||||
|
Eigen::IndexList<Eigen::type2index<1> > rank1;
|
||||||
|
#endif
|
||||||
|
const int size = out.dimension(0);
|
||||||
|
Eigen::array<int, 1> broadcast_dims{size};
|
||||||
|
|
||||||
|
To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define DEFINE_FILL_SYCL(T) \
|
||||||
|
template struct FillFunctor<Eigen::SyclDevice, T>;
|
||||||
|
DEFINE_FILL_SYCL(float);
|
||||||
|
DEFINE_FILL_SYCL(double);
|
||||||
|
TF_CALL_INTEGRAL_TYPES(DEFINE_FILL_SYCL)
|
||||||
|
#undef DEFINE_FILL_SYCL
|
||||||
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
||||||
115
tensorflow/core/kernels/fill_functor.cu.cc
Normal file
115
tensorflow/core/kernels/fill_functor.cu.cc
Normal file
|
|
@ -0,0 +1,115 @@
|
||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace Eigen {
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct scalar_const_op {
|
||||||
|
typedef typename packet_traits<T>::type Packet;
|
||||||
|
|
||||||
|
const T* val;
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
scalar_const_op(const scalar_const_op& x)
|
||||||
|
: val(x.val) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_const_op(const T* v) : val(v) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()() const {
|
||||||
|
return *val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PacketType = Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const PacketType packetOp() const {
|
||||||
|
return internal::pset1<PacketType>(*val);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct functor_traits<scalar_const_op<T> > {
|
||||||
|
enum {
|
||||||
|
Cost = 1,
|
||||||
|
PacketAccess = packet_traits<T>::Vectorizable,
|
||||||
|
IsRepeatable = true
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
// Partial specialization FillFunctor<Device=GPUDevice, T>
|
||||||
|
template <typename T>
|
||||||
|
struct FillFunctor<GPUDevice, T> {
|
||||||
|
void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
|
||||||
|
typename TTypes<T>::ConstScalar in) {
|
||||||
|
Eigen::internal::scalar_const_op<T> f(in.data());
|
||||||
|
To32Bit(out).device(d) = To32Bit(out).nullaryExpr(f);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define DEFINE_FILL_GPU(T) template struct FillFunctor<GPUDevice, T>;
|
||||||
|
TF_CALL_REAL_NUMBER_TYPES(DEFINE_FILL_GPU);
|
||||||
|
TF_CALL_bfloat16(DEFINE_FILL_GPU);
|
||||||
|
TF_CALL_bool(DEFINE_FILL_GPU);
|
||||||
|
#undef DEFINE_FILL_GPU
|
||||||
|
|
||||||
|
// Partial specialization of FillFunctor<Device=GPUDevice, T>.
|
||||||
|
template <typename T>
|
||||||
|
struct SetZeroFunctor<GPUDevice, T> {
|
||||||
|
void operator()(const GPUDevice& d, typename TTypes<T>::Flat out) {
|
||||||
|
To32Bit(out).device(d) = To32Bit(out).constant(T(0));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define DEFINE_SETZERO_GPU(T) template struct SetZeroFunctor<GPUDevice, T>;
|
||||||
|
TF_CALL_NUMBER_TYPES(DEFINE_SETZERO_GPU);
|
||||||
|
TF_CALL_bfloat16(DEFINE_SETZERO_GPU);
|
||||||
|
TF_CALL_bool(DEFINE_SETZERO_GPU);
|
||||||
|
#undef DEFINE_SETZERO_GPU
|
||||||
|
|
||||||
|
// Partial specialization of FillFunctor<Device=GPUDevice, T>.
|
||||||
|
template <typename T>
|
||||||
|
struct SetOneFunctor<GPUDevice, T> {
|
||||||
|
void operator()(const GPUDevice& d, typename TTypes<T>::Flat out) {
|
||||||
|
To32Bit(out).device(d) = To32Bit(out).constant(T(1));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define DEFINE_SETONE_GPU(T) template struct SetOneFunctor<GPUDevice, T>;
|
||||||
|
TF_CALL_NUMBER_TYPES(DEFINE_SETONE_GPU);
|
||||||
|
TF_CALL_bfloat16(DEFINE_SETONE_GPU);
|
||||||
|
TF_CALL_bool(DEFINE_SETONE_GPU);
|
||||||
|
#undef DEFINE_SETONE_GPU
|
||||||
|
|
||||||
|
} // end namespace functor
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
|
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
|
|
@ -239,6 +240,14 @@ struct FusedBatchNorm<GPUDevice, T, U> {
|
||||||
<< " offset shape: " << offset.shape().DebugString()
|
<< " offset shape: " << offset.shape().DebugString()
|
||||||
<< " tensor format: " << tensor_format;
|
<< " tensor format: " << tensor_format;
|
||||||
|
|
||||||
|
// If input is empty, return NaN mean/variance
|
||||||
|
if (x.shape().num_elements() == 0) {
|
||||||
|
functor::SetNanFunctor<U> f;
|
||||||
|
f(context->eigen_device<GPUDevice>(), batch_mean->flat<U>());
|
||||||
|
f(context->eigen_device<GPUDevice>(), batch_var->flat<U>());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
Tensor x_maybe_transformed = x;
|
Tensor x_maybe_transformed = x;
|
||||||
Tensor x_transformed;
|
Tensor x_transformed;
|
||||||
Tensor y_transformed;
|
Tensor y_transformed;
|
||||||
|
|
@ -656,6 +665,14 @@ class FusedBatchNormGradOp : public OpKernel {
|
||||||
context, context->allocate_output(4, TensorShape({}), &placeholder_2));
|
context, context->allocate_output(4, TensorShape({}), &placeholder_2));
|
||||||
FillZeros<Device>(placeholder_2);
|
FillZeros<Device>(placeholder_2);
|
||||||
|
|
||||||
|
// If input is empty, set gradients w.r.t scale/offset to zero.
|
||||||
|
if (x.shape().num_elements() == 0) {
|
||||||
|
functor::SetZeroFunctor<Device, U> f;
|
||||||
|
f(context->eigen_device<Device>(), scale_backprop->flat<U>());
|
||||||
|
f(context->eigen_device<Device>(), offset_backprop->flat<U>());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (is_training_) {
|
if (is_training_) {
|
||||||
functor::FusedBatchNormGrad<Device, T, U>()(
|
functor::FusedBatchNormGrad<Device, T, U>()(
|
||||||
context, y_backprop, x, scale, saved_mean_or_pop_mean,
|
context, y_backprop, x, scale, saved_mean_or_pop_mean,
|
||||||
|
|
|
||||||
|
|
@ -65,8 +65,15 @@ void InvVarianceToVariance<T>::operator()(const Eigen::GpuDevice& d,
|
||||||
epsilon, sample_size, variance);
|
epsilon, sample_size, variance);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void SetNanFunctor<T>::operator()(const Eigen::GpuDevice& d,
|
||||||
|
typename TTypes<T>::Flat out) {
|
||||||
|
To32Bit(out).device(d) = To32Bit(out).constant(Eigen::NumTraits<T>::quiet_NaN());
|
||||||
|
}
|
||||||
|
|
||||||
template class VarianceToInvVariance<float>;
|
template class VarianceToInvVariance<float>;
|
||||||
template class InvVarianceToVariance<float>;
|
template class InvVarianceToVariance<float>;
|
||||||
|
template class SetNanFunctor<float>;
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,12 @@ struct InvVarianceToVariance {
|
||||||
int channels, T* variance);
|
int channels, T* variance);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// This function sets a GPU tensor to NaNs.
|
||||||
|
template <class T>
|
||||||
|
struct SetNanFunctor {
|
||||||
|
void operator()(const Eigen::GpuDevice& d, typename TTypes<T>::Flat out);
|
||||||
|
};
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
// Functor used by FusedBatchNormGradOp to do the computations when
|
// Functor used by FusedBatchNormGradOp to do the computations when
|
||||||
|
|
|
||||||
|
|
@ -210,7 +210,7 @@ class MatrixInverseOpGpu : public AsyncOpKernel {
|
||||||
done);
|
done);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// For large matrices, we wompute the inverse of each matrix in the batch
|
// For large matrices, we compute the inverse of each matrix in the batch
|
||||||
// sequentially. Here we use the cuSolver methods GETRF/GETRS because they
|
// sequentially. Here we use the cuSolver methods GETRF/GETRS because they
|
||||||
// are MUCH faster than their batched cuBlas equivalents for large
|
// are MUCH faster than their batched cuBlas equivalents for large
|
||||||
// matrices.
|
// matrices.
|
||||||
|
|
|
||||||
|
|
@ -147,6 +147,9 @@ void DnnPoolingOp<T>::Compute(
|
||||||
Tensor* tensor_out = nullptr;
|
Tensor* tensor_out = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output(0, tensor_out_shape, &tensor_out));
|
context->allocate_output(0, tensor_out_shape, &tensor_out));
|
||||||
|
if (tensor_in.shape().num_elements() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
PoolParameters params{context, size, stride,
|
PoolParameters params{context, size, stride,
|
||||||
padding, data_format, tensor_in.shape()};
|
padding, data_format, tensor_in.shape()};
|
||||||
|
|
@ -247,6 +250,9 @@ void DnnPoolingGradOp<T>::Compute(
|
||||||
Tensor* input_backprop = nullptr;
|
Tensor* input_backprop = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output(0, tensor_in_shape, &input_backprop));
|
context->allocate_output(0, tensor_in_shape, &input_backprop));
|
||||||
|
if (tensor_in_shape.num_elements() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
PoolParameters params{context, size, stride,
|
PoolParameters params{context, size, stride,
|
||||||
padding, data_format, tensor_in_shape};
|
padding, data_format, tensor_in_shape};
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ class RecordInputOp : public OpKernel {
|
||||||
GETATTR(int64, batch_size);
|
GETATTR(int64, batch_size);
|
||||||
GETATTR(string, compression_type);
|
GETATTR(string, compression_type);
|
||||||
#undef GETATTR
|
#undef GETATTR
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("compression_type", &compression_type));
|
||||||
|
|
||||||
RecordYielder::Options yopts;
|
RecordYielder::Options yopts;
|
||||||
yopts.file_pattern = file_pattern;
|
yopts.file_pattern = file_pattern;
|
||||||
|
|
|
||||||
|
|
@ -131,7 +131,7 @@ class WhereCPUOp : public OpKernel {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, input.dtype() != DT_HALF,
|
context, input.dtype() != DT_HALF,
|
||||||
errors::Unimplemented("No WhereOp available for float16/half type on "
|
errors::Unimplemented("No WhereOp available for float16/half type on "
|
||||||
"GPU; dying in CPU WhereOp to avoid silently "
|
"CPU; dying in CPU WhereOp to avoid silently "
|
||||||
"creating costly copies from device."));
|
"creating costly copies from device."));
|
||||||
|
|
||||||
const int input_dims = input.dims();
|
const int input_dims = input.dims();
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ computation is performed on the underlying representation of x.
|
||||||
.Output("z: T") \
|
.Output("z: T") \
|
||||||
.SetIsCommutative() \
|
.SetIsCommutative() \
|
||||||
.Attr("T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64}") \
|
.Attr("T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64}") \
|
||||||
.SetShapeFn(shape_inference::UnchangedShape)
|
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
|
||||||
|
|
||||||
REGISTER_OP("PopulationCount")
|
REGISTER_OP("PopulationCount")
|
||||||
.Input("x: T")
|
.Input("x: T")
|
||||||
|
|
|
||||||
|
|
@ -114,6 +114,8 @@ int64 LogLevelStrToInt(const char* tf_env_var_val) {
|
||||||
return level;
|
return level;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
int64 MinLogLevelFromEnv() {
|
int64 MinLogLevelFromEnv() {
|
||||||
const char* tf_env_var_val = getenv("TF_CPP_MIN_LOG_LEVEL");
|
const char* tf_env_var_val = getenv("TF_CPP_MIN_LOG_LEVEL");
|
||||||
return LogLevelStrToInt(tf_env_var_val);
|
return LogLevelStrToInt(tf_env_var_val);
|
||||||
|
|
@ -124,8 +126,6 @@ int64 MinVLogLevelFromEnv() {
|
||||||
return LogLevelStrToInt(tf_env_var_val);
|
return LogLevelStrToInt(tf_env_var_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
LogMessage::~LogMessage() {
|
LogMessage::~LogMessage() {
|
||||||
// Read the min log level once during the first call to logging.
|
// Read the min log level once during the first call to logging.
|
||||||
static int64 min_log_level = MinLogLevelFromEnv();
|
static int64 min_log_level = MinLogLevelFromEnv();
|
||||||
|
|
|
||||||
|
|
@ -305,6 +305,10 @@ T&& CheckNotNull(const char* file, int line, const char* exprtext, T&& t) {
|
||||||
return std::forward<T>(t);
|
return std::forward<T>(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 MinLogLevelFromEnv();
|
||||||
|
|
||||||
|
int64 MinVLogLevelFromEnv();
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,8 @@ filegroup(
|
||||||
tf_cc_binary(
|
tf_cc_binary(
|
||||||
name = "s3_file_system.so",
|
name = "s3_file_system.so",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"aws_logging.cc",
|
||||||
|
"aws_logging.h",
|
||||||
"s3_crypto.cc",
|
"s3_crypto.cc",
|
||||||
"s3_crypto.h",
|
"s3_crypto.h",
|
||||||
"s3_file_system.cc",
|
"s3_file_system.cc",
|
||||||
|
|
@ -66,6 +68,22 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "aws_logging",
|
||||||
|
srcs = [
|
||||||
|
"aws_logging.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"aws_logging.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"@aws//:aws",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "s3_file_system",
|
name = "s3_file_system",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
@ -75,6 +93,7 @@ cc_library(
|
||||||
"s3_file_system.h",
|
"s3_file_system.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":aws_logging",
|
||||||
":s3_crypto",
|
":s3_crypto",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
|
|
||||||
121
tensorflow/core/platform/s3/aws_logging.cc
Normal file
121
tensorflow/core/platform/s3/aws_logging.cc
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/platform/s3/aws_logging.h"
|
||||||
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
|
||||||
|
#include <aws/core/Aws.h>
|
||||||
|
#include <aws/core/utils/logging/AWSLogging.h>
|
||||||
|
#include <aws/core/utils/logging/LogSystemInterface.h>
|
||||||
|
|
||||||
|
#include <cstdarg>
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
AWSLogSystem::AWSLogSystem(Aws::Utils::Logging::LogLevel log_level)
|
||||||
|
: log_level_(log_level) {}
|
||||||
|
|
||||||
|
void AWSLogSystem::Log(Aws::Utils::Logging::LogLevel log_level, const char* tag,
|
||||||
|
const char* format, ...) {
|
||||||
|
std::va_list args;
|
||||||
|
va_start(args, format);
|
||||||
|
|
||||||
|
const string s = strings::Printf(format, args);
|
||||||
|
|
||||||
|
va_end(args);
|
||||||
|
|
||||||
|
LogMessage(log_level, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AWSLogSystem::LogStream(Aws::Utils::Logging::LogLevel log_level,
|
||||||
|
const char* tag,
|
||||||
|
const Aws::OStringStream& message_stream) {
|
||||||
|
LogMessage(log_level, message_stream.rdbuf()->str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void AWSLogSystem::LogMessage(Aws::Utils::Logging::LogLevel log_level,
|
||||||
|
const std::string& message) {
|
||||||
|
switch (log_level) {
|
||||||
|
case Aws::Utils::Logging::LogLevel::Info:
|
||||||
|
LOG(INFO) << message;
|
||||||
|
break;
|
||||||
|
case Aws::Utils::Logging::LogLevel::Warn:
|
||||||
|
LOG(WARNING) << message;
|
||||||
|
break;
|
||||||
|
case Aws::Utils::Logging::LogLevel::Error:
|
||||||
|
LOG(ERROR) << message;
|
||||||
|
break;
|
||||||
|
case Aws::Utils::Logging::LogLevel::Fatal:
|
||||||
|
LOG(FATAL) << message;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(ERROR) << message;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
static const char* kAWSLoggingTag = "AWSLogging";
|
||||||
|
|
||||||
|
Aws::Utils::Logging::LogLevel ParseLogLevelFromEnv() {
|
||||||
|
Aws::Utils::Logging::LogLevel log_level = Aws::Utils::Logging::LogLevel::Info;
|
||||||
|
|
||||||
|
const int64_t level = tensorflow::internal::MinLogLevelFromEnv();
|
||||||
|
|
||||||
|
switch (level) {
|
||||||
|
case INFO:
|
||||||
|
log_level = Aws::Utils::Logging::LogLevel::Info;
|
||||||
|
break;
|
||||||
|
case WARNING:
|
||||||
|
log_level = Aws::Utils::Logging::LogLevel::Warn;
|
||||||
|
break;
|
||||||
|
case ERROR:
|
||||||
|
log_level = Aws::Utils::Logging::LogLevel::Error;
|
||||||
|
break;
|
||||||
|
case FATAL:
|
||||||
|
log_level = Aws::Utils::Logging::LogLevel::Fatal;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
log_level = Aws::Utils::Logging::LogLevel::Info;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return log_level;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool initialized = false;
|
||||||
|
static mutex s3_logging_mutex(LINKER_INITIALIZED);
|
||||||
|
void AWSLogSystem::InitializeAWSLogging() {
|
||||||
|
std::lock_guard<mutex> s3_logging_lock(s3_logging_mutex);
|
||||||
|
if (!initialized) {
|
||||||
|
Aws::Utils::Logging::InitializeAWSLogging(
|
||||||
|
Aws::MakeShared<AWSLogSystem>(kAWSLoggingTag, ParseLogLevelFromEnv()));
|
||||||
|
initialized = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AWSLogSystem::ShutdownAWSLogging() {
|
||||||
|
std::lock_guard<mutex> s3_logging_lock(s3_logging_mutex);
|
||||||
|
if (initialized) {
|
||||||
|
Aws::Utils::Logging::ShutdownAWSLogging();
|
||||||
|
initialized = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
68
tensorflow/core/platform/s3/aws_logging.h
Normal file
68
tensorflow/core/platform/s3/aws_logging.h
Normal file
|
|
@ -0,0 +1,68 @@
|
||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CONTRIB_S3_S3_LOGGING_H_
|
||||||
|
#define TENSORFLOW_CONTRIB_S3_S3_LOGGING_H_
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <aws/core/utils/logging/LogLevel.h>
|
||||||
|
#include <aws/core/utils/logging/LogSystemInterface.h>
|
||||||
|
#include "tensorflow/core/platform/default/logging.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class AWSLogSystem : public Aws::Utils::Logging::LogSystemInterface {
|
||||||
|
public:
|
||||||
|
static void InitializeAWSLogging();
|
||||||
|
static void ShutdownAWSLogging();
|
||||||
|
|
||||||
|
explicit AWSLogSystem(Aws::Utils::Logging::LogLevel log_level);
|
||||||
|
virtual ~AWSLogSystem() = default;
|
||||||
|
|
||||||
|
// Gets the currently configured log level.
|
||||||
|
virtual Aws::Utils::Logging::LogLevel GetLogLevel(void) const override {
|
||||||
|
return log_level_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set a new log level. This has the immediate effect of changing the log.
|
||||||
|
void SetLogLevel(Aws::Utils::Logging::LogLevel log_level) {
|
||||||
|
log_level_.store(log_level);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Does a printf style output to ProcessFormattedStatement. Don't use this,
|
||||||
|
// it's unsafe. See LogStream.
|
||||||
|
// Since non-static C++ methods have an implicit this argument,
|
||||||
|
// TF_PRINTF_ATTRIBUTE should be counted from two (vs. one).
|
||||||
|
virtual void Log(Aws::Utils::Logging::LogLevel log_level, const char* tag,
|
||||||
|
const char* format, ...) override TF_PRINTF_ATTRIBUTE(4, 5);
|
||||||
|
|
||||||
|
// Writes the stream to ProcessFormattedStatement.
|
||||||
|
virtual void LogStream(Aws::Utils::Logging::LogLevel log_level,
|
||||||
|
const char* tag,
|
||||||
|
const Aws::OStringStream& messageStream) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void LogMessage(Aws::Utils::Logging::LogLevel log_level,
|
||||||
|
const string& message);
|
||||||
|
std::atomic<Aws::Utils::Logging::LogLevel> log_level_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(AWSLogSystem);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CONTRIB_S3_S3_LOGGING_H_
|
||||||
|
|
@ -15,10 +15,13 @@ limitations under the License.
|
||||||
#include "tensorflow/core/platform/s3/s3_file_system.h"
|
#include "tensorflow/core/platform/s3/s3_file_system.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/s3/aws_logging.h"
|
||||||
#include "tensorflow/core/platform/s3/s3_crypto.h"
|
#include "tensorflow/core/platform/s3/s3_crypto.h"
|
||||||
|
|
||||||
#include <aws/core/Aws.h>
|
#include <aws/core/Aws.h>
|
||||||
#include <aws/core/utils/FileSystemUtils.h>
|
#include <aws/core/utils/FileSystemUtils.h>
|
||||||
|
#include <aws/core/utils/logging/AWSLogging.h>
|
||||||
|
#include <aws/core/utils/logging/LogSystemInterface.h>
|
||||||
#include <aws/s3/S3Client.h>
|
#include <aws/s3/S3Client.h>
|
||||||
#include <aws/s3/S3Errors.h>
|
#include <aws/s3/S3Errors.h>
|
||||||
#include <aws/s3/model/CopyObjectRequest.h>
|
#include <aws/s3/model/CopyObjectRequest.h>
|
||||||
|
|
@ -33,6 +36,7 @@ limitations under the License.
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
static const char* kS3FileSystemAllocationTag = "S3FileSystemAllocation";
|
static const char* kS3FileSystemAllocationTag = "S3FileSystemAllocation";
|
||||||
static const size_t kS3ReadAppendableFileBufferSize = 1024 * 1024;
|
static const size_t kS3ReadAppendableFileBufferSize = 1024 * 1024;
|
||||||
static const int kS3GetChildrenMaxKeys = 100;
|
static const int kS3GetChildrenMaxKeys = 100;
|
||||||
|
|
@ -226,7 +230,11 @@ class S3ReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
|
||||||
uint64 length_;
|
uint64 length_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
S3FileSystem::S3FileSystem() {
|
S3FileSystem::S3FileSystem() {
|
||||||
|
AWSLogSystem::InitializeAWSLogging();
|
||||||
|
|
||||||
Aws::SDKOptions options;
|
Aws::SDKOptions options;
|
||||||
options.cryptoOptions.sha256Factory_create_fn = []() {
|
options.cryptoOptions.sha256Factory_create_fn = []() {
|
||||||
return Aws::MakeShared<S3SHA256Factory>(S3CryptoAllocationTag);
|
return Aws::MakeShared<S3SHA256Factory>(S3CryptoAllocationTag);
|
||||||
|
|
@ -240,6 +248,8 @@ S3FileSystem::S3FileSystem() {
|
||||||
S3FileSystem::~S3FileSystem() {
|
S3FileSystem::~S3FileSystem() {
|
||||||
Aws::SDKOptions options;
|
Aws::SDKOptions options;
|
||||||
Aws::ShutdownAPI(options);
|
Aws::ShutdownAPI(options);
|
||||||
|
|
||||||
|
AWSLogSystem::ShutdownAWSLogging();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status S3FileSystem::NewRandomAccessFile(
|
Status S3FileSystem::NewRandomAccessFile(
|
||||||
|
|
|
||||||
|
|
@ -232,10 +232,10 @@ data and target values for the training set, respectively, and `test_set.data`
|
||||||
and `test_set.target` contain feature data and target values for the test set.
|
and `test_set.target` contain feature data and target values for the test set.
|
||||||
|
|
||||||
Later on, in
|
Later on, in
|
||||||
["Fit the DNNClassifier to the Iris Training Data,"](#fit-dnnclassifier)
|
["Fit the DNNClassifier to the Iris Training Data,"](#fit_the_dnnclassifier_to_the_iris_training_data)
|
||||||
you'll use `training_set.data` and
|
you'll use `training_set.data` and
|
||||||
`training_set.target` to train your model, and in
|
`training_set.target` to train your model, and in
|
||||||
["Evaluate Model Accuracy,"](#evaluate-accuracy) you'll use `test_set.data` and
|
["Evaluate Model Accuracy,"](#evaluate_model_accuracy) you'll use `test_set.data` and
|
||||||
`test_set.target`. But first, you'll construct your model in the next section.
|
`test_set.target`. But first, you'll construct your model in the next section.
|
||||||
|
|
||||||
## Construct a Deep Neural Network Classifier
|
## Construct a Deep Neural Network Classifier
|
||||||
|
|
|
||||||
|
|
@ -177,7 +177,7 @@ dataset = dataset.batch(batch_size=FLAGS.batch_size)
|
||||||
to:
|
to:
|
||||||
|
|
||||||
```
|
```
|
||||||
dataset = dataset.apply(tf.data.contrib.map_and_batch(
|
dataset = dataset.apply(tf.contrib.data.map_and_batch(
|
||||||
map_func=parse_fn, batch_size=FLAGS.batch_size))
|
map_func=parse_fn, batch_size=FLAGS.batch_size))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -300,7 +300,7 @@ functions, methods, and properties. We also adhere to the
|
||||||
[Google Python style guide](https://google.github.io/styleguide/pyguide.html).
|
[Google Python style guide](https://google.github.io/styleguide/pyguide.html).
|
||||||
|
|
||||||
The TensorFlow C++ code base adheres to the
|
The TensorFlow C++ code base adheres to the
|
||||||
[Google C++ style guide](http://google.github.io/styleguide/cppguide.html).
|
[Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
||||||
|
|
||||||
(<sup>*</sup> With one exception: we use 2-space indentation instead of 4-space
|
(<sup>*</sup> With one exception: we use 2-space indentation instead of 4-space
|
||||||
indentation.)
|
indentation.)
|
||||||
|
|
|
||||||
|
|
@ -487,7 +487,7 @@ subgraph inside.
|
||||||

|

|
||||||
|
|
||||||
For more information about visualizing your TensorFlow application with
|
For more information about visualizing your TensorFlow application with
|
||||||
TensorBoard, see the [TensorBoard tutorial](TODO).
|
TensorBoard, see the [TensorBoard tutorial](../get_started/summaries_and_tensorboard.md).
|
||||||
|
|
||||||
## Programming with multiple graphs
|
## Programming with multiple graphs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -246,7 +246,7 @@ results as in your server testing.
|
||||||
The demo app updates its UI list of results automatically based on the labels
|
The demo app updates its UI list of results automatically based on the labels
|
||||||
text file you copy into assets alongside your frozen graph, which means you can
|
text file you copy into assets alongside your frozen graph, which means you can
|
||||||
easily try out different models without needing to make any code changes. You
|
easily try out different models without needing to make any code changes. You
|
||||||
will need to updaye `LABEL_FILENAME` and `MODEL_FILENAME` to point to the files
|
will need to update `LABEL_FILENAME` and `MODEL_FILENAME` to point to the files
|
||||||
you've added if you change the paths though.
|
you've added if you change the paths though.
|
||||||
|
|
||||||
## How does this Model Work?
|
## How does this Model Work?
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
|
||||||
image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
|
image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
|
||||||
name='jpeg_reader')
|
name='jpeg_reader')
|
||||||
float_caster = tf.cast(image_reader, tf.float32)
|
float_caster = tf.cast(image_reader, tf.float32)
|
||||||
dims_expander = tf.expand_dims(float_caster, 0);
|
dims_expander = tf.expand_dims(float_caster, 0)
|
||||||
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
|
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
|
||||||
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
|
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
|
||||||
sess = tf.Session()
|
sess = tf.Session()
|
||||||
|
|
@ -118,8 +118,8 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
input_name = "import/" + input_layer
|
input_name = "import/" + input_layer
|
||||||
output_name = "import/" + output_layer
|
output_name = "import/" + output_layer
|
||||||
input_operation = graph.get_operation_by_name(input_name);
|
input_operation = graph.get_operation_by_name(input_name)
|
||||||
output_operation = graph.get_operation_by_name(output_name);
|
output_operation = graph.get_operation_by_name(output_name)
|
||||||
|
|
||||||
with tf.Session(graph=graph) as sess:
|
with tf.Session(graph=graph) as sess:
|
||||||
results = sess.run(output_operation.outputs[0],
|
results = sess.run(output_operation.outputs[0],
|
||||||
|
|
|
||||||
|
|
@ -97,10 +97,27 @@ tf_java_op_gen_srcjar(
|
||||||
# file before making it an executable. See tf_java_op_gen_srcjar().
|
# file before making it an executable. See tf_java_op_gen_srcjar().
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "java_op_gen_tool",
|
name = "java_op_gen_tool",
|
||||||
srcs = glob([
|
srcs = [
|
||||||
"src/gen/cc/*.h",
|
"src/gen/cc/op_gen_main.cc",
|
||||||
"src/gen/cc/*.cc",
|
],
|
||||||
]),
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
":java_op_gen_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "java_op_gen_lib",
|
||||||
|
srcs = [
|
||||||
|
"src/gen/cc/op_generator.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"src/gen/cc/java_defs.h",
|
||||||
|
"src/gen/cc/op_generator.h",
|
||||||
|
],
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
|
@ -280,21 +297,6 @@ tf_java_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
#java_test(
|
|
||||||
# name = "OperatorProcessorTest",
|
|
||||||
# size = "small",
|
|
||||||
# srcs = ["src/test/java/org/tensorflow/processor/OperatorProcessorTest.java"],
|
|
||||||
# javacopts = JAVACOPTS,
|
|
||||||
# resources = [":processor_test_resources"],
|
|
||||||
# test_class = "org.tensorflow.processor.OperatorProcessorTest",
|
|
||||||
# deps = [
|
|
||||||
# ":processor_library",
|
|
||||||
# "//third_party/java/junit",
|
|
||||||
# "@com_google_testing_compile",
|
|
||||||
# "@com_google_truth",
|
|
||||||
# ],
|
|
||||||
#)
|
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "processor_test_resources",
|
name = "processor_test_resources",
|
||||||
srcs = glob([
|
srcs = glob([
|
||||||
|
|
|
||||||
273
tensorflow/java/src/gen/cc/java_defs.h
Normal file
273
tensorflow/java/src/gen/cc/java_defs.h
Normal file
|
|
@ -0,0 +1,273 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
|
||||||
|
#define TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <deque>
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace java {
|
||||||
|
|
||||||
|
// An enumeration of different modifiers commonly used in Java
|
||||||
|
enum Modifier {
|
||||||
|
PUBLIC = (1 << 0),
|
||||||
|
PROTECTED = (1 << 1),
|
||||||
|
PRIVATE = (1 << 2),
|
||||||
|
STATIC = (1 << 3),
|
||||||
|
FINAL = (1 << 4),
|
||||||
|
};
|
||||||
|
|
||||||
|
class Annotation;
|
||||||
|
|
||||||
|
// A definition of any kind of Java type (classes, interfaces...)
|
||||||
|
//
|
||||||
|
// Note that most of the data fields of this class are only useful in specific
|
||||||
|
// contexts and are not required in many cases. For example, annotations and
|
||||||
|
// supertypes are only useful when declaring a type.
|
||||||
|
class Type {
|
||||||
|
public:
|
||||||
|
enum Kind {
|
||||||
|
PRIMITIVE, CLASS, INTERFACE, ENUM, GENERIC, ANNOTATION
|
||||||
|
};
|
||||||
|
static const Type Byte() {
|
||||||
|
return Type(Type::PRIMITIVE, "byte");
|
||||||
|
}
|
||||||
|
static const Type Char() {
|
||||||
|
return Type(Type::PRIMITIVE, "char");
|
||||||
|
}
|
||||||
|
static const Type Short() {
|
||||||
|
return Type(Type::PRIMITIVE, "short");
|
||||||
|
}
|
||||||
|
static const Type Int() {
|
||||||
|
return Type(Type::PRIMITIVE, "int");
|
||||||
|
}
|
||||||
|
static const Type Long() {
|
||||||
|
return Type(Type::PRIMITIVE, "long");
|
||||||
|
}
|
||||||
|
static const Type Float() {
|
||||||
|
return Type(Type::PRIMITIVE, "float");
|
||||||
|
}
|
||||||
|
static const Type Double() {
|
||||||
|
return Type(Type::PRIMITIVE, "double");
|
||||||
|
}
|
||||||
|
static const Type Boolean() {
|
||||||
|
return Type(Type::PRIMITIVE, "boolean");
|
||||||
|
}
|
||||||
|
static const Type Void() {
|
||||||
|
// For simplicity, we consider 'void' as a primitive type, like the Java
|
||||||
|
// Reflection API does
|
||||||
|
return Type(Type::PRIMITIVE, "void");
|
||||||
|
}
|
||||||
|
static Type Class(const string& name, const string& package = "") {
|
||||||
|
return Type(Type::CLASS, name, package);
|
||||||
|
}
|
||||||
|
static Type Interface(const string& name, const string& package = "") {
|
||||||
|
return Type(Type::INTERFACE, name, package);
|
||||||
|
}
|
||||||
|
static Type Enum(const string& name, const string& package = "") {
|
||||||
|
return Type(Type::ENUM, name, package);
|
||||||
|
}
|
||||||
|
static Type Generic(const string& name = "") {
|
||||||
|
return Type(Type::GENERIC, name);
|
||||||
|
}
|
||||||
|
static Type ClassOf(const Type& type) {
|
||||||
|
return Class("Class").add_parameter(type);
|
||||||
|
}
|
||||||
|
static Type ListOf(const Type& type) {
|
||||||
|
return Interface("List", "java.util").add_parameter(type);
|
||||||
|
}
|
||||||
|
static Type IterableOf(const Type& type) {
|
||||||
|
return Interface("Iterable").add_parameter(type);
|
||||||
|
}
|
||||||
|
const Kind& kind() const { return kind_; }
|
||||||
|
const string& name() const { return name_; }
|
||||||
|
const string& package() const { return package_; }
|
||||||
|
const string& description() const { return description_; }
|
||||||
|
Type& description(const string& description) {
|
||||||
|
description_ = description;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
const std::vector<Type>& parameters() const { return parameters_; }
|
||||||
|
Type& add_parameter(const Type& parameter) {
|
||||||
|
parameters_.push_back(parameter);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
const std::vector<Annotation>& annotations() const { return annotations_; }
|
||||||
|
Type& add_annotation(const Annotation& annotation) {
|
||||||
|
annotations_.push_back(annotation);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
const std::deque<Type>& supertypes() const { return supertypes_; }
|
||||||
|
Type& add_supertype(const Type& type) {
|
||||||
|
if (type.kind_ == CLASS) {
|
||||||
|
supertypes_.push_front(type); // keep superclass at the front of the list
|
||||||
|
} else if (type.kind_ == INTERFACE) {
|
||||||
|
supertypes_.push_back(type);
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
// Returns true if "type" is of a known collection type (only a few for now)
|
||||||
|
bool IsCollection() const {
|
||||||
|
return name_ == "List" || name_ == "Iterable";
|
||||||
|
}
|
||||||
|
// Returns true if this instance is a wildcard (<?>)
|
||||||
|
bool IsWildcard() const {
|
||||||
|
return kind_ == GENERIC && name_.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Type(Kind kind, const string& name, const string& package = "")
|
||||||
|
: kind_(kind), name_(name), package_(package) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Kind kind_;
|
||||||
|
string name_;
|
||||||
|
string package_;
|
||||||
|
string description_;
|
||||||
|
std::vector<Type> parameters_;
|
||||||
|
std::vector<Annotation> annotations_;
|
||||||
|
std::deque<Type> supertypes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Definition of a Java annotation
|
||||||
|
//
|
||||||
|
// This class only defines the usage of an annotation in a specific context,
|
||||||
|
// giving optionally a set of attributes to initialize.
|
||||||
|
class Annotation : public Type {
|
||||||
|
public:
|
||||||
|
static Annotation Create(const string& type_name, const string& pkg = "") {
|
||||||
|
return Annotation(type_name, pkg);
|
||||||
|
}
|
||||||
|
const string& attributes() const { return attributes_; }
|
||||||
|
Annotation& attributes(const string& attributes) {
|
||||||
|
attributes_ = attributes;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
string attributes_;
|
||||||
|
|
||||||
|
Annotation(const string& name, const string& package)
|
||||||
|
: Type(Kind::ANNOTATION, name, package) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// A definition of a Java variable
|
||||||
|
//
|
||||||
|
// This class declares an instance of a type, such as a class field or a
|
||||||
|
// method argument, which can be documented.
|
||||||
|
class Variable {
|
||||||
|
public:
|
||||||
|
static Variable Create(const string& name, const Type& type) {
|
||||||
|
return Variable(name, type, false);
|
||||||
|
}
|
||||||
|
static Variable Varargs(const string& name, const Type& type) {
|
||||||
|
return Variable(name, type, true);
|
||||||
|
}
|
||||||
|
const string& name() const { return name_; }
|
||||||
|
const Type& type() const { return type_; }
|
||||||
|
bool variadic() const { return variadic_; }
|
||||||
|
const string& description() const { return description_; }
|
||||||
|
Variable& description(const string& description) {
|
||||||
|
description_ = description;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
string name_;
|
||||||
|
Type type_;
|
||||||
|
bool variadic_;
|
||||||
|
string description_;
|
||||||
|
|
||||||
|
Variable(const string& name, const Type& type, bool variadic)
|
||||||
|
: name_(name), type_(type), variadic_(variadic) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// A definition of a Java class method
|
||||||
|
//
|
||||||
|
// This class defines the signature of a method, including its name, return
|
||||||
|
// type and arguments.
|
||||||
|
class Method {
|
||||||
|
public:
|
||||||
|
static Method Create(const string& name, const Type& return_type) {
|
||||||
|
return Method(name, return_type, false);
|
||||||
|
}
|
||||||
|
static Method ConstructorFor(const Type& clazz) {
|
||||||
|
return Method(clazz.name(), clazz, true);
|
||||||
|
}
|
||||||
|
bool constructor() const { return constructor_; }
|
||||||
|
const string& name() const { return name_; }
|
||||||
|
const Type& return_type() const { return return_type_; }
|
||||||
|
const string& description() const { return description_; }
|
||||||
|
Method& description(const string& description) {
|
||||||
|
description_ = description;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
const string& return_description() const { return return_description_; }
|
||||||
|
Method& return_description(const string& description) {
|
||||||
|
return_description_ = description;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
const std::vector<Variable>& arguments() const { return arguments_; }
|
||||||
|
Method& add_arguments(const std::vector<Variable>& args) {
|
||||||
|
arguments_.insert(arguments_.cend(), args.cbegin(), args.cend());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
Method& add_argument(const Variable& var) {
|
||||||
|
arguments_.push_back(var);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
const std::vector<Annotation>& annotations() const { return annotations_; }
|
||||||
|
Method& add_annotation(const Annotation& annotation) {
|
||||||
|
annotations_.push_back(annotation);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
string name_;
|
||||||
|
Type return_type_;
|
||||||
|
bool constructor_;
|
||||||
|
string description_;
|
||||||
|
string return_description_;
|
||||||
|
std::vector<Variable> arguments_;
|
||||||
|
std::vector<Annotation> annotations_;
|
||||||
|
|
||||||
|
Method(const string& name, const Type& return_type, bool constructor)
|
||||||
|
: name_(name), return_type_(return_type), constructor_(constructor) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// A piece of code to read from a file.
|
||||||
|
class Snippet {
|
||||||
|
public:
|
||||||
|
static Snippet Create(const string& fname, Env* env = Env::Default()) {
|
||||||
|
return Snippet(fname, env);
|
||||||
|
}
|
||||||
|
const string& data() const { return data_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
string data_;
|
||||||
|
|
||||||
|
Snippet(const string& fname, Env* env) {
|
||||||
|
TF_CHECK_OK(ReadFileToString(env, fname, &data_));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace java
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
|
||||||
|
|
@ -25,7 +25,7 @@
|
||||||
#include "tensorflow/java/src/gen/cc/op_generator.h"
|
#include "tensorflow/java/src/gen/cc/op_generator.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace op_gen {
|
namespace java {
|
||||||
|
|
||||||
const char kUsageHeader[] =
|
const char kUsageHeader[] =
|
||||||
"\n\nGenerator of operation wrappers in Java.\n\n"
|
"\n\nGenerator of operation wrappers in Java.\n\n"
|
||||||
|
|
@ -51,7 +51,7 @@ const char kUsageHeader[] =
|
||||||
"Finally, the '--base_package' overrides the default parent package "
|
"Finally, the '--base_package' overrides the default parent package "
|
||||||
"under which the generated subpackage and classes are to be located.\n\n";
|
"under which the generated subpackage and classes are to be located.\n\n";
|
||||||
|
|
||||||
} // namespace op_gen
|
} // namespace java
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
int main(int argc, char* argv[]) {
|
int main(int argc, char* argv[]) {
|
||||||
|
|
@ -67,13 +67,13 @@ int main(int argc, char* argv[]) {
|
||||||
tensorflow::Flag(
|
tensorflow::Flag(
|
||||||
"base_package", &base_package,
|
"base_package", &base_package,
|
||||||
"Package parent to the generated subpackage and classes")};
|
"Package parent to the generated subpackage and classes")};
|
||||||
tensorflow::string usage = tensorflow::op_gen::kUsageHeader;
|
tensorflow::string usage = tensorflow::java::kUsageHeader;
|
||||||
usage += tensorflow::Flags::Usage(argv[0], flag_list);
|
usage += tensorflow::Flags::Usage(argv[0], flag_list);
|
||||||
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||||
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
|
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
|
||||||
QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage;
|
QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage;
|
||||||
|
|
||||||
tensorflow::OpGenerator generator;
|
tensorflow::java::OpGenerator generator;
|
||||||
tensorflow::OpList ops;
|
tensorflow::OpList ops;
|
||||||
tensorflow::OpRegistry::Global()->Export(true, &ops);
|
tensorflow::OpRegistry::Global()->Export(true, &ops);
|
||||||
tensorflow::Status status =
|
tensorflow::Status status =
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||||
#include "tensorflow/java/src/gen/cc/op_generator.h"
|
#include "tensorflow/java/src/gen/cc/op_generator.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace java {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
string CamelCase(const string& str, char delimiter, bool upper) {
|
string CamelCase(const string& str, char delimiter, bool upper) {
|
||||||
|
|
@ -63,4 +64,5 @@ Status OpGenerator::Run(const OpList& ops, const string& lib_name,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace java
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace java {
|
||||||
|
|
||||||
/// \brief A generator of Java operation wrappers.
|
/// \brief A generator of Java operation wrappers.
|
||||||
///
|
///
|
||||||
|
|
@ -46,6 +47,7 @@ class OpGenerator {
|
||||||
Env* env;
|
Env* env;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace java
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_
|
#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ def tf_java_op_gen_srcjar(name,
|
||||||
|
|
||||||
# Generate a source archive containing generated code for these ops.
|
# Generate a source archive containing generated code for these ops.
|
||||||
gen_srcjar = out_dir + name + ".srcjar"
|
gen_srcjar = out_dir + name + ".srcjar"
|
||||||
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) ."]
|
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"]
|
||||||
gen_tools += ["@local_jdk//:jar"] + ["@local_jdk//:jdk"]
|
gen_tools += ["@local_jdk//:jar"] + ["@local_jdk//:jdk"]
|
||||||
gen_tools += tf_binary_additional_srcs()
|
gen_tools += tf_binary_additional_srcs()
|
||||||
native.genrule(
|
native.genrule(
|
||||||
|
|
|
||||||
|
|
@ -803,7 +803,7 @@ class Dataset(object):
|
||||||
```python
|
```python
|
||||||
# Preprocess 4 files concurrently, and interleave blocks of 16 records from
|
# Preprocess 4 files concurrently, and interleave blocks of 16 records from
|
||||||
# each file.
|
# each file.
|
||||||
filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ..."]
|
filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...]
|
||||||
dataset = (Dataset.from_tensor_slices(filenames)
|
dataset = (Dataset.from_tensor_slices(filenames)
|
||||||
.interleave(lambda x:
|
.interleave(lambda x:
|
||||||
TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
|
TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,9 @@
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
# Filter out LOG(INFO)
|
||||||
|
export TF_CPP_MIN_LOG_LEVEL=1
|
||||||
|
|
||||||
IS_VIRTUALENV=0
|
IS_VIRTUALENV=0
|
||||||
PYTHON_BIN_PATH=""
|
PYTHON_BIN_PATH=""
|
||||||
while true; do
|
while true; do
|
||||||
|
|
|
||||||
|
|
@ -688,9 +688,10 @@ class Estimator(object):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if input_fn takes invalid arguments.
|
ValueError: if input_fn takes invalid arguments.
|
||||||
"""
|
"""
|
||||||
del mode # unused
|
|
||||||
input_fn_args = util.fn_args(input_fn)
|
input_fn_args = util.fn_args(input_fn)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
if 'mode' in input_fn_args:
|
||||||
|
kwargs['mode'] = mode
|
||||||
if 'params' in input_fn_args:
|
if 'params' in input_fn_args:
|
||||||
kwargs['params'] = self.params
|
kwargs['params'] = self.params
|
||||||
if 'config' in input_fn_args:
|
if 'config' in input_fn_args:
|
||||||
|
|
|
||||||
|
|
@ -420,6 +420,7 @@ class EstimatorTrainTest(test.TestCase):
|
||||||
self.assertEqual(1, model_fn_call_count[0])
|
self.assertEqual(1, model_fn_call_count[0])
|
||||||
|
|
||||||
def test_callable_input_fn(self):
|
def test_callable_input_fn(self):
|
||||||
|
expected_mode = model_fn_lib.ModeKeys.TRAIN
|
||||||
expected_params = {'batch_size': 10}
|
expected_params = {'batch_size': 10}
|
||||||
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
||||||
input_fn_call_count = [0]
|
input_fn_call_count = [0]
|
||||||
|
|
@ -432,8 +433,9 @@ class EstimatorTrainTest(test.TestCase):
|
||||||
|
|
||||||
class InputFn(object):
|
class InputFn(object):
|
||||||
|
|
||||||
def __call__(self, params, config):
|
def __call__(self, mode, params, config):
|
||||||
input_fn_call_count[0] += 1
|
input_fn_call_count[0] += 1
|
||||||
|
test_self.assertEqual(expected_mode, mode)
|
||||||
test_self.assertEqual(expected_params, params)
|
test_self.assertEqual(expected_params, params)
|
||||||
test_self.assertEqual(4321, config.tf_random_seed)
|
test_self.assertEqual(4321, config.tf_random_seed)
|
||||||
return dummy_input_fn()
|
return dummy_input_fn()
|
||||||
|
|
@ -446,6 +448,7 @@ class EstimatorTrainTest(test.TestCase):
|
||||||
self.assertEqual(1, input_fn_call_count[0])
|
self.assertEqual(1, input_fn_call_count[0])
|
||||||
|
|
||||||
def test_input_fn_args(self):
|
def test_input_fn_args(self):
|
||||||
|
expected_mode = model_fn_lib.ModeKeys.TRAIN
|
||||||
expected_params = {'batch_size': 10}
|
expected_params = {'batch_size': 10}
|
||||||
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
||||||
input_fn_call_count = [0]
|
input_fn_call_count = [0]
|
||||||
|
|
@ -454,8 +457,9 @@ class EstimatorTrainTest(test.TestCase):
|
||||||
del params, config
|
del params, config
|
||||||
return model_fn_global_step_incrementer(features, labels, mode)
|
return model_fn_global_step_incrementer(features, labels, mode)
|
||||||
|
|
||||||
def _input_fn(params, config):
|
def _input_fn(mode, params, config):
|
||||||
input_fn_call_count[0] += 1
|
input_fn_call_count[0] += 1
|
||||||
|
self.assertEqual(expected_mode, mode)
|
||||||
self.assertEqual(expected_params, params)
|
self.assertEqual(expected_params, params)
|
||||||
self.assertEqual(4321, config.tf_random_seed)
|
self.assertEqual(4321, config.tf_random_seed)
|
||||||
return dummy_input_fn()
|
return dummy_input_fn()
|
||||||
|
|
@ -992,6 +996,7 @@ class EstimatorDatasetIntegrationTest(test.TestCase):
|
||||||
class EstimatorEvaluateTest(test.TestCase):
|
class EstimatorEvaluateTest(test.TestCase):
|
||||||
|
|
||||||
def test_input_fn_args(self):
|
def test_input_fn_args(self):
|
||||||
|
expected_mode = model_fn_lib.ModeKeys.EVAL
|
||||||
expected_params = {'batch_size': 10}
|
expected_params = {'batch_size': 10}
|
||||||
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
||||||
input_fn_call_count = [0]
|
input_fn_call_count = [0]
|
||||||
|
|
@ -1000,8 +1005,9 @@ class EstimatorEvaluateTest(test.TestCase):
|
||||||
del params, config
|
del params, config
|
||||||
return model_fn_global_step_incrementer(features, labels, mode)
|
return model_fn_global_step_incrementer(features, labels, mode)
|
||||||
|
|
||||||
def _input_fn(params, config):
|
def _input_fn(mode, params, config):
|
||||||
input_fn_call_count[0] += 1
|
input_fn_call_count[0] += 1
|
||||||
|
self.assertEqual(expected_mode, mode)
|
||||||
self.assertEqual(expected_params, params)
|
self.assertEqual(expected_params, params)
|
||||||
self.assertEqual(4321, config.tf_random_seed)
|
self.assertEqual(4321, config.tf_random_seed)
|
||||||
return dummy_input_fn()
|
return dummy_input_fn()
|
||||||
|
|
@ -1265,6 +1271,7 @@ class EstimatorEvaluateTest(test.TestCase):
|
||||||
class EstimatorPredictTest(test.TestCase):
|
class EstimatorPredictTest(test.TestCase):
|
||||||
|
|
||||||
def test_input_fn_args(self):
|
def test_input_fn_args(self):
|
||||||
|
expected_mode = model_fn_lib.ModeKeys.PREDICT
|
||||||
expected_params = {'batch_size': 10}
|
expected_params = {'batch_size': 10}
|
||||||
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
|
||||||
input_fn_call_count = [0]
|
input_fn_call_count = [0]
|
||||||
|
|
@ -1277,8 +1284,9 @@ class EstimatorPredictTest(test.TestCase):
|
||||||
train_op=state_ops.assign_add(training.get_global_step(), 1),
|
train_op=state_ops.assign_add(training.get_global_step(), 1),
|
||||||
predictions=constant_op.constant([[10.]]))
|
predictions=constant_op.constant([[10.]]))
|
||||||
|
|
||||||
def _input_fn(params, config):
|
def _input_fn(mode, params, config):
|
||||||
input_fn_call_count[0] += 1
|
input_fn_call_count[0] += 1
|
||||||
|
self.assertEqual(expected_mode, mode)
|
||||||
self.assertEqual(expected_params, params)
|
self.assertEqual(expected_params, params)
|
||||||
self.assertEqual(4321, config.tf_random_seed)
|
self.assertEqual(4321, config.tf_random_seed)
|
||||||
return dummy_input_fn()
|
return dummy_input_fn()
|
||||||
|
|
|
||||||
|
|
@ -677,11 +677,19 @@ class _TrainingExecutor(object):
|
||||||
'RunConfig or set the TF_CONFIG environment variable.')
|
'RunConfig or set the TF_CONFIG environment variable.')
|
||||||
|
|
||||||
logging.info('Start Tensorflow server.')
|
logging.info('Start Tensorflow server.')
|
||||||
|
|
||||||
|
if config.session_config is None:
|
||||||
|
session_config=config_pb2.ConfigProto(log_device_placement=False)
|
||||||
|
else:
|
||||||
|
session_config=config_pb2.ConfigProto(
|
||||||
|
log_device_placement=False,
|
||||||
|
gpu_options=config.session_config.gpu_options)
|
||||||
|
|
||||||
server = server_lib.Server(
|
server = server_lib.Server(
|
||||||
config.cluster_spec,
|
config.cluster_spec,
|
||||||
job_name=config.task_type,
|
job_name=config.task_type,
|
||||||
task_index=config.task_id,
|
task_index=config.task_id,
|
||||||
config=config_pb2.ConfigProto(log_device_placement=False),
|
config=session_config,
|
||||||
start=False)
|
start=False)
|
||||||
server.start()
|
server.start()
|
||||||
return server
|
return server
|
||||||
|
|
|
||||||
|
|
@ -796,6 +796,20 @@ class Conv2DTest(test.TestCase):
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
use_gpu=use_gpu)
|
use_gpu=use_gpu)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
|
def testConv2DBackpropFilterWithEmptyInput(self):
|
||||||
|
expected = [0, 0, 0, 0]
|
||||||
|
for (data_format, use_gpu) in GetTestConfigs():
|
||||||
|
self._RunAndVerifyBackpropFilter(
|
||||||
|
input_sizes=[0, 2, 3, 1],
|
||||||
|
filter_sizes=[2, 2, 1, 1],
|
||||||
|
output_sizes=[0, 1, 2, 1],
|
||||||
|
strides=[1, 1],
|
||||||
|
padding="VALID",
|
||||||
|
expected=expected,
|
||||||
|
data_format=data_format,
|
||||||
|
use_gpu=use_gpu)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testConv2D2x2Depth3ValidBackpropFilter(self):
|
def testConv2D2x2Depth3ValidBackpropFilter(self):
|
||||||
expected = [
|
expected = [
|
||||||
|
|
|
||||||
|
|
@ -361,6 +361,16 @@ class PoolingTest(test.TestCase):
|
||||||
expected=expected_output,
|
expected=expected_output,
|
||||||
use_gpu=use_gpu)
|
use_gpu=use_gpu)
|
||||||
|
|
||||||
|
def _testAvgPoolEmptyInput(self, use_gpu):
|
||||||
|
self._VerifyValues(
|
||||||
|
nn_ops.avg_pool,
|
||||||
|
input_sizes=[0, 8, 8, 8],
|
||||||
|
ksize=[1, 3, 3, 1],
|
||||||
|
strides=[1, 2, 2, 1],
|
||||||
|
padding="SAME",
|
||||||
|
expected=[],
|
||||||
|
use_gpu=use_gpu)
|
||||||
|
|
||||||
def testAvgPooling(self):
|
def testAvgPooling(self):
|
||||||
for use_gpu in True, False:
|
for use_gpu in True, False:
|
||||||
self._testAvgPoolValidPadding(use_gpu)
|
self._testAvgPoolValidPadding(use_gpu)
|
||||||
|
|
@ -371,6 +381,7 @@ class PoolingTest(test.TestCase):
|
||||||
self._testAvgPoolSamePadding4(use_gpu)
|
self._testAvgPoolSamePadding4(use_gpu)
|
||||||
self._testAvgPoolSamePaddingPacket4(use_gpu)
|
self._testAvgPoolSamePaddingPacket4(use_gpu)
|
||||||
self._testAvgPoolSamePaddingPacket8(use_gpu)
|
self._testAvgPoolSamePaddingPacket8(use_gpu)
|
||||||
|
self._testAvgPoolEmptyInput(use_gpu)
|
||||||
|
|
||||||
def _testMaxPoolValidPadding(self, use_gpu):
|
def _testMaxPoolValidPadding(self, use_gpu):
|
||||||
expected_output = [13.0, 14.0, 15.0]
|
expected_output = [13.0, 14.0, 15.0]
|
||||||
|
|
@ -543,6 +554,16 @@ class PoolingTest(test.TestCase):
|
||||||
use_gpu=use_gpu,
|
use_gpu=use_gpu,
|
||||||
v2=v2)
|
v2=v2)
|
||||||
|
|
||||||
|
def _testMaxPoolEmptyInput(self, use_gpu):
|
||||||
|
self._VerifyValues(
|
||||||
|
gen_nn_ops._max_pool_v2,
|
||||||
|
input_sizes=[0, 8, 8, 8],
|
||||||
|
ksize=[1, 3, 3, 1],
|
||||||
|
strides=[1, 2, 2, 1],
|
||||||
|
padding="SAME",
|
||||||
|
expected=[],
|
||||||
|
use_gpu=use_gpu)
|
||||||
|
|
||||||
def testMaxPooling(self):
|
def testMaxPooling(self):
|
||||||
for use_gpu in True, False:
|
for use_gpu in True, False:
|
||||||
self._testMaxPoolValidPadding(use_gpu)
|
self._testMaxPoolValidPadding(use_gpu)
|
||||||
|
|
@ -551,6 +572,7 @@ class PoolingTest(test.TestCase):
|
||||||
self._testMaxPoolValidPaddingUnevenStride(use_gpu)
|
self._testMaxPoolValidPaddingUnevenStride(use_gpu)
|
||||||
self._testMaxPoolSamePaddingPacket4(use_gpu)
|
self._testMaxPoolSamePaddingPacket4(use_gpu)
|
||||||
self._testMaxPoolSamePaddingPacket8(use_gpu)
|
self._testMaxPoolSamePaddingPacket8(use_gpu)
|
||||||
|
self._testMaxPoolEmptyInput(use_gpu)
|
||||||
|
|
||||||
# Tests for DepthwiseMaxPooling on CPU only.
|
# Tests for DepthwiseMaxPooling on CPU only.
|
||||||
def testDepthwiseMaxPool1x1DepthWindow1(self):
|
def testDepthwiseMaxPool1x1DepthWindow1(self):
|
||||||
|
|
|
||||||
|
|
@ -135,5 +135,36 @@ class BitwiseOpTest(test_util.TensorFlowTestCase):
|
||||||
bitwise_ops.right_shift(lhs, rhs)])
|
bitwise_ops.right_shift(lhs, rhs)])
|
||||||
|
|
||||||
|
|
||||||
|
def testShapeInference(self):
|
||||||
|
dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
|
||||||
|
dtypes.uint8, dtypes.uint16]
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True) as sess:
|
||||||
|
for dtype in dtype_list:
|
||||||
|
lhs = constant_op.constant([[0], [3], [5]], dtype=dtype)
|
||||||
|
rhs = constant_op.constant([[1, 2, 4]], dtype=dtype)
|
||||||
|
|
||||||
|
and_tensor = bitwise_ops.bitwise_and(lhs, rhs)
|
||||||
|
or_tensor = bitwise_ops.bitwise_or(lhs, rhs)
|
||||||
|
xor_tensor = bitwise_ops.bitwise_xor(lhs, rhs)
|
||||||
|
ls_tensor = bitwise_ops.left_shift(lhs, rhs)
|
||||||
|
rs_tensor = bitwise_ops.right_shift(lhs, rhs)
|
||||||
|
|
||||||
|
and_result, or_result, xor_result, ls_result, rs_result = sess.run(
|
||||||
|
[and_tensor, or_tensor, xor_tensor, ls_tensor, rs_tensor])
|
||||||
|
|
||||||
|
# Compare shape inference with result
|
||||||
|
self.assertAllEqual(and_tensor.get_shape().as_list(), and_result.shape)
|
||||||
|
self.assertAllEqual(and_tensor.get_shape().as_list(), [3, 3])
|
||||||
|
self.assertAllEqual(or_tensor.get_shape().as_list(), or_result.shape)
|
||||||
|
self.assertAllEqual(or_tensor.get_shape().as_list(), [3, 3])
|
||||||
|
self.assertAllEqual(xor_tensor.get_shape().as_list(), xor_result.shape)
|
||||||
|
self.assertAllEqual(xor_tensor.get_shape().as_list(), [3, 3])
|
||||||
|
self.assertAllEqual(ls_tensor.get_shape().as_list(), ls_result.shape)
|
||||||
|
self.assertAllEqual(ls_tensor.get_shape().as_list(), [3, 3])
|
||||||
|
self.assertAllEqual(rs_tensor.get_shape().as_list(), rs_result.shape)
|
||||||
|
self.assertAllEqual(rs_tensor.get_shape().as_list(), [3, 3])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
|
|
||||||
|
|
@ -977,9 +977,7 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
|
||||||
|
|
||||||
`hessians()` adds ops to the graph to output the Hessian matrix of `ys`
|
`hessians()` adds ops to the graph to output the Hessian matrix of `ys`
|
||||||
with respect to `xs`. It returns a list of `Tensor` of length `len(xs)`
|
with respect to `xs`. It returns a list of `Tensor` of length `len(xs)`
|
||||||
where each tensor is the Hessian of `sum(ys)`. This function currently
|
where each tensor is the Hessian of `sum(ys)`.
|
||||||
only supports evaluating the Hessian with respect to (a list of) one-
|
|
||||||
dimensional tensors.
|
|
||||||
|
|
||||||
The Hessian is a matrix of second-order partial derivatives of a scalar
|
The Hessian is a matrix of second-order partial derivatives of a scalar
|
||||||
tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details).
|
tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details).
|
||||||
|
|
@ -1009,13 +1007,10 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
|
||||||
# Compute first-order derivatives and iterate for each x in xs.
|
# Compute first-order derivatives and iterate for each x in xs.
|
||||||
hessians = []
|
hessians = []
|
||||||
_gradients = gradients(ys, xs, **kwargs)
|
_gradients = gradients(ys, xs, **kwargs)
|
||||||
for i, _gradient, x in zip(range(len(xs)), _gradients, xs):
|
for gradient, x in zip(_gradients, xs):
|
||||||
# Ensure that x is a vector.
|
# change shape to one-dimension without graph branching
|
||||||
check_rank = check_ops.assert_rank(
|
gradient = array_ops.reshape(gradient, [-1])
|
||||||
x, 1, message='Cannot compute Hessian because element %d of `xs` does '
|
|
||||||
'not have rank one.' % i
|
|
||||||
)
|
|
||||||
with ops.control_dependencies([check_rank]):
|
|
||||||
# Declare an iterator and tensor array loop variables for the gradients.
|
# Declare an iterator and tensor array loop variables for the gradients.
|
||||||
n = array_ops.size(x)
|
n = array_ops.size(x)
|
||||||
loop_vars = [
|
loop_vars = [
|
||||||
|
|
@ -1027,9 +1022,13 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
|
||||||
_, hessian = control_flow_ops.while_loop(
|
_, hessian = control_flow_ops.while_loop(
|
||||||
lambda j, _: j < n,
|
lambda j, _: j < n,
|
||||||
lambda j, result: (j + 1,
|
lambda j, result: (j + 1,
|
||||||
result.write(j, gradients(_gradient[j], x)[0])),
|
result.write(j, gradients(gradient[j], x)[0])),
|
||||||
loop_vars
|
loop_vars
|
||||||
)
|
)
|
||||||
|
|
||||||
hessians.append(hessian.stack())
|
_shape = array_ops.shape(x)
|
||||||
|
_reshaped_hessian = array_ops.reshape(
|
||||||
|
hessian.stack(), array_ops.concat((_shape, _shape), 0)
|
||||||
|
)
|
||||||
|
hessians.append(_reshaped_hessian)
|
||||||
return hessians
|
return hessians
|
||||||
|
|
|
||||||
|
|
@ -621,6 +621,45 @@ class HessianTest(test_util.TensorFlowTestCase):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
gradients.hessians(x, x)
|
gradients.hessians(x, x)
|
||||||
|
|
||||||
|
def testHessian2D_square_matrix(self):
|
||||||
|
# Manually compute the Hessian explicitly for a low-dimensional problem
|
||||||
|
# and check that `hessian` matches. Specifically, the Hessian of
|
||||||
|
# f(x) = 1/2 * x^T * x is H = constant (block identity matrix)
|
||||||
|
m = 3
|
||||||
|
rng = np.random.RandomState([1, 2, 3])
|
||||||
|
x_value = rng.randn(m, m).astype("float32")
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
x = constant_op.constant(x_value)
|
||||||
|
x_square = math_ops.reduce_sum(
|
||||||
|
math_ops.matmul(array_ops.transpose(x), x) * 0.5
|
||||||
|
)
|
||||||
|
hess = gradients.hessians(x_square, x)[0]
|
||||||
|
hess_actual = hess.eval()
|
||||||
|
hess_value = np.bmat([
|
||||||
|
[elem*np.ones((m, m)) for elem in vec]
|
||||||
|
for vec in np.eye(m)
|
||||||
|
]).astype("float32")
|
||||||
|
self.assertAllEqual((m, m, m, m), hess_actual.shape)
|
||||||
|
self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m)))
|
||||||
|
|
||||||
|
def testHessian2D_non_square_matrix(self):
|
||||||
|
m = 3
|
||||||
|
n = 4
|
||||||
|
rng = np.random.RandomState([1, 2, 3])
|
||||||
|
x_value = rng.randn(m, n).astype("float32")
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
x = constant_op.constant(x_value)
|
||||||
|
x_square = math_ops.reduce_sum(
|
||||||
|
math_ops.matmul(array_ops.transpose(x), x) * 0.5
|
||||||
|
)
|
||||||
|
hess = gradients.hessians(x_square, x)[0]
|
||||||
|
hess_actual = hess.eval()
|
||||||
|
hess_value = np.bmat([
|
||||||
|
[elem*np.ones((n, n)) for elem in vec]
|
||||||
|
for vec in np.eye(m)
|
||||||
|
]).astype("float32")
|
||||||
|
self.assertAllEqual((m, n, m, n), hess_actual.shape)
|
||||||
|
self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n)))
|
||||||
|
|
||||||
@test_util.with_c_api
|
@test_util.with_c_api
|
||||||
class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
|
class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
|
||||||
|
|
|
||||||
|
|
@ -219,6 +219,7 @@ def random_flip_up_down(image, seed=None):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the shape of `image` not supported.
|
ValueError: if the shape of `image` not supported.
|
||||||
"""
|
"""
|
||||||
|
with ops.name_scope(None, 'random_flip_up_down', [image]) as scope:
|
||||||
image = ops.convert_to_tensor(image, name='image')
|
image = ops.convert_to_tensor(image, name='image')
|
||||||
image = control_flow_ops.with_dependencies(
|
image = control_flow_ops.with_dependencies(
|
||||||
_Check3DImage(image, require_static=False), image)
|
_Check3DImage(image, require_static=False), image)
|
||||||
|
|
@ -226,7 +227,8 @@ def random_flip_up_down(image, seed=None):
|
||||||
mirror_cond = math_ops.less(uniform_random, .5)
|
mirror_cond = math_ops.less(uniform_random, .5)
|
||||||
result = control_flow_ops.cond(mirror_cond,
|
result = control_flow_ops.cond(mirror_cond,
|
||||||
lambda: array_ops.reverse(image, [0]),
|
lambda: array_ops.reverse(image, [0]),
|
||||||
lambda: image)
|
lambda: image,
|
||||||
|
name=scope)
|
||||||
return fix_image_flip_shape(image, result)
|
return fix_image_flip_shape(image, result)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -248,6 +250,7 @@ def random_flip_left_right(image, seed=None):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the shape of `image` not supported.
|
ValueError: if the shape of `image` not supported.
|
||||||
"""
|
"""
|
||||||
|
with ops.name_scope(None, 'random_flip_left_right', [image]) as scope:
|
||||||
image = ops.convert_to_tensor(image, name='image')
|
image = ops.convert_to_tensor(image, name='image')
|
||||||
image = control_flow_ops.with_dependencies(
|
image = control_flow_ops.with_dependencies(
|
||||||
_Check3DImage(image, require_static=False), image)
|
_Check3DImage(image, require_static=False), image)
|
||||||
|
|
@ -255,7 +258,10 @@ def random_flip_left_right(image, seed=None):
|
||||||
mirror_cond = math_ops.less(uniform_random, .5)
|
mirror_cond = math_ops.less(uniform_random, .5)
|
||||||
result = control_flow_ops.cond(mirror_cond,
|
result = control_flow_ops.cond(mirror_cond,
|
||||||
lambda: array_ops.reverse(image, [1]),
|
lambda: array_ops.reverse(image, [1]),
|
||||||
lambda: image)
|
lambda: image,
|
||||||
|
name=scope)
|
||||||
|
print('scope: ' + scope)
|
||||||
|
print('result name: ' + result.name)
|
||||||
return fix_image_flip_shape(image, result)
|
return fix_image_flip_shape(image, result)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -276,10 +282,12 @@ def flip_left_right(image):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the shape of `image` not supported.
|
ValueError: if the shape of `image` not supported.
|
||||||
"""
|
"""
|
||||||
|
with ops.name_scope(None, 'flip_left_right', [image]) as scope:
|
||||||
image = ops.convert_to_tensor(image, name='image')
|
image = ops.convert_to_tensor(image, name='image')
|
||||||
image = control_flow_ops.with_dependencies(
|
image = control_flow_ops.with_dependencies(
|
||||||
_Check3DImage(image, require_static=False), image)
|
_Check3DImage(image, require_static=False), image)
|
||||||
return fix_image_flip_shape(image, array_ops.reverse(image, [1]))
|
return fix_image_flip_shape(image,
|
||||||
|
array_ops.reverse(image, [1], name=scope))
|
||||||
|
|
||||||
|
|
||||||
def flip_up_down(image):
|
def flip_up_down(image):
|
||||||
|
|
@ -299,10 +307,12 @@ def flip_up_down(image):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the shape of `image` not supported.
|
ValueError: if the shape of `image` not supported.
|
||||||
"""
|
"""
|
||||||
|
with ops.name_scope(None, 'flip_up_down', [image]) as scope:
|
||||||
image = ops.convert_to_tensor(image, name='image')
|
image = ops.convert_to_tensor(image, name='image')
|
||||||
image = control_flow_ops.with_dependencies(
|
image = control_flow_ops.with_dependencies(
|
||||||
_Check3DImage(image, require_static=False), image)
|
_Check3DImage(image, require_static=False), image)
|
||||||
return fix_image_flip_shape(image, array_ops.reverse(image, [0]))
|
return fix_image_flip_shape(image,
|
||||||
|
array_ops.reverse(image, [0], name=scope))
|
||||||
|
|
||||||
|
|
||||||
def rot90(image, k=1, name=None):
|
def rot90(image, k=1, name=None):
|
||||||
|
|
@ -356,10 +366,11 @@ def transpose_image(image):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the shape of `image` not supported.
|
ValueError: if the shape of `image` not supported.
|
||||||
"""
|
"""
|
||||||
|
with ops.name_scope(None, 'transpose_image', [image]) as scope:
|
||||||
image = ops.convert_to_tensor(image, name='image')
|
image = ops.convert_to_tensor(image, name='image')
|
||||||
image = control_flow_ops.with_dependencies(
|
image = control_flow_ops.with_dependencies(
|
||||||
_Check3DImage(image, require_static=False), image)
|
_Check3DImage(image, require_static=False), image)
|
||||||
return array_ops.transpose(image, [1, 0, 2], name='transpose_image')
|
return array_ops.transpose(image, [1, 0, 2], name=scope)
|
||||||
|
|
||||||
|
|
||||||
def central_crop(image, central_fraction):
|
def central_crop(image, central_fraction):
|
||||||
|
|
@ -386,6 +397,7 @@ def central_crop(image, central_fraction):
|
||||||
Returns:
|
Returns:
|
||||||
3-D float Tensor
|
3-D float Tensor
|
||||||
"""
|
"""
|
||||||
|
with ops.name_scope(None, 'central_crop', [image]):
|
||||||
image = ops.convert_to_tensor(image, name='image')
|
image = ops.convert_to_tensor(image, name='image')
|
||||||
if central_fraction <= 0.0 or central_fraction > 1.0:
|
if central_fraction <= 0.0 or central_fraction > 1.0:
|
||||||
raise ValueError('central_fraction must be within (0, 1]')
|
raise ValueError('central_fraction must be within (0, 1]')
|
||||||
|
|
@ -444,6 +456,7 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
|
||||||
`target_*` arguments, or either `offset_height` or `offset_width` is
|
`target_*` arguments, or either `offset_height` or `offset_width` is
|
||||||
negative.
|
negative.
|
||||||
"""
|
"""
|
||||||
|
with ops.name_scope(None, 'pad_to_bounding_box', [image]):
|
||||||
image = ops.convert_to_tensor(image, name='image')
|
image = ops.convert_to_tensor(image, name='image')
|
||||||
|
|
||||||
is_batch = True
|
is_batch = True
|
||||||
|
|
@ -459,10 +472,10 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
|
||||||
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
|
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
|
||||||
|
|
||||||
assert_ops = _CheckAtLeast3DImage(image, require_static=False)
|
assert_ops = _CheckAtLeast3DImage(image, require_static=False)
|
||||||
|
|
||||||
batch, height, width, depth = _ImageDimensions(image, rank=4)
|
batch, height, width, depth = _ImageDimensions(image, rank=4)
|
||||||
|
|
||||||
after_padding_width = target_width - offset_width - width
|
after_padding_width = target_width - offset_width - width
|
||||||
|
|
||||||
after_padding_height = target_height - offset_height - height
|
after_padding_height = target_height - offset_height - height
|
||||||
|
|
||||||
assert_ops += _assert(offset_height >= 0, ValueError,
|
assert_ops += _assert(offset_height >= 0, ValueError,
|
||||||
|
|
@ -523,6 +536,7 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
|
||||||
`target_*` arguments, or either `offset_height` or `offset_width` is
|
`target_*` arguments, or either `offset_height` or `offset_width` is
|
||||||
negative, or either `target_height` or `target_width` is not positive.
|
negative, or either `target_height` or `target_width` is not positive.
|
||||||
"""
|
"""
|
||||||
|
with ops.name_scope(None, 'crop_to_bounding_box', [image]):
|
||||||
image = ops.convert_to_tensor(image, name='image')
|
image = ops.convert_to_tensor(image, name='image')
|
||||||
|
|
||||||
is_batch = True
|
is_batch = True
|
||||||
|
|
@ -598,6 +612,7 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
|
||||||
If `images` was 3-D, a 3-D float Tensor of shape
|
If `images` was 3-D, a 3-D float Tensor of shape
|
||||||
`[new_height, new_width, channels]`.
|
`[new_height, new_width, channels]`.
|
||||||
"""
|
"""
|
||||||
|
with ops.name_scope(None, 'resize_image_with_crop_or_pad', [image]):
|
||||||
image = ops.convert_to_tensor(image, name='image')
|
image = ops.convert_to_tensor(image, name='image')
|
||||||
image_shape = image.get_shape()
|
image_shape = image.get_shape()
|
||||||
is_batch = True
|
is_batch = True
|
||||||
|
|
@ -624,7 +639,8 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
|
||||||
target_height = control_flow_ops.with_dependencies(
|
target_height = control_flow_ops.with_dependencies(
|
||||||
assert_ops, target_height)
|
assert_ops, target_height)
|
||||||
if _is_tensor(target_width):
|
if _is_tensor(target_width):
|
||||||
target_width = control_flow_ops.with_dependencies(assert_ops, target_width)
|
target_width = control_flow_ops.with_dependencies(
|
||||||
|
assert_ops, target_width)
|
||||||
|
|
||||||
def max_(x, y):
|
def max_(x, y):
|
||||||
if _is_tensor(x) or _is_tensor(y):
|
if _is_tensor(x) or _is_tensor(y):
|
||||||
|
|
@ -736,6 +752,7 @@ def resize_images(images,
|
||||||
If `images` was 3-D, a 3-D float Tensor of shape
|
If `images` was 3-D, a 3-D float Tensor of shape
|
||||||
`[new_height, new_width, channels]`.
|
`[new_height, new_width, channels]`.
|
||||||
"""
|
"""
|
||||||
|
with ops.name_scope(None, 'resize_images', [images, size]):
|
||||||
images = ops.convert_to_tensor(images, name='images')
|
images = ops.convert_to_tensor(images, name='images')
|
||||||
if images.get_shape().ndims is None:
|
if images.get_shape().ndims is None:
|
||||||
raise ValueError('\'images\' contains no shape.')
|
raise ValueError('\'images\' contains no shape.')
|
||||||
|
|
@ -776,7 +793,8 @@ def resize_images(images,
|
||||||
elif method == ResizeMethod.NEAREST_NEIGHBOR:
|
elif method == ResizeMethod.NEAREST_NEIGHBOR:
|
||||||
images = gen_image_ops.resize_nearest_neighbor(images,
|
images = gen_image_ops.resize_nearest_neighbor(images,
|
||||||
size,
|
size,
|
||||||
align_corners=align_corners)
|
align_corners=
|
||||||
|
align_corners)
|
||||||
elif method == ResizeMethod.BICUBIC:
|
elif method == ResizeMethod.BICUBIC:
|
||||||
images = gen_image_ops.resize_bicubic(images,
|
images = gen_image_ops.resize_bicubic(images,
|
||||||
size,
|
size,
|
||||||
|
|
@ -816,6 +834,7 @@ def per_image_standardization(image):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the shape of 'image' is incompatible with this function.
|
ValueError: if the shape of 'image' is incompatible with this function.
|
||||||
"""
|
"""
|
||||||
|
with ops.name_scope(None, 'per_image_standardization', [image]) as scope:
|
||||||
image = ops.convert_to_tensor(image, name='image')
|
image = ops.convert_to_tensor(image, name='image')
|
||||||
image = control_flow_ops.with_dependencies(
|
image = control_flow_ops.with_dependencies(
|
||||||
_Check3DImage(image, require_static=False), image)
|
_Check3DImage(image, require_static=False), image)
|
||||||
|
|
@ -835,7 +854,7 @@ def per_image_standardization(image):
|
||||||
pixel_value_offset = image_mean
|
pixel_value_offset = image_mean
|
||||||
|
|
||||||
image = math_ops.subtract(image, pixel_value_offset)
|
image = math_ops.subtract(image, pixel_value_offset)
|
||||||
image = math_ops.div(image, pixel_value_scale)
|
image = math_ops.div(image, pixel_value_scale, name=scope)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -857,6 +857,7 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
|
||||||
with self.test_session(use_gpu=True):
|
with self.test_session(use_gpu=True):
|
||||||
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
||||||
y = image_ops.flip_left_right(x_tf)
|
y = image_ops.flip_left_right(x_tf)
|
||||||
|
self.assertTrue(y.op.name.startswith('flip_left_right'))
|
||||||
y_tf = y.eval()
|
y_tf = y.eval()
|
||||||
self.assertAllEqual(y_tf, y_np)
|
self.assertAllEqual(y_tf, y_np)
|
||||||
|
|
||||||
|
|
@ -867,6 +868,7 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
|
||||||
with self.test_session(use_gpu=True):
|
with self.test_session(use_gpu=True):
|
||||||
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
||||||
y = image_ops.random_flip_left_right(x_tf)
|
y = image_ops.random_flip_left_right(x_tf)
|
||||||
|
self.assertTrue(y.op.name.startswith('random_flip_left_right'))
|
||||||
|
|
||||||
count_flipped = 0
|
count_flipped = 0
|
||||||
count_unflipped = 0
|
count_unflipped = 0
|
||||||
|
|
@ -897,6 +899,7 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
|
||||||
with self.test_session(use_gpu=True):
|
with self.test_session(use_gpu=True):
|
||||||
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
||||||
y = image_ops.flip_up_down(x_tf)
|
y = image_ops.flip_up_down(x_tf)
|
||||||
|
self.assertTrue(y.op.name.startswith('flip_up_down'))
|
||||||
y_tf = y.eval()
|
y_tf = y.eval()
|
||||||
self.assertAllEqual(y_tf, y_np)
|
self.assertAllEqual(y_tf, y_np)
|
||||||
|
|
||||||
|
|
@ -907,6 +910,7 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
|
||||||
with self.test_session(use_gpu=True):
|
with self.test_session(use_gpu=True):
|
||||||
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
||||||
y = image_ops.random_flip_up_down(x_tf)
|
y = image_ops.random_flip_up_down(x_tf)
|
||||||
|
self.assertTrue(y.op.name.startswith('random_flip_up_down'))
|
||||||
count_flipped = 0
|
count_flipped = 0
|
||||||
count_unflipped = 0
|
count_unflipped = 0
|
||||||
for _ in range(50):
|
for _ in range(50):
|
||||||
|
|
@ -936,6 +940,7 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
|
||||||
with self.test_session(use_gpu=True):
|
with self.test_session(use_gpu=True):
|
||||||
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
||||||
y = image_ops.transpose_image(x_tf)
|
y = image_ops.transpose_image(x_tf)
|
||||||
|
self.assertTrue(y.op.name.startswith('transpose_image'))
|
||||||
y_tf = y.eval()
|
y_tf = y.eval()
|
||||||
self.assertAllEqual(y_tf, y_np)
|
self.assertAllEqual(y_tf, y_np)
|
||||||
|
|
||||||
|
|
@ -1160,6 +1165,7 @@ class PerImageWhiteningTest(test_util.TensorFlowTestCase):
|
||||||
with self.test_session(use_gpu=True):
|
with self.test_session(use_gpu=True):
|
||||||
x = constant_op.constant(x_np, shape=x_shape)
|
x = constant_op.constant(x_np, shape=x_shape)
|
||||||
y = image_ops.per_image_standardization(x)
|
y = image_ops.per_image_standardization(x)
|
||||||
|
self.assertTrue(y.op.name.startswith('per_image_standardization'))
|
||||||
y_tf = y.eval()
|
y_tf = y.eval()
|
||||||
self.assertAllClose(y_tf, y_np, atol=1e-4)
|
self.assertAllClose(y_tf, y_np, atol=1e-4)
|
||||||
|
|
||||||
|
|
@ -1341,6 +1347,11 @@ class CropToBoundingBoxTest(test_util.TensorFlowTestCase):
|
||||||
for params, err_msg in test_config:
|
for params, err_msg in test_config:
|
||||||
self._assertRaises(x, x_shape, *params, err_msg=err_msg)
|
self._assertRaises(x, x_shape, *params, err_msg=err_msg)
|
||||||
|
|
||||||
|
def testNameScope(self):
|
||||||
|
image = array_ops.placeholder(dtypes.float32, shape=[55, 66, 3])
|
||||||
|
y = image_ops.crop_to_bounding_box(image, 0, 0, 55, 66)
|
||||||
|
self.assertTrue(y.name.startswith('crop_to_bounding_box'))
|
||||||
|
|
||||||
|
|
||||||
class CentralCropTest(test_util.TensorFlowTestCase):
|
class CentralCropTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
|
@ -1417,6 +1428,13 @@ class CentralCropTest(test_util.TensorFlowTestCase):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = image_ops.central_crop(x, 1.01)
|
_ = image_ops.central_crop(x, 1.01)
|
||||||
|
|
||||||
|
def testNameScope(self):
|
||||||
|
x_shape = [13, 9, 3]
|
||||||
|
x_np = np.ones(x_shape, dtype=np.float32)
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
y = image_ops.central_crop(x_np, 1.0)
|
||||||
|
self.assertTrue(y.op.name.startswith('central_crop'))
|
||||||
|
|
||||||
|
|
||||||
class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
|
class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
|
@ -1620,6 +1638,11 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
|
||||||
for config_item in test_config:
|
for config_item in test_config:
|
||||||
self._assertRaises(x, x_shape, *config_item)
|
self._assertRaises(x, x_shape, *config_item)
|
||||||
|
|
||||||
|
def testNameScope(self):
|
||||||
|
image = array_ops.placeholder(dtypes.float32, shape=[55, 66, 3])
|
||||||
|
y = image_ops.pad_to_bounding_box(image, 0, 0, 55, 66)
|
||||||
|
self.assertTrue(y.op.name.startswith('pad_to_bounding_box'))
|
||||||
|
|
||||||
|
|
||||||
class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase):
|
class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
|
@ -2224,6 +2247,13 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
|
||||||
self._assertShapeInference([59, 60, None], [55, 66], [55, 66, None])
|
self._assertShapeInference([59, 60, None], [55, 66], [55, 66, None])
|
||||||
self._assertShapeInference([None, None, None], [55, 66], [55, 66, None])
|
self._assertShapeInference([None, None, None], [55, 66], [55, 66, None])
|
||||||
|
|
||||||
|
def testNameScope(self):
|
||||||
|
img_shape = [1, 3, 2, 1]
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
single_image = array_ops.placeholder(dtypes.float32, shape=[50, 60, 3])
|
||||||
|
y = image_ops.resize_images(single_image, [55, 66])
|
||||||
|
self.assertTrue(y.op.name.startswith('resize_images'))
|
||||||
|
|
||||||
|
|
||||||
class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase):
|
class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
|
@ -2499,6 +2529,11 @@ class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase):
|
||||||
self._assertRaises(x, x_shape, target_height, target_width,
|
self._assertRaises(x, x_shape, target_height, target_width,
|
||||||
"target_width must be > 0")
|
"target_width must be > 0")
|
||||||
|
|
||||||
|
def testNameScope(self):
|
||||||
|
image = array_ops.placeholder(dtypes.float32, shape=[50, 60, 3])
|
||||||
|
y = image_ops.resize_image_with_crop_or_pad(image, 55, 66)
|
||||||
|
self.assertTrue(y.op.name.startswith('resize_image_with_crop_or_pad'))
|
||||||
|
|
||||||
|
|
||||||
def _SimpleColorRamp():
|
def _SimpleColorRamp():
|
||||||
"""Build a simple color ramp RGB image."""
|
"""Build a simple color ramp RGB image."""
|
||||||
|
|
|
||||||
|
|
@ -171,6 +171,10 @@ class BatchNormalizationTest(test.TestCase):
|
||||||
x, x_shape, y, y_shape, delta=1e-3, x_init_value=x_init_val)
|
x, x_shape, y, y_shape, delta=1e-3, x_init_value=x_init_val)
|
||||||
_, numerical_grad = gradient_checker.compute_gradient(
|
_, numerical_grad = gradient_checker.compute_gradient(
|
||||||
x32, x_shape, y32, y_shape, delta=1e-3, x_init_value=x32_init_val)
|
x32, x_shape, y32, y_shape, delta=1e-3, x_init_value=x32_init_val)
|
||||||
|
|
||||||
|
# If grad is empty, no error.
|
||||||
|
if theoretical_grad.size == 0 and numerical_grad.size == 0:
|
||||||
|
return 0
|
||||||
return np.fabs(theoretical_grad - numerical_grad).max()
|
return np.fabs(theoretical_grad - numerical_grad).max()
|
||||||
|
|
||||||
def _test_gradient(self,
|
def _test_gradient(self,
|
||||||
|
|
@ -371,6 +375,17 @@ class BatchNormalizationTest(test.TestCase):
|
||||||
self._test_inference(
|
self._test_inference(
|
||||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||||
|
|
||||||
|
def testInferenceShape5(self):
|
||||||
|
x_shape = [0, 131, 127, 6]
|
||||||
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||||
|
|
||||||
def testTrainingShape1(self):
|
def testTrainingShape1(self):
|
||||||
x_shape = [1, 1, 6, 1]
|
x_shape = [1, 1, 6, 1]
|
||||||
for dtype in [np.float16, np.float32]:
|
for dtype in [np.float16, np.float32]:
|
||||||
|
|
@ -409,6 +424,17 @@ class BatchNormalizationTest(test.TestCase):
|
||||||
self._test_training(
|
self._test_training(
|
||||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||||
|
|
||||||
|
def testTrainingShape5(self):
|
||||||
|
x_shape = [0, 131, 127, 6]
|
||||||
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||||
|
|
||||||
def testBatchNormGradShape1(self):
|
def testBatchNormGradShape1(self):
|
||||||
for is_training in [True, False]:
|
for is_training in [True, False]:
|
||||||
x_shape = [1, 1, 6, 1]
|
x_shape = [1, 1, 6, 1]
|
||||||
|
|
@ -496,6 +522,33 @@ class BatchNormalizationTest(test.TestCase):
|
||||||
data_format='NHWC',
|
data_format='NHWC',
|
||||||
is_training=is_training)
|
is_training=is_training)
|
||||||
|
|
||||||
|
def testBatchNormGradShape5(self):
|
||||||
|
for is_training in [True, False]:
|
||||||
|
x_shape = [0, 7, 11, 4]
|
||||||
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [7],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=True,
|
||||||
|
data_format='NCHW',
|
||||||
|
is_training=is_training)
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [4],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=True,
|
||||||
|
data_format='NHWC',
|
||||||
|
is_training=is_training)
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [4],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=False,
|
||||||
|
data_format='NHWC',
|
||||||
|
is_training=is_training)
|
||||||
|
|
||||||
def _testBatchNormGradGrad(self, config):
|
def _testBatchNormGradGrad(self, config):
|
||||||
shape = config['shape']
|
shape = config['shape']
|
||||||
err_tolerance = config['err_tolerance']
|
err_tolerance = config['err_tolerance']
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,7 @@ def piecewise_constant(x, boundaries, values, name=None):
|
||||||
`float64`, `uint8`, `int8`, `int16`, `int32`, `int64`.
|
`float64`, `uint8`, `int8`, `int16`, `int32`, `int64`.
|
||||||
boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
|
boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
|
||||||
increasing entries, and with all elements having the same type as `x`.
|
increasing entries, and with all elements having the same type as `x`.
|
||||||
values: A list of `Tensor`s or float`s or `int`s that specifies the values
|
values: A list of `Tensor`s or `float`s or `int`s that specifies the values
|
||||||
for the intervals defined by `boundaries`. It should have one more element
|
for the intervals defined by `boundaries`. It should have one more element
|
||||||
than `boundaries`, and all elements should have the same type.
|
than `boundaries`, and all elements should have the same type.
|
||||||
name: A string. Optional name of the operation. Defaults to
|
name: A string. Optional name of the operation. Defaults to
|
||||||
|
|
@ -424,11 +424,12 @@ def inverse_time_decay(learning_rate, global_step, decay_steps, decay_rate,
|
||||||
return math_ops.div(learning_rate, denom, name=name)
|
return math_ops.div(learning_rate, denom, name=name)
|
||||||
|
|
||||||
|
|
||||||
def cosine_decay(learning_rate, global_step, decay_steps, name=None):
|
def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0,
|
||||||
|
name=None):
|
||||||
"""Applies cosine decay to the learning rate.
|
"""Applies cosine decay to the learning rate.
|
||||||
|
|
||||||
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
|
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
|
||||||
with Warm Restarts.
|
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
||||||
|
|
||||||
When training a model, it is often recommended to lower the learning rate as
|
When training a model, it is often recommended to lower the learning rate as
|
||||||
the training progresses. This function applies a cosine decay function
|
the training progresses. This function applies a cosine decay function
|
||||||
|
|
@ -439,7 +440,8 @@ def cosine_decay(learning_rate, global_step, decay_steps, name=None):
|
||||||
The function returns the decayed learning rate. It is computed as:
|
The function returns the decayed learning rate. It is computed as:
|
||||||
```python
|
```python
|
||||||
global_step = min(global_step, decay_steps)
|
global_step = min(global_step, decay_steps)
|
||||||
decayed = 0.5 * (1 + cos(pi * global_step / decay_steps))
|
cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps))
|
||||||
|
decayed = (1 - alpha) * cosine_decay + alpha
|
||||||
decayed_learning_rate = learning_rate * decayed
|
decayed_learning_rate = learning_rate * decayed
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -456,6 +458,8 @@ def cosine_decay(learning_rate, global_step, decay_steps, name=None):
|
||||||
Global step to use for the decay computation.
|
Global step to use for the decay computation.
|
||||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||||
Number of steps to decay over.
|
Number of steps to decay over.
|
||||||
|
alpha: A scalar `float32` or `float64` Tensor or a Python number.
|
||||||
|
Minimum learning rate value as a fraction of learning_rate.
|
||||||
name: String. Optional name of the operation. Defaults to 'CosineDecay'.
|
name: String. Optional name of the operation. Defaults to 'CosineDecay'.
|
||||||
Returns:
|
Returns:
|
||||||
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
||||||
|
|
@ -476,7 +480,96 @@ def cosine_decay(learning_rate, global_step, decay_steps, name=None):
|
||||||
cosine_decayed = 0.5 * (
|
cosine_decayed = 0.5 * (
|
||||||
1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction))
|
1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction))
|
||||||
|
|
||||||
return math_ops.multiply(learning_rate, cosine_decayed)
|
decayed = (1 - alpha) * cosine_decayed + alpha
|
||||||
|
return math_ops.multiply(learning_rate, decayed)
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_decay_restarts(learning_rate, global_step, first_decay_steps,
|
||||||
|
t_mul=2.0, m_mul=1.0, alpha=0.0, name=None):
|
||||||
|
"""Applies cosine decay with restarts to the learning rate.
|
||||||
|
|
||||||
|
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
|
||||||
|
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
||||||
|
|
||||||
|
When training a model, it is often recommended to lower the learning rate as
|
||||||
|
the training progresses. This function applies a cosine decay function with
|
||||||
|
restarts to a provided initial learning rate. It requires a `global_step`
|
||||||
|
value to compute the decayed learning rate. You can just pass a TensorFlow
|
||||||
|
variable that you increment at each training step.
|
||||||
|
|
||||||
|
The function returns the decayed learning rate while taking into account
|
||||||
|
possible warm restarts. The learning rate multiplier first decays
|
||||||
|
from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
|
||||||
|
restart is performed. Each new warm restart runs for `t_mul` times more steps
|
||||||
|
and with `m_mul` times smaller initial learning rate.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
first_decay_steps = 1000
|
||||||
|
lr_decayed = cosine_decay_restarts(learning_rate, global_step,
|
||||||
|
first_decay_steps)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
||||||
|
The initial learning rate.
|
||||||
|
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||||
|
Global step to use for the decay computation.
|
||||||
|
first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||||
|
Number of steps to decay over.
|
||||||
|
t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||||
|
Used to derive the number of iterations in the i-th period
|
||||||
|
m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||||
|
Used to derive the initial learning rate of the i-th period:
|
||||||
|
alpha: A scalar `float32` or `float64` Tensor or a Python number.
|
||||||
|
Minimum learning rate value as a fraction of the learning_rate.
|
||||||
|
name: String. Optional name of the operation. Defaults to 'SGDRDecay'.
|
||||||
|
Returns:
|
||||||
|
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
||||||
|
learning rate.
|
||||||
|
Raises:
|
||||||
|
ValueError: if `global_step` is not supplied.
|
||||||
|
"""
|
||||||
|
if global_step is None:
|
||||||
|
raise ValueError("cosine decay restarts requires global_step")
|
||||||
|
with ops.name_scope(name, "SGDRDecay",
|
||||||
|
[learning_rate, global_step]) as name:
|
||||||
|
learning_rate = ops.convert_to_tensor(learning_rate,
|
||||||
|
name="initial_learning_rate")
|
||||||
|
dtype = learning_rate.dtype
|
||||||
|
global_step = math_ops.cast(global_step, dtype)
|
||||||
|
first_decay_steps = math_ops.cast(first_decay_steps, dtype)
|
||||||
|
alpha = math_ops.cast(alpha, dtype)
|
||||||
|
t_mul = math_ops.cast(t_mul, dtype)
|
||||||
|
m_mul = math_ops.cast(m_mul, dtype)
|
||||||
|
|
||||||
|
completed_fraction = global_step / first_decay_steps
|
||||||
|
|
||||||
|
def compute_step(completed_fraction, geometric=False):
|
||||||
|
if geometric:
|
||||||
|
i_restart = math_ops.floor(math_ops.log(1.0 - completed_fraction * (
|
||||||
|
1.0 - t_mul)) / math_ops.log(t_mul))
|
||||||
|
|
||||||
|
sum_r = (1.0 - t_mul ** i_restart) / (1.0 - t_mul)
|
||||||
|
completed_fraction = (completed_fraction - sum_r) / t_mul ** i_restart
|
||||||
|
|
||||||
|
else:
|
||||||
|
i_restart = math_ops.floor(completed_fraction)
|
||||||
|
completed_fraction = completed_fraction - i_restart
|
||||||
|
|
||||||
|
return i_restart, completed_fraction
|
||||||
|
|
||||||
|
i_restart, completed_fraction = control_flow_ops.cond(
|
||||||
|
math_ops.equal(t_mul, 1.0),
|
||||||
|
lambda: compute_step(completed_fraction, geometric=False),
|
||||||
|
lambda: compute_step(completed_fraction, geometric=True))
|
||||||
|
|
||||||
|
m_fac = m_mul ** i_restart
|
||||||
|
cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
|
||||||
|
constant_op.constant(math.pi) * completed_fraction))
|
||||||
|
decayed = (1 - alpha) * cosine_decayed + alpha
|
||||||
|
|
||||||
|
return math_ops.multiply(learning_rate, decayed, name=name)
|
||||||
|
|
||||||
|
|
||||||
def linear_cosine_decay(learning_rate, global_step, decay_steps,
|
def linear_cosine_decay(learning_rate, global_step, decay_steps,
|
||||||
|
|
@ -487,6 +580,10 @@ def linear_cosine_decay(learning_rate, global_step, decay_steps,
|
||||||
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
|
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
|
||||||
https://arxiv.org/abs/1709.07417
|
https://arxiv.org/abs/1709.07417
|
||||||
|
|
||||||
|
For the idea of warm starts here controlled by `num_periods`,
|
||||||
|
see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
|
||||||
|
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
||||||
|
|
||||||
Note that linear cosine decay is more aggressive than cosine decay and
|
Note that linear cosine decay is more aggressive than cosine decay and
|
||||||
larger initial learning rates can typically be used.
|
larger initial learning rates can typically be used.
|
||||||
|
|
||||||
|
|
@ -563,6 +660,10 @@ def noisy_linear_cosine_decay(learning_rate, global_step, decay_steps,
|
||||||
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
|
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
|
||||||
https://arxiv.org/abs/1709.07417
|
https://arxiv.org/abs/1709.07417
|
||||||
|
|
||||||
|
For the idea of warm starts here controlled by `num_periods`,
|
||||||
|
see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
|
||||||
|
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
||||||
|
|
||||||
Note that linear cosine decay is more aggressive than cosine decay and
|
Note that linear cosine decay is more aggressive than cosine decay and
|
||||||
larger initial learning rates can typically be used.
|
larger initial learning rates can typically be used.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -342,10 +342,11 @@ class InverseDecayTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
class CosineDecayTest(test_util.TensorFlowTestCase):
|
class CosineDecayTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def np_cosine_decay(self, step, decay_steps):
|
def np_cosine_decay(self, step, decay_steps, alpha=0.0):
|
||||||
step = min(step, decay_steps)
|
step = min(step, decay_steps)
|
||||||
completed_fraction = step / decay_steps
|
completed_fraction = step / decay_steps
|
||||||
return 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
|
decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
|
||||||
|
return (1.0 - alpha) * decay + alpha
|
||||||
|
|
||||||
def testDecay(self):
|
def testDecay(self):
|
||||||
num_training_steps = 1000
|
num_training_steps = 1000
|
||||||
|
|
@ -357,6 +358,77 @@ class CosineDecayTest(test_util.TensorFlowTestCase):
|
||||||
expected = self.np_cosine_decay(step, num_training_steps)
|
expected = self.np_cosine_decay(step, num_training_steps)
|
||||||
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
|
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
|
||||||
|
|
||||||
|
def testAlpha(self):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
alpha = 0.1
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
with self.test_session():
|
||||||
|
decayed_lr = learning_rate_decay.cosine_decay(
|
||||||
|
initial_lr, step, num_training_steps, alpha)
|
||||||
|
expected = self.np_cosine_decay(step, num_training_steps, alpha)
|
||||||
|
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
class CosineDecayRestartsTest(test_util.TensorFlowTestCase):
|
||||||
|
def np_cosine_decay_restarts(self, step, decay_steps, t_mul=2.0, m_mul=1.0,
|
||||||
|
alpha=0.0):
|
||||||
|
fac = 1.0
|
||||||
|
while step >= decay_steps:
|
||||||
|
step = step - decay_steps
|
||||||
|
decay_steps *= t_mul
|
||||||
|
fac *= m_mul
|
||||||
|
|
||||||
|
completed_fraction = step / decay_steps
|
||||||
|
decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
|
||||||
|
return (1.0 - alpha) * decay + alpha
|
||||||
|
|
||||||
|
def testDecay(self):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
with self.test_session():
|
||||||
|
decayed_lr = learning_rate_decay.cosine_decay_restarts(
|
||||||
|
initial_lr, step, num_training_steps)
|
||||||
|
expected = self.np_cosine_decay_restarts(step, num_training_steps)
|
||||||
|
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
|
||||||
|
|
||||||
|
def testAlpha(self):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
alpha = 0.1
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
with self.test_session():
|
||||||
|
decayed_lr = learning_rate_decay.cosine_decay_restarts(
|
||||||
|
initial_lr, step, num_training_steps, alpha=alpha)
|
||||||
|
expected = self.np_cosine_decay_restarts(step, num_training_steps,
|
||||||
|
alpha=alpha)
|
||||||
|
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
|
||||||
|
|
||||||
|
def testMMul(self):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
m_mul = 0.9
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
with self.test_session():
|
||||||
|
decayed_lr = learning_rate_decay.cosine_decay_restarts(
|
||||||
|
initial_lr, step, num_training_steps, m_mul=m_mul)
|
||||||
|
expected = self.np_cosine_decay_restarts(step, num_training_steps,
|
||||||
|
m_mul=m_mul)
|
||||||
|
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
|
||||||
|
|
||||||
|
def testTMul(self):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
t_mul = 1.0
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
with self.test_session():
|
||||||
|
decayed_lr = learning_rate_decay.cosine_decay_restarts(
|
||||||
|
initial_lr, step, num_training_steps, t_mul=t_mul)
|
||||||
|
expected = self.np_cosine_decay_restarts(step, num_training_steps,
|
||||||
|
t_mul=t_mul)
|
||||||
|
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
|
||||||
|
|
||||||
|
|
||||||
class LinearCosineDecayTest(test_util.TensorFlowTestCase):
|
class LinearCosineDecayTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ See the @{$python/train} guide.
|
||||||
@@clip_by_global_norm
|
@@clip_by_global_norm
|
||||||
@@global_norm
|
@@global_norm
|
||||||
@@cosine_decay
|
@@cosine_decay
|
||||||
|
@@cosine_decay_restarts
|
||||||
@@linear_cosine_decay
|
@@linear_cosine_decay
|
||||||
@@noisy_linear_cosine_decay
|
@@noisy_linear_cosine_decay
|
||||||
@@exponential_decay
|
@@exponential_decay
|
||||||
|
|
|
||||||
|
|
@ -266,7 +266,11 @@ tf_module {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "cosine_decay"
|
name: "cosine_decay"
|
||||||
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "cosine_decay_restarts"
|
||||||
|
argspec: "args=[\'learning_rate\', \'global_step\', \'first_decay_steps\', \'t_mul\', \'m_mul\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'2.0\', \'1.0\', \'0.0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "create_global_step"
|
name: "create_global_step"
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
FROM ubuntu:14.04
|
FROM ubuntu:16.04
|
||||||
|
|
||||||
LABEL maintainer="Jan Prach <jendap@google.com>"
|
LABEL maintainer="Jan Prach <jendap@google.com>"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
FROM ubuntu:14.04
|
FROM ubuntu:16.04
|
||||||
|
|
||||||
LABEL maintainer="Jan Prach <jendap@google.com>"
|
LABEL maintainer="Jan Prach <jendap@google.com>"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
FROM ubuntu:14.04
|
FROM ubuntu:16.04
|
||||||
|
|
||||||
LABEL authors="Andrew Gibiansky <andrew.gibiansky@gmail.com>, Joel Hestness <jthestness@gmail.com>"
|
LABEL authors="Andrew Gibiansky <andrew.gibiansky@gmail.com>, Joel Hestness <jthestness@gmail.com>"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
FROM ubuntu:14.04
|
FROM ubuntu:16.04
|
||||||
|
|
||||||
LABEL maintainer="Jonathan Hseu <jhseu@google.com>"
|
LABEL maintainer="Jonathan Hseu <jhseu@google.com>"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
FROM ubuntu:14.04
|
FROM ubuntu:16.04
|
||||||
|
|
||||||
LABEL maintainer="Jan Prach <jendap@google.com>"
|
LABEL maintainer="Jan Prach <jendap@google.com>"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
FROM ubuntu:14.04
|
FROM ubuntu:16.04
|
||||||
|
|
||||||
LABEL maintainer="Jan Prach <jendap@google.com>"
|
LABEL maintainer="Jan Prach <jendap@google.com>"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -111,9 +111,9 @@ do_pylint() {
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ $1 == "PYTHON2" ]]; then
|
if [[ $1 == "PYTHON2" ]]; then
|
||||||
PYLINT_BIN="python /usr/local/lib/python2.7/dist-packages/pylint/lint.py"
|
PYLINT_BIN="python -m pylint"
|
||||||
elif [[ $1 == "PYTHON3" ]]; then
|
elif [[ $1 == "PYTHON3" ]]; then
|
||||||
PYLINT_BIN="python3 /usr/local/lib/python3.4/dist-packages/pylint/lint.py"
|
PYLINT_BIN="python3 -m pylint"
|
||||||
else
|
else
|
||||||
echo "Unrecognized python version (PYTHON2 | PYTHON3): $1"
|
echo "Unrecognized python version (PYTHON2 | PYTHON3): $1"
|
||||||
return 1
|
return 1
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
# Select bazel version.
|
# Select bazel version.
|
||||||
BAZEL_VERSION="0.5.4"
|
BAZEL_VERSION="0.8.0"
|
||||||
|
|
||||||
set +e
|
set +e
|
||||||
local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}')
|
local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}')
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,24 @@
|
||||||
|
|
||||||
# fkrull/deadsnakes is for Python3.6
|
# fkrull/deadsnakes is for Python3.6
|
||||||
add-apt-repository -y ppa:fkrull/deadsnakes
|
add-apt-repository -y ppa:fkrull/deadsnakes
|
||||||
|
|
||||||
apt-get update
|
apt-get update
|
||||||
|
apt-get upgrade
|
||||||
|
|
||||||
|
# Install python dep
|
||||||
|
apt-get install python-dev
|
||||||
|
# Install bz2 dep
|
||||||
|
apt-get install libbz2-dev
|
||||||
|
# Install curses dep
|
||||||
|
apt-get install libncurses5 libncurses5-dev
|
||||||
|
apt-get install libncursesw5 libncursesw5-dev
|
||||||
|
# Install readline dep
|
||||||
|
apt-get install libreadline6 libreadline6-dev
|
||||||
|
# Install sqlite3 dependencies
|
||||||
|
apt-get install libsqlite3-dev
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
# Install Python 3.6 and dev library
|
# Install Python 3.6 and dev library
|
||||||
wget https://www.python.org/ftp/python/3.6.1/Python-3.6.1.tar.xz
|
wget https://www.python.org/ftp/python/3.6.1/Python-3.6.1.tar.xz
|
||||||
tar xvf Python-3.6.1.tar.xz
|
tar xvf Python-3.6.1.tar.xz
|
||||||
|
|
@ -63,6 +78,10 @@ pip3 install scikit-learn==0.18.1
|
||||||
# pandas required by `inflow`
|
# pandas required by `inflow`
|
||||||
pip3 install pandas==0.19.2
|
pip3 install pandas==0.19.2
|
||||||
|
|
||||||
|
pip3 install gnureadline
|
||||||
|
|
||||||
|
pip3 install bz2file
|
||||||
|
|
||||||
# Install recent-enough version of wheel for Python 3.6 wheel builds
|
# Install recent-enough version of wheel for Python 3.6 wheel builds
|
||||||
pip3 install wheel==0.29.0
|
pip3 install wheel==0.29.0
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,7 @@ load(
|
||||||
|
|
||||||
py_binary(
|
py_binary(
|
||||||
name = "tf_upgrade",
|
name = "tf_upgrade",
|
||||||
srcs = [
|
srcs = ["tf_upgrade.py"],
|
||||||
"ast_edits.py",
|
|
||||||
"tf_upgrade.py",
|
|
||||||
],
|
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -22,7 +19,7 @@ py_test(
|
||||||
srcs = ["tf_upgrade_test.py"],
|
srcs = ["tf_upgrade_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"tf_upgrade",
|
":tf_upgrade",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
|
|
@ -48,11 +45,11 @@ genrule(
|
||||||
"test_file_v1_0.py",
|
"test_file_v1_0.py",
|
||||||
"report.txt",
|
"report.txt",
|
||||||
],
|
],
|
||||||
cmd = ("$(location tf_upgrade)" +
|
cmd = ("$(location :tf_upgrade)" +
|
||||||
" --infile $(location testdata/test_file_v0_11.py)" +
|
" --infile $(location testdata/test_file_v0_11.py)" +
|
||||||
" --outfile $(location test_file_v1_0.py)" +
|
" --outfile $(location test_file_v1_0.py)" +
|
||||||
" --reportfile $(location report.txt)"),
|
" --reportfile $(location report.txt)"),
|
||||||
tools = ["tf_upgrade"],
|
tools = [":tf_upgrade"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
|
|
||||||
|
|
@ -19,11 +19,486 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import ast
|
||||||
from tensorflow.tools.compatibility import ast_edits
|
import collections
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
class APIChangeSpec(object):
|
||||||
|
"""This class defines the transformations that need to happen.
|
||||||
|
|
||||||
|
This class must provide the following fields:
|
||||||
|
|
||||||
|
* `function_keyword_renames`: maps function names to a map of old -> new
|
||||||
|
argument names
|
||||||
|
* `function_renames`: maps function names to new function names
|
||||||
|
* `change_to_function`: a set of function names that have changed (for
|
||||||
|
notifications)
|
||||||
|
* `function_reorders`: maps functions whose argument order has changed to the
|
||||||
|
list of arguments in the new order
|
||||||
|
* `function_handle`: maps function names to custom handlers for the function
|
||||||
|
|
||||||
|
For an example, see `TFAPIChangeSpec`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class _FileEditTuple(collections.namedtuple(
|
||||||
|
"_FileEditTuple", ["comment", "line", "start", "old", "new"])):
|
||||||
|
"""Each edit that is recorded by a _FileEditRecorder.
|
||||||
|
|
||||||
|
Fields:
|
||||||
|
comment: A description of the edit and why it was made.
|
||||||
|
line: The line number in the file where the edit occurs (1-indexed).
|
||||||
|
start: The line number in the file where the edit occurs (0-indexed).
|
||||||
|
old: text string to remove (this must match what was in file).
|
||||||
|
new: text string to add in place of `old`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
|
||||||
|
class _FileEditRecorder(object):
|
||||||
|
"""Record changes that need to be done to the file."""
|
||||||
|
|
||||||
|
def __init__(self, filename):
|
||||||
|
# all edits are lists of chars
|
||||||
|
self._filename = filename
|
||||||
|
|
||||||
|
self._line_to_edit = collections.defaultdict(list)
|
||||||
|
self._errors = []
|
||||||
|
|
||||||
|
def process(self, text):
|
||||||
|
"""Process a list of strings, each corresponding to the recorded changes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: A list of lines of text (assumed to contain newlines)
|
||||||
|
Returns:
|
||||||
|
A tuple of the modified text and a textual description of what is done.
|
||||||
|
Raises:
|
||||||
|
ValueError: if substitution source location does not have expected text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
change_report = ""
|
||||||
|
|
||||||
|
# Iterate of each line
|
||||||
|
for line, edits in self._line_to_edit.items():
|
||||||
|
offset = 0
|
||||||
|
# sort by column so that edits are processed in order in order to make
|
||||||
|
# indexing adjustments cumulative for changes that change the string
|
||||||
|
# length
|
||||||
|
edits.sort(key=lambda x: x.start)
|
||||||
|
|
||||||
|
# Extract each line to a list of characters, because mutable lists
|
||||||
|
# are editable, unlike immutable strings.
|
||||||
|
char_array = list(text[line - 1])
|
||||||
|
|
||||||
|
# Record a description of the change
|
||||||
|
change_report += "%r Line %d\n" % (self._filename, line)
|
||||||
|
change_report += "-" * 80 + "\n\n"
|
||||||
|
for e in edits:
|
||||||
|
change_report += "%s\n" % e.comment
|
||||||
|
change_report += "\n Old: %s" % (text[line - 1])
|
||||||
|
|
||||||
|
# Make underscore buffers for underlining where in the line the edit was
|
||||||
|
change_list = [" "] * len(text[line - 1])
|
||||||
|
change_list_new = [" "] * len(text[line - 1])
|
||||||
|
|
||||||
|
# Iterate for each edit
|
||||||
|
for e in edits:
|
||||||
|
# Create effective start, end by accounting for change in length due
|
||||||
|
# to previous edits
|
||||||
|
start_eff = e.start + offset
|
||||||
|
end_eff = start_eff + len(e.old)
|
||||||
|
|
||||||
|
# Make sure the edit is changing what it should be changing
|
||||||
|
old_actual = "".join(char_array[start_eff:end_eff])
|
||||||
|
if old_actual != e.old:
|
||||||
|
raise ValueError("Expected text %r but got %r" %
|
||||||
|
("".join(e.old), "".join(old_actual)))
|
||||||
|
# Make the edit
|
||||||
|
char_array[start_eff:end_eff] = list(e.new)
|
||||||
|
|
||||||
|
# Create the underline highlighting of the before and after
|
||||||
|
change_list[e.start:e.start + len(e.old)] = "~" * len(e.old)
|
||||||
|
change_list_new[start_eff:end_eff] = "~" * len(e.new)
|
||||||
|
|
||||||
|
# Keep track of how to generate effective ranges
|
||||||
|
offset += len(e.new) - len(e.old)
|
||||||
|
|
||||||
|
# Finish the report comment
|
||||||
|
change_report += " %s\n" % "".join(change_list)
|
||||||
|
text[line - 1] = "".join(char_array)
|
||||||
|
change_report += " New: %s" % (text[line - 1])
|
||||||
|
change_report += " %s\n\n" % "".join(change_list_new)
|
||||||
|
return "".join(text), change_report, self._errors
|
||||||
|
|
||||||
|
def add(self, comment, line, start, old, new, error=None):
|
||||||
|
"""Add a new change that is needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
comment: A description of what was changed
|
||||||
|
line: Line number (1 indexed)
|
||||||
|
start: Column offset (0 indexed)
|
||||||
|
old: old text
|
||||||
|
new: new text
|
||||||
|
error: this "edit" is something that cannot be fixed automatically
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._line_to_edit[line].append(
|
||||||
|
_FileEditTuple(comment, line, start, old, new))
|
||||||
|
if error:
|
||||||
|
self._errors.append("%s:%d: %s" % (self._filename, line, error))
|
||||||
|
|
||||||
|
|
||||||
|
class _ASTCallVisitor(ast.NodeVisitor):
|
||||||
|
"""AST Visitor that processes function calls.
|
||||||
|
|
||||||
|
Updates function calls from old API version to new API version using a given
|
||||||
|
change spec.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, filename, lines, api_change_spec):
|
||||||
|
self._filename = filename
|
||||||
|
self._file_edit = _FileEditRecorder(filename)
|
||||||
|
self._lines = lines
|
||||||
|
self._api_change_spec = api_change_spec
|
||||||
|
|
||||||
|
def process(self, lines):
|
||||||
|
return self._file_edit.process(lines)
|
||||||
|
|
||||||
|
def generic_visit(self, node):
|
||||||
|
ast.NodeVisitor.generic_visit(self, node)
|
||||||
|
|
||||||
|
def _rename_functions(self, node, full_name):
|
||||||
|
function_renames = self._api_change_spec.function_renames
|
||||||
|
try:
|
||||||
|
new_name = function_renames[full_name]
|
||||||
|
self._file_edit.add("Renamed function %r to %r" % (full_name,
|
||||||
|
new_name),
|
||||||
|
node.lineno, node.col_offset, full_name, new_name)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_attribute_full_path(self, node):
|
||||||
|
"""Traverse an attribute to generate a full name e.g. tf.foo.bar.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: A Node of type Attribute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a '.'-delimited full-name or None if the tree was not a simple form.
|
||||||
|
i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
|
||||||
|
"""
|
||||||
|
curr = node
|
||||||
|
items = []
|
||||||
|
while not isinstance(curr, ast.Name):
|
||||||
|
if not isinstance(curr, ast.Attribute):
|
||||||
|
return None
|
||||||
|
items.append(curr.attr)
|
||||||
|
curr = curr.value
|
||||||
|
items.append(curr.id)
|
||||||
|
return ".".join(reversed(items))
|
||||||
|
|
||||||
|
def _find_true_position(self, node):
|
||||||
|
"""Return correct line number and column offset for a given node.
|
||||||
|
|
||||||
|
This is necessary mainly because ListComp's location reporting reports
|
||||||
|
the next token after the list comprehension list opening.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Node for which we wish to know the lineno and col_offset
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
find_open = re.compile("^\s*(\\[).*$")
|
||||||
|
find_string_chars = re.compile("['\"]")
|
||||||
|
|
||||||
|
if isinstance(node, ast.ListComp):
|
||||||
|
# Strangely, ast.ListComp returns the col_offset of the first token
|
||||||
|
# after the '[' token which appears to be a bug. Workaround by
|
||||||
|
# explicitly finding the real start of the list comprehension.
|
||||||
|
line = node.lineno
|
||||||
|
col = node.col_offset
|
||||||
|
# loop over lines
|
||||||
|
while 1:
|
||||||
|
# Reverse the text to and regular expression search for whitespace
|
||||||
|
text = self._lines[line-1]
|
||||||
|
reversed_preceding_text = text[:col][::-1]
|
||||||
|
# First find if a [ can be found with only whitespace between it and
|
||||||
|
# col.
|
||||||
|
m = find_open.match(reversed_preceding_text)
|
||||||
|
if m:
|
||||||
|
new_col_offset = col - m.start(1) - 1
|
||||||
|
return line, new_col_offset
|
||||||
|
else:
|
||||||
|
if (reversed_preceding_text=="" or
|
||||||
|
reversed_preceding_text.isspace()):
|
||||||
|
line = line - 1
|
||||||
|
prev_line = self._lines[line - 1]
|
||||||
|
# TODO(aselle):
|
||||||
|
# this is poor comment detection, but it is good enough for
|
||||||
|
# cases where the comment does not contain string literal starting/
|
||||||
|
# ending characters. If ast gave us start and end locations of the
|
||||||
|
# ast nodes rather than just start, we could use string literal
|
||||||
|
# node ranges to filter out spurious #'s that appear in string
|
||||||
|
# literals.
|
||||||
|
comment_start = prev_line.find("#")
|
||||||
|
if comment_start == -1:
|
||||||
|
col = len(prev_line) -1
|
||||||
|
elif find_string_chars.search(prev_line[comment_start:]) is None:
|
||||||
|
col = comment_start
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
# Most other nodes return proper locations (with notably does not), but
|
||||||
|
# it is not possible to use that in an argument.
|
||||||
|
return node.lineno, node.col_offset
|
||||||
|
|
||||||
|
|
||||||
|
def visit_Call(self, node): # pylint: disable=invalid-name
|
||||||
|
"""Handle visiting a call node in the AST.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Current Node
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Find a simple attribute name path e.g. "tf.foo.bar"
|
||||||
|
full_name = self._get_attribute_full_path(node.func)
|
||||||
|
|
||||||
|
# Make sure the func is marked as being part of a call
|
||||||
|
node.func.is_function_for_call = True
|
||||||
|
|
||||||
|
if full_name:
|
||||||
|
# Call special handlers
|
||||||
|
function_handles = self._api_change_spec.function_handle
|
||||||
|
if full_name in function_handles:
|
||||||
|
function_handles[full_name](self._file_edit, node)
|
||||||
|
|
||||||
|
# Examine any non-keyword argument and make it into a keyword argument
|
||||||
|
# if reordering required.
|
||||||
|
function_reorders = self._api_change_spec.function_reorders
|
||||||
|
function_keyword_renames = (
|
||||||
|
self._api_change_spec.function_keyword_renames)
|
||||||
|
|
||||||
|
if full_name in function_reorders:
|
||||||
|
reordered = function_reorders[full_name]
|
||||||
|
for idx, arg in enumerate(node.args):
|
||||||
|
lineno, col_offset = self._find_true_position(arg)
|
||||||
|
if lineno is None or col_offset is None:
|
||||||
|
self._file_edit.add(
|
||||||
|
"Failed to add keyword %r to reordered function %r"
|
||||||
|
% (reordered[idx], full_name), arg.lineno, arg.col_offset,
|
||||||
|
"", "",
|
||||||
|
error="A necessary keyword argument failed to be inserted.")
|
||||||
|
else:
|
||||||
|
keyword_arg = reordered[idx]
|
||||||
|
if (full_name in function_keyword_renames and
|
||||||
|
keyword_arg in function_keyword_renames[full_name]):
|
||||||
|
keyword_arg = function_keyword_renames[full_name][keyword_arg]
|
||||||
|
self._file_edit.add("Added keyword %r to reordered function %r"
|
||||||
|
% (reordered[idx], full_name), lineno,
|
||||||
|
col_offset, "", keyword_arg + "=")
|
||||||
|
|
||||||
|
# Examine each keyword argument and convert it to the final renamed form
|
||||||
|
renamed_keywords = ({} if full_name not in function_keyword_renames else
|
||||||
|
function_keyword_renames[full_name])
|
||||||
|
for keyword in node.keywords:
|
||||||
|
argkey = keyword.arg
|
||||||
|
argval = keyword.value
|
||||||
|
|
||||||
|
if argkey in renamed_keywords:
|
||||||
|
argval_lineno, argval_col_offset = self._find_true_position(argval)
|
||||||
|
if argval_lineno is not None and argval_col_offset is not None:
|
||||||
|
# TODO(aselle): We should scan backward to find the start of the
|
||||||
|
# keyword key. Unfortunately ast does not give you the location of
|
||||||
|
# keyword keys, so we are forced to infer it from the keyword arg
|
||||||
|
# value.
|
||||||
|
key_start = argval_col_offset - len(argkey) - 1
|
||||||
|
key_end = key_start + len(argkey) + 1
|
||||||
|
if (self._lines[argval_lineno - 1][key_start:key_end] ==
|
||||||
|
argkey + "="):
|
||||||
|
self._file_edit.add("Renamed keyword argument from %r to %r" %
|
||||||
|
(argkey, renamed_keywords[argkey]),
|
||||||
|
argval_lineno,
|
||||||
|
argval_col_offset - len(argkey) - 1,
|
||||||
|
argkey + "=", renamed_keywords[argkey] + "=")
|
||||||
|
continue
|
||||||
|
self._file_edit.add(
|
||||||
|
"Failed to rename keyword argument from %r to %r" %
|
||||||
|
(argkey, renamed_keywords[argkey]),
|
||||||
|
argval.lineno,
|
||||||
|
argval.col_offset - len(argkey) - 1,
|
||||||
|
"", "",
|
||||||
|
error="Failed to find keyword lexographically. Fix manually.")
|
||||||
|
|
||||||
|
ast.NodeVisitor.generic_visit(self, node)
|
||||||
|
|
||||||
|
def visit_Attribute(self, node): # pylint: disable=invalid-name
|
||||||
|
"""Handle bare Attributes i.e. [tf.foo, tf.bar].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Node that is of type ast.Attribute
|
||||||
|
"""
|
||||||
|
full_name = self._get_attribute_full_path(node)
|
||||||
|
if full_name:
|
||||||
|
self._rename_functions(node, full_name)
|
||||||
|
if full_name in self._api_change_spec.change_to_function:
|
||||||
|
if not hasattr(node, "is_function_for_call"):
|
||||||
|
new_text = full_name + "()"
|
||||||
|
self._file_edit.add("Changed %r to %r"%(full_name, new_text),
|
||||||
|
node.lineno, node.col_offset, full_name, new_text)
|
||||||
|
|
||||||
|
ast.NodeVisitor.generic_visit(self, node)
|
||||||
|
|
||||||
|
|
||||||
|
class ASTCodeUpgrader(object):
|
||||||
|
"""Handles upgrading a set of Python files using a given API change spec."""
|
||||||
|
|
||||||
|
def __init__(self, api_change_spec):
|
||||||
|
if not isinstance(api_change_spec, APIChangeSpec):
|
||||||
|
raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" %
|
||||||
|
type(api_change_spec))
|
||||||
|
self._api_change_spec = api_change_spec
|
||||||
|
|
||||||
|
def process_file(self, in_filename, out_filename):
|
||||||
|
"""Process the given python file for incompatible changes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_filename: filename to parse
|
||||||
|
out_filename: output file to write to
|
||||||
|
Returns:
|
||||||
|
A tuple representing number of files processed, log of actions, errors
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Write to a temporary file, just in case we are doing an implace modify.
|
||||||
|
with open(in_filename, "r") as in_file, \
|
||||||
|
tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
|
||||||
|
ret = self.process_opened_file(
|
||||||
|
in_filename, in_file, out_filename, temp_file)
|
||||||
|
|
||||||
|
shutil.move(temp_file.name, out_filename)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
# Broad exceptions are required here because ast throws whatever it wants.
|
||||||
|
# pylint: disable=broad-except
|
||||||
|
def process_opened_file(self, in_filename, in_file, out_filename, out_file):
|
||||||
|
"""Process the given python file for incompatible changes.
|
||||||
|
|
||||||
|
This function is split out to facilitate StringIO testing from
|
||||||
|
tf_upgrade_test.py.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_filename: filename to parse
|
||||||
|
in_file: opened file (or StringIO)
|
||||||
|
out_filename: output file to write to
|
||||||
|
out_file: opened file (or StringIO)
|
||||||
|
Returns:
|
||||||
|
A tuple representing number of files processed, log of actions, errors
|
||||||
|
"""
|
||||||
|
process_errors = []
|
||||||
|
text = "-" * 80 + "\n"
|
||||||
|
text += "Processing file %r\n outputting to %r\n" % (in_filename,
|
||||||
|
out_filename)
|
||||||
|
text += "-" * 80 + "\n\n"
|
||||||
|
|
||||||
|
parsed_ast = None
|
||||||
|
lines = in_file.readlines()
|
||||||
|
try:
|
||||||
|
parsed_ast = ast.parse("".join(lines))
|
||||||
|
except Exception:
|
||||||
|
text += "Failed to parse %r\n\n" % in_filename
|
||||||
|
text += traceback.format_exc()
|
||||||
|
if parsed_ast:
|
||||||
|
visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec)
|
||||||
|
visitor.visit(parsed_ast)
|
||||||
|
out_text, new_text, process_errors = visitor.process(lines)
|
||||||
|
text += new_text
|
||||||
|
if out_file:
|
||||||
|
out_file.write(out_text)
|
||||||
|
text += "\n"
|
||||||
|
return 1, text, process_errors
|
||||||
|
# pylint: enable=broad-except
|
||||||
|
|
||||||
|
def process_tree(self, root_directory, output_root_directory,
|
||||||
|
copy_other_files):
|
||||||
|
"""Processes upgrades on an entire tree of python files in place.
|
||||||
|
|
||||||
|
Note that only Python files. If you have custom code in other languages,
|
||||||
|
you will need to manually upgrade those.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_directory: Directory to walk and process.
|
||||||
|
output_root_directory: Directory to use as base.
|
||||||
|
copy_other_files: Copy files that are not touched by this converter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of files processed, the report string ofr all files, and errors
|
||||||
|
"""
|
||||||
|
|
||||||
|
# make sure output directory doesn't exist
|
||||||
|
if output_root_directory and os.path.exists(output_root_directory):
|
||||||
|
print("Output directory %r must not already exist." % (
|
||||||
|
output_root_directory))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# make sure output directory does not overlap with root_directory
|
||||||
|
norm_root = os.path.split(os.path.normpath(root_directory))
|
||||||
|
norm_output = os.path.split(os.path.normpath(output_root_directory))
|
||||||
|
if norm_root == norm_output:
|
||||||
|
print("Output directory %r same as input directory %r" % (
|
||||||
|
root_directory, output_root_directory))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Collect list of files to process (we do this to correctly handle if the
|
||||||
|
# user puts the output directory in some sub directory of the input dir)
|
||||||
|
files_to_process = []
|
||||||
|
files_to_copy = []
|
||||||
|
for dir_name, _, file_list in os.walk(root_directory):
|
||||||
|
py_files = [f for f in file_list if f.endswith(".py")]
|
||||||
|
copy_files = [f for f in file_list if not f.endswith(".py")]
|
||||||
|
for filename in py_files:
|
||||||
|
fullpath = os.path.join(dir_name, filename)
|
||||||
|
fullpath_output = os.path.join(
|
||||||
|
output_root_directory, os.path.relpath(fullpath, root_directory))
|
||||||
|
files_to_process.append((fullpath, fullpath_output))
|
||||||
|
if copy_other_files:
|
||||||
|
for filename in copy_files:
|
||||||
|
fullpath = os.path.join(dir_name, filename)
|
||||||
|
fullpath_output = os.path.join(
|
||||||
|
output_root_directory, os.path.relpath(fullpath, root_directory))
|
||||||
|
files_to_copy.append((fullpath, fullpath_output))
|
||||||
|
|
||||||
|
file_count = 0
|
||||||
|
tree_errors = []
|
||||||
|
report = ""
|
||||||
|
report += ("=" * 80) + "\n"
|
||||||
|
report += "Input tree: %r\n" % root_directory
|
||||||
|
report += ("=" * 80) + "\n"
|
||||||
|
|
||||||
|
for input_path, output_path in files_to_process:
|
||||||
|
output_directory = os.path.dirname(output_path)
|
||||||
|
if not os.path.isdir(output_directory):
|
||||||
|
os.makedirs(output_directory)
|
||||||
|
file_count += 1
|
||||||
|
_, l_report, l_errors = self.process_file(input_path, output_path)
|
||||||
|
tree_errors += l_errors
|
||||||
|
report += l_report
|
||||||
|
for input_path, output_path in files_to_copy:
|
||||||
|
output_directory = os.path.dirname(output_path)
|
||||||
|
if not os.path.isdir(output_directory):
|
||||||
|
os.makedirs(output_directory)
|
||||||
|
shutil.copy(input_path, output_path)
|
||||||
|
return file_count, report, tree_errors
|
||||||
|
|
||||||
|
|
||||||
|
class TFAPIChangeSpec(APIChangeSpec):
|
||||||
"""List of maps that describe what changed in the API."""
|
"""List of maps that describe what changed in the API."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -238,7 +713,7 @@ Simple usage:
|
||||||
default="report.txt")
|
default="report.txt")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec())
|
upgrade = ASTCodeUpgrader(TFAPIChangeSpec())
|
||||||
report_text = None
|
report_text = None
|
||||||
report_filename = args.report_filename
|
report_filename = args.report_filename
|
||||||
files_processed = 0
|
files_processed = 0
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ import tempfile
|
||||||
import six
|
import six
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test as test_lib
|
from tensorflow.python.platform import test as test_lib
|
||||||
from tensorflow.tools.compatibility import ast_edits
|
|
||||||
from tensorflow.tools.compatibility import tf_upgrade
|
from tensorflow.tools.compatibility import tf_upgrade
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -37,7 +36,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||||
def _upgrade(self, old_file_text):
|
def _upgrade(self, old_file_text):
|
||||||
in_file = six.StringIO(old_file_text)
|
in_file = six.StringIO(old_file_text)
|
||||||
out_file = six.StringIO()
|
out_file = six.StringIO()
|
||||||
upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
|
upgrader = tf_upgrade.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
|
||||||
count, report, errors = (
|
count, report, errors = (
|
||||||
upgrader.process_opened_file("test.py", in_file,
|
upgrader.process_opened_file("test.py", in_file,
|
||||||
"test_out.py", out_file))
|
"test_out.py", out_file))
|
||||||
|
|
@ -140,7 +139,7 @@ class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
||||||
upgraded = "tf.multiply(a, b)\n"
|
upgraded = "tf.multiply(a, b)\n"
|
||||||
temp_file.write(original)
|
temp_file.write(original)
|
||||||
temp_file.close()
|
temp_file.close()
|
||||||
upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
|
upgrader = tf_upgrade.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
|
||||||
upgrader.process_file(temp_file.name, temp_file.name)
|
upgrader.process_file(temp_file.name, temp_file.name)
|
||||||
self.assertAllEqual(open(temp_file.name).read(), upgraded)
|
self.assertAllEqual(open(temp_file.name).read(), upgraded)
|
||||||
os.unlink(temp_file.name)
|
os.unlink(temp_file.name)
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc
|
||||||
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
|
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
|
||||||
>>/etc/bazel.bazelrc
|
>>/etc/bazel.bazelrc
|
||||||
# Install the most recent bazel release.
|
# Install the most recent bazel release.
|
||||||
ENV BAZEL_VERSION 0.5.4
|
ENV BAZEL_VERSION 0.8.0
|
||||||
WORKDIR /
|
WORKDIR /
|
||||||
RUN mkdir /bazel && \
|
RUN mkdir /bazel && \
|
||||||
cd /bazel && \
|
cd /bazel && \
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc
|
||||||
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
|
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
|
||||||
>>/etc/bazel.bazelrc
|
>>/etc/bazel.bazelrc
|
||||||
# Install the most recent bazel release.
|
# Install the most recent bazel release.
|
||||||
ENV BAZEL_VERSION 0.5.4
|
ENV BAZEL_VERSION 0.8.0
|
||||||
WORKDIR /
|
WORKDIR /
|
||||||
RUN mkdir /bazel && \
|
RUN mkdir /bazel && \
|
||||||
cd /bazel && \
|
cd /bazel && \
|
||||||
|
|
|
||||||
|
|
@ -179,10 +179,15 @@ def find_files(pattern, root):
|
||||||
|
|
||||||
|
|
||||||
matches = ['../' + x for x in find_files('*', 'external') if '.py' not in x]
|
matches = ['../' + x for x in find_files('*', 'external') if '.py' not in x]
|
||||||
matches += ['../' + x for x in find_files('*', '_solib_k8') if '.py' not in x]
|
|
||||||
matches += [
|
so_lib_paths = [i for i in os.listdir('.')
|
||||||
'../' + x for x in find_files('*', '_solib_local') if '.py' not in x
|
if os.path.isdir(i)
|
||||||
]
|
and fnmatch.fnmatch(i, '_solib_*')]
|
||||||
|
|
||||||
|
for path in so_lib_paths:
|
||||||
|
matches.extend(
|
||||||
|
['../' + x for x in find_files('*', path) if '.py' not in x]
|
||||||
|
)
|
||||||
|
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.pyd'
|
EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.pyd'
|
||||||
|
|
|
||||||
|
|
@ -76,11 +76,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
||||||
tf_http_archive(
|
tf_http_archive(
|
||||||
name = "mkl_dnn",
|
name = "mkl_dnn",
|
||||||
urls = [
|
urls = [
|
||||||
"https://mirror.bazel.build/github.com/01org/mkl-dnn/archive/aab753280e83137ba955f8f19d72cb6aaba545ef.tar.gz",
|
"https://mirror.bazel.build/github.com/01org/mkl-dnn/archive/e0bfcaa7fcb2b1e1558f5f0676933c1db807a729.tar.gz",
|
||||||
"https://github.com/01org/mkl-dnn/archive/aab753280e83137ba955f8f19d72cb6aaba545ef.tar.gz",
|
"https://github.com/01org/mkl-dnn/archive/e0bfcaa7fcb2b1e1558f5f0676933c1db807a729.tar.gz",
|
||||||
],
|
],
|
||||||
sha256 = "fb67f255a96bd4ad39b8dd104eca5aa92200c95c1ed36e59641e6c0478eefd11",
|
sha256 = "02e244f63dd95402691a361392504c143eede9a89043426f174836638a9cbf09",
|
||||||
strip_prefix = "mkl-dnn-aab753280e83137ba955f8f19d72cb6aaba545ef",
|
strip_prefix = "mkl-dnn-e0bfcaa7fcb2b1e1558f5f0676933c1db807a729",
|
||||||
build_file = str(Label("//third_party/mkl_dnn:mkldnn.BUILD")),
|
build_file = str(Label("//third_party/mkl_dnn:mkldnn.BUILD")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user