mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Merge changes from github.
Change: 124644444
This commit is contained in:
parent
f34e397622
commit
5a65d43a9e
|
|
@ -16,7 +16,7 @@ Installed version of CUDA and cuDNN:
|
|||
If installed from binary pip package, provide:
|
||||
|
||||
1. Which pip package you installed.
|
||||
2. The output from python -c "import tensorflow; print(tensorflow.__version__)".
|
||||
2. The output from `python -c "import tensorflow; print(tensorflow.__version__)"`.
|
||||
|
||||
If installed from sources, provide the commit hash:
|
||||
|
||||
|
|
|
|||
|
|
@ -145,6 +145,12 @@ cc_library(
|
|||
"include",
|
||||
".",
|
||||
],
|
||||
defines = [
|
||||
"GPR_BACKWARDS_COMPATIBILITY_MODE",
|
||||
],
|
||||
copts = [
|
||||
"-std=c99",
|
||||
],
|
||||
deps = [
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -285,7 +285,7 @@
|
|||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
HEADER_SEARCH_PATHS = (
|
||||
"$(SRCROOT)/../../makefile/gen/proto",
|
||||
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f",
|
||||
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30",
|
||||
"$(SRCROOT)/../../makefile/downloads",
|
||||
"$(SRCROOT)/../../makefile/downloads/protobuf/src/",
|
||||
"$(SRCROOT)/../../../..",
|
||||
|
|
@ -300,6 +300,12 @@
|
|||
OTHER_LDFLAGS = (
|
||||
"-force_load",
|
||||
"$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
|
||||
"-Xlinker",
|
||||
"-S",
|
||||
"-Xlinker",
|
||||
"-x",
|
||||
"-Xlinker",
|
||||
"-dead_strip",
|
||||
);
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.google.CameraExample;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
|
|
@ -344,7 +350,7 @@
|
|||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
HEADER_SEARCH_PATHS = (
|
||||
"$(SRCROOT)/../../makefile/gen/proto",
|
||||
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f",
|
||||
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30",
|
||||
"$(SRCROOT)/../../makefile/downloads",
|
||||
"$(SRCROOT)/../../makefile/downloads/protobuf/src/",
|
||||
"$(SRCROOT)/../../../..",
|
||||
|
|
@ -359,6 +365,12 @@
|
|||
OTHER_LDFLAGS = (
|
||||
"-force_load",
|
||||
"$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
|
||||
"-Xlinker",
|
||||
"-S",
|
||||
"-Xlinker",
|
||||
"-x",
|
||||
"-Xlinker",
|
||||
"-dead_strip",
|
||||
);
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.google.CameraExample;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
|
|
|
|||
|
|
@ -276,7 +276,7 @@
|
|||
"$(SRCROOT)/../../../..",
|
||||
"$(SRCROOT)/../../makefile/downloads/protobuf/src/",
|
||||
"$(SRCROOT)/../../makefile/downloads",
|
||||
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f",
|
||||
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30",
|
||||
"$(SRCROOT)/../../makefile/gen/proto",
|
||||
);
|
||||
INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist";
|
||||
|
|
@ -289,6 +289,12 @@
|
|||
OTHER_LDFLAGS = (
|
||||
"-force_load",
|
||||
"$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
|
||||
"-Xlinker",
|
||||
"-S",
|
||||
"-Xlinker",
|
||||
"-x",
|
||||
"-Xlinker",
|
||||
"-dead_strip",
|
||||
);
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
|
|
@ -304,7 +310,7 @@
|
|||
"$(SRCROOT)/../../../..",
|
||||
"$(SRCROOT)/../../makefile/downloads/protobuf/src/",
|
||||
"$(SRCROOT)/../../makefile/downloads",
|
||||
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f",
|
||||
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30",
|
||||
"$(SRCROOT)/../../makefile/gen/proto",
|
||||
);
|
||||
INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist";
|
||||
|
|
@ -314,10 +320,16 @@
|
|||
"$(SRCROOT)/../../makefile/gen/protobuf_ios/lib",
|
||||
"$(SRCROOT)/../../makefile/gen/lib",
|
||||
);
|
||||
ONLY_ACTIVE_ARCH = NO;
|
||||
ONLY_ACTIVE_ARCH = YES;
|
||||
OTHER_LDFLAGS = (
|
||||
"-force_load",
|
||||
"$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
|
||||
"-Xlinker",
|
||||
"-S",
|
||||
"-Xlinker",
|
||||
"-x",
|
||||
"-Xlinker",
|
||||
"-dead_strip",
|
||||
);
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
|
|
|
|||
|
|
@ -358,7 +358,7 @@ class DataFeeder(object):
|
|||
else:
|
||||
if self.n_classes > 1:
|
||||
if len(self.output_shape) == 2:
|
||||
out.itemset((i, self.y[sample]), 1.0)
|
||||
out.itemset((i, int(self.y[sample])), 1.0)
|
||||
else:
|
||||
for idx, value in enumerate(self.y[sample]):
|
||||
out.itemset(tuple([i, idx, value]), 1.0)
|
||||
|
|
|
|||
|
|
@ -54,22 +54,31 @@ def batch_normalize(tensor_in,
|
|||
initializer=init_ops.random_normal_initializer(1., 0.02))
|
||||
beta = vs.get_variable("beta", [shape[-1]],
|
||||
initializer=init_ops.constant_initializer(0.))
|
||||
ema = moving_averages.ExponentialMovingAverage(decay=decay)
|
||||
if convnet:
|
||||
assign_mean, assign_var = nn.moments(tensor_in, [0, 1, 2])
|
||||
else:
|
||||
assign_mean, assign_var = nn.moments(tensor_in, [0])
|
||||
ema_assign_op = ema.apply([assign_mean, assign_var])
|
||||
ema_mean, ema_var = ema.average(assign_mean), ema.average(assign_var)
|
||||
moving_mean = vs.get_variable(
|
||||
'moving_mean',
|
||||
shape=[shape[-1]],
|
||||
initializer=init_ops.zeros_initializer,
|
||||
trainable=False)
|
||||
moving_var = vs.get_variable(
|
||||
'moving_var',
|
||||
shape=[shape[-1]],
|
||||
initializer=init_ops.ones_initializer,
|
||||
trainable=False)
|
||||
|
||||
def _update_mean_var():
|
||||
"""Internal function that updates mean and variance during training."""
|
||||
with ops.control_dependencies([ema_assign_op]):
|
||||
return array_ops_.identity(assign_mean), array_ops_.identity(assign_var)
|
||||
axis = [0, 1, 2] if convnet else [0]
|
||||
mean, var = nn.moments(tensor_in, axis)
|
||||
update_moving_mean = moving_averages.assign_moving_average(
|
||||
moving_mean, mean, decay)
|
||||
update_moving_var = moving_averages.assign_moving_average(
|
||||
moving_var, var, decay)
|
||||
with ops.control_dependencies([update_moving_mean, update_moving_var]):
|
||||
return array_ops_.identity(mean), array_ops_.identity(var)
|
||||
|
||||
is_training = array_ops_.squeeze(ops.get_collection("IS_TRAINING"))
|
||||
mean, variance = control_flow_ops.cond(is_training, _update_mean_var,
|
||||
lambda: (ema_mean, ema_var))
|
||||
lambda: (moving_mean, moving_var))
|
||||
return nn.batch_norm_with_global_normalization(
|
||||
tensor_in,
|
||||
mean,
|
||||
|
|
|
|||
|
|
@ -35,8 +35,8 @@ HOST_OBJDIR := $(MAKEFILE_DIR)/gen/host_obj/
|
|||
HOST_BINDIR := $(MAKEFILE_DIR)/gen/host_bin/
|
||||
HOST_GENDIR := $(MAKEFILE_DIR)/gen/host_obj/
|
||||
|
||||
# Which Eigen version we're using.
|
||||
EIGEN_HASH := d02e6a705c30
|
||||
# Find the current Eigen version name from the Bazel build file
|
||||
EIGEN_HASH := $(shell cat eigen.BUILD | grep archive_dir | head -1 | cut -f3 -d- | cut -f1 -d\")
|
||||
|
||||
# Settings for the host compiler.
|
||||
HOST_CXX := gcc
|
||||
|
|
@ -56,9 +56,15 @@ HOST_LIBS := \
|
|||
|
||||
# If we're on Linux, also link in the dl library.
|
||||
ifeq ($(HOST_OS),LINUX)
|
||||
HOST_LIBS += -ldl
|
||||
HOST_LIBS += -ldl -lpthread
|
||||
endif
|
||||
|
||||
# If we're on a Pi, link in pthreads and dl
|
||||
ifeq ($(HOST_OS),PI)
|
||||
HOST_LIBS += -ldl -lpthread
|
||||
endif
|
||||
|
||||
|
||||
# proto_text is a tool that converts protobufs into a form we can use more
|
||||
# compactly within TensorFlow. It's a bit like protoc, but is designed to
|
||||
# produce a much more minimal result so we can save binary space.
|
||||
|
|
@ -125,13 +131,13 @@ ifeq ($(TARGET),LINUX)
|
|||
endif
|
||||
# If we're on Linux, also link in the dl library.
|
||||
ifeq ($(TARGET),LINUX)
|
||||
LIBS += -ldl
|
||||
LIBS += -ldl -lpthread
|
||||
endif
|
||||
# If we're cross-compiling for the Raspberry Pi, use the right gcc.
|
||||
ifeq ($(TARGET),PI)
|
||||
CXX := arm-linux-gnueabihf-g++
|
||||
LDFLAGS := -L$(GENDIR)protobuf_pi/lib/ -Wl,--no-whole-archive
|
||||
LIBS += -ldl
|
||||
CXXFLAGS += -D__ANDROID_TYPES_SLIM__
|
||||
LDFLAGS := -Wl,--no-whole-archive
|
||||
LIBS += -ldl -lpthread
|
||||
LIBFLAGS += -Wl,--allow-multiple-definition -Wl,--whole-archive
|
||||
endif
|
||||
|
||||
|
|
@ -169,12 +175,16 @@ ifeq ($(TARGET),IOS)
|
|||
-Wno-c++11-narrowing \
|
||||
-mno-thumb \
|
||||
-DTF_LEAN_BINARY \
|
||||
-D__ANDROID_TYPES_SLIM__ \
|
||||
-DMIN_LOG_LEVEL=0 \
|
||||
-fno-exceptions \
|
||||
-isysroot \
|
||||
${IPHONEOS_SYSROOT}
|
||||
LDFLAGS := -arch armv7 \
|
||||
-miphoneos-version-min=${MIN_SDK_VERSION} \
|
||||
-Xlinker -S \
|
||||
-Xlinker -x \
|
||||
-Xlinker -dead_strip \
|
||||
-all_load \
|
||||
-L$(GENDIR)protobuf_ios/lib \
|
||||
-lz
|
||||
|
|
@ -186,6 +196,7 @@ ifeq ($(TARGET),IOS)
|
|||
-Wno-c++11-narrowing \
|
||||
-mno-thumb \
|
||||
-DTF_LEAN_BINARY \
|
||||
-D__ANDROID_TYPES_SLIM__ \
|
||||
-DMIN_LOG_LEVEL=0 \
|
||||
-fno-exceptions \
|
||||
-isysroot \
|
||||
|
|
@ -205,6 +216,7 @@ ifeq ($(TARGET),IOS)
|
|||
-D__thread= \
|
||||
-Wno-c++11-narrowing \
|
||||
-DTF_LEAN_BINARY \
|
||||
-D__ANDROID_TYPES_SLIM__ \
|
||||
-DMIN_LOG_LEVEL=0 \
|
||||
-fno-exceptions \
|
||||
-isysroot \
|
||||
|
|
@ -224,6 +236,7 @@ ifeq ($(TARGET),IOS)
|
|||
-D__thread= \
|
||||
-Wno-c++11-narrowing \
|
||||
-DTF_LEAN_BINARY \
|
||||
-D__ANDROID_TYPES_SLIM__ \
|
||||
-DMIN_LOG_LEVEL=0 \
|
||||
-fno-exceptions \
|
||||
-isysroot \
|
||||
|
|
@ -243,6 +256,7 @@ ifeq ($(TARGET),IOS)
|
|||
-D__thread= \
|
||||
-Wno-c++11-narrowing \
|
||||
-DTF_LEAN_BINARY \
|
||||
-D__ANDROID_TYPES_SLIM__ \
|
||||
-DMIN_LOG_LEVEL=0 \
|
||||
-fno-exceptions \
|
||||
-isysroot \
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ at `tensorflow/contrib/makefile/gen/bin/benchmark`. To run the executable, use:
|
|||
tensorflow/contrib/makefile/gen/bin/benchmark --graph=tensorflow_inception_graph.pb
|
||||
```
|
||||
|
||||
You should download the example graph from [http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz).
|
||||
You should download the example graph from [https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip).
|
||||
|
||||
## Supported Systems
|
||||
|
||||
|
|
@ -132,24 +132,31 @@ static library in a simple app.
|
|||
|
||||
## Raspberry Pi
|
||||
|
||||
The easiest way to build for the Raspberry Pi is to cross-compile from Linux.
|
||||
To use this makefile to do that, you first need to install the right version of
|
||||
the compiler to target the Pi, using a command like this on your Linux machine:
|
||||
Building on the Raspberry Pi is similar to a normal Linux system, though we
|
||||
recommend starting by compiling and installing protobuf:
|
||||
|
||||
```bash
|
||||
sudo apt-get install g++-arm-linux-gnueabihf
|
||||
cd tensorflow/contrib/makefile/downloads/protobuf/
|
||||
./autogen.sh
|
||||
./configure
|
||||
make
|
||||
sudo make install
|
||||
cd ../../../../..
|
||||
```
|
||||
|
||||
After that, run `tensorflow/contrib/makefile/compile_pi_protobuf.sh` to build a
|
||||
version of the protobuf library aimed at the Pi. Then you should be able to run:
|
||||
Once that's done, you can use make to build the library and example:
|
||||
|
||||
```bash
|
||||
make -f tensorflow/contrib/makefile/Makefile TARGET=PI
|
||||
make -f tensorflow/contrib/makefile/Makefile HOST_OS=PI TARGET=PI OPTFLAGS="-Os"
|
||||
```
|
||||
|
||||
This will build the static library, and the example benchmark executable. You
|
||||
can then copy the `tensorflow/contrib/makefile/gen/bin/benchmark` program over
|
||||
to your Raspberry Pi, and run it there.
|
||||
If you're only interested in building for Raspberry Pi's 2 and 3, you can supply
|
||||
some extra optimization flags to give you code that will run faster:
|
||||
|
||||
```bash
|
||||
make -f tensorflow/contrib/makefile/Makefile HOST_OS=PI TARGET=PI \
|
||||
OPTFLAGS="-Os -mfpu=neon-vfpv4 -funsafe-math-optimizations -ftree-vectorize"
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,12 @@ DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads
|
|||
|
||||
mkdir ${DOWNLOADS_DIR}
|
||||
|
||||
EIGEN_HASH=d02e6a705c30
|
||||
EIGEN_HASH=62a2305d5734
|
||||
if [ -f eigen.BUILD ]; then
|
||||
# Grab the current Eigen version name from the Bazel build file
|
||||
EIGEN_HASH=$(cat eigen.BUILD | grep archive_dir | head -1 | cut -f3 -d- | cut -f1 -d\")
|
||||
fi
|
||||
|
||||
curl "https://bitbucket.org/eigen/eigen/get/${EIGEN_HASH}.tar.gz" \
|
||||
-o /tmp/eigen-${EIGEN_HASH}.tar.gz
|
||||
tar xzf /tmp/eigen-${EIGEN_HASH}.tar.gz -C ${DOWNLOADS_DIR}
|
||||
|
|
|
|||
|
|
@ -1937,4 +1937,141 @@ REGISTER_KERNELS(GPU, double);
|
|||
#undef REGISTER_CPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
|
||||
// Note, this op works on cpu only.
|
||||
template <typename T, typename Tindex>
|
||||
class SparseApplyRMSPropOp : public OpKernel {
|
||||
public:
|
||||
explicit SparseApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
||||
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor ms = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor mom = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", def().input(0)));
|
||||
OP_REQUIRES(
|
||||
ctx, ms.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", def().input(1)));
|
||||
OP_REQUIRES(
|
||||
ctx, mom.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", def().input(2)));
|
||||
|
||||
const Tensor& lr = ctx->input(3);
|
||||
const Tensor& rho = ctx->input(4);
|
||||
const Tensor& momentum = ctx->input(5);
|
||||
const Tensor& epsilon = ctx->input(6);
|
||||
const Tensor& grad = ctx->input(7);
|
||||
const Tensor& indices = ctx->input(8);
|
||||
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
|
||||
errors::InvalidArgument("lr is not a scalar: ",
|
||||
lr.shape().DebugString()));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
|
||||
errors::InvalidArgument("rho is not a scalar: ",
|
||||
rho.shape().DebugString()));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
|
||||
errors::InvalidArgument("momentum is not a scalar: ",
|
||||
momentum.shape().DebugString()));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
|
||||
errors::InvalidArgument("epsilon is not a scalar: ",
|
||||
epsilon.shape().DebugString()));
|
||||
|
||||
OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
|
||||
errors::InvalidArgument("var and ms do not have the same shape",
|
||||
var.shape().DebugString(), " ",
|
||||
ms.shape().DebugString()));
|
||||
|
||||
OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
|
||||
errors::InvalidArgument(
|
||||
"var and mom do not have the same shape",
|
||||
var.shape().DebugString(), " ", mom.shape().DebugString()));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.shape().IsSameSize(grad.shape()),
|
||||
errors::InvalidArgument("var and grad do not have the same shape",
|
||||
var.shape().DebugString(), " ",
|
||||
grad.shape().DebugString()));
|
||||
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
|
||||
errors::InvalidArgument("indices must be one-dimensional"));
|
||||
|
||||
const Tindex N = indices.dim_size(0);
|
||||
OP_REQUIRES(
|
||||
ctx, grad.dim_size(0) == N,
|
||||
errors::InvalidArgument(
|
||||
"grad must be the same size as indices in the first dimension."));
|
||||
|
||||
if (N > 0) {
|
||||
const Tindex first_dim_size = var.dim_size(0);
|
||||
// Validate all the indices are in range
|
||||
auto indices_vec = indices.vec<Tindex>();
|
||||
for (Tindex i = 0; i < N; i++) {
|
||||
const Tindex index = indices_vec(i);
|
||||
OP_REQUIRES(ctx, index >= 0 && index < first_dim_size,
|
||||
errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range")));
|
||||
}
|
||||
|
||||
auto var_flat = var.flat_outer_dims<T>();
|
||||
auto ms_flat = ms.flat_outer_dims<T>();
|
||||
auto mom_flat = mom.flat_outer_dims<T>();
|
||||
auto grad_flat = grad.flat_outer_dims<T>();
|
||||
const T lr_scalar = lr.scalar<T>()();
|
||||
const T rho_scalar = rho.scalar<T>()();
|
||||
const T epsilon_scalar = epsilon.scalar<T>()();
|
||||
const T momentum_scalar = momentum.scalar<T>()();
|
||||
|
||||
for (Tindex i = 0; i < N; i++) {
|
||||
const Tindex index = indices_vec(i);
|
||||
|
||||
auto ms_ = ms_flat.template chip<0>(index);
|
||||
auto mom_ = mom_flat.template chip<0>(index);
|
||||
auto grad_ = grad_flat.template chip<0>(i);
|
||||
|
||||
ms_ = ms_ * ms_.constant(rho_scalar) +
|
||||
grad_.square() * grad_.constant(T(1) - rho_scalar);
|
||||
mom_ = mom_ * mom_.constant(momentum_scalar) +
|
||||
(ms_ + ms_.constant(epsilon_scalar)).rsqrt() *
|
||||
ms_.constant(lr_scalar) * grad_;
|
||||
|
||||
auto v = var_flat.template chip<0>(index);
|
||||
v -= mom_;
|
||||
}
|
||||
}
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(T, Tindices) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyRMSPropOp<T, Tindices>);
|
||||
|
||||
REGISTER_KERNELS(Eigen::half, int32);
|
||||
REGISTER_KERNELS(Eigen::half, int64);
|
||||
REGISTER_KERNELS(float, int32);
|
||||
REGISTER_KERNELS(float, int64);
|
||||
REGISTER_KERNELS(double, int32);
|
||||
REGISTER_KERNELS(double, int64);
|
||||
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -1715,7 +1715,7 @@ REGISTER_OP("ExtractImagePatches")
|
|||
.Attr("T: realnumbertype")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Doc(R"doc(
|
||||
Extract `patches` from `images` and puth them in the "depth" output dimension.
|
||||
Extract `patches` from `images` and put them in the "depth" output dimension.
|
||||
|
||||
images: 4-D Tensor with shape `[batch, in_rows, in_cols, depth]`.
|
||||
patches: 4-D Tensor with shape `[batch, out_rows, out_cols, ksize_rows *
|
||||
|
|
|
|||
|
|
@ -441,6 +441,9 @@ REGISTER_OP("ApplyRMSProp")
|
|||
.Attr("use_locking: bool = false")
|
||||
.Doc(R"doc(
|
||||
Update '*var' according to the RMSProp algorithm.
|
||||
Note that in dense implement of this algorithm, ms and mom will
|
||||
update even if the grad is zero, but in this sparse implement, ms
|
||||
and mom will not update in iterations the grad is zero.
|
||||
|
||||
mean_square = decay * mean_square + (1-decay) * gradient ** 2
|
||||
Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
|
||||
|
|
@ -461,5 +464,46 @@ use_locking: If `True`, updating of the var, m, and v tensors will be protected
|
|||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SparseApplyRMSProp")
|
||||
.Input("var: Ref(T)")
|
||||
.Input("ms: Ref(T)")
|
||||
.Input("mom: Ref(T)")
|
||||
.Input("lr: T")
|
||||
.Input("rho: T")
|
||||
.Input("momentum: T")
|
||||
.Input("epsilon: T")
|
||||
.Input("grad: T")
|
||||
.Input("indices: Tindices")
|
||||
.Output("out: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Doc(R"doc(
|
||||
Update '*var' according to the RMSProp algorithm.
|
||||
Note that in dense implement of this algorithm, ms and mom will
|
||||
update even if the grad is zero, but in this sparse implement, ms
|
||||
and mom will not update in iterations the grad is zero.
|
||||
|
||||
mean_square = decay * mean_square + (1-decay) * gradient ** 2
|
||||
Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
|
||||
|
||||
ms <- rho * ms_{t-1} + (1-rho) * grad * grad
|
||||
mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
|
||||
var <- var - mom
|
||||
|
||||
var: Should be from a Variable().
|
||||
ms: Should be from a Variable().
|
||||
mom: Should be from a Variable().
|
||||
lr: Scaling factor. Must be a scalar.
|
||||
epsilon: Ridge term. Must be a scalar.
|
||||
rho: Decay rate. Must be a scalar.
|
||||
grad: The gradient.
|
||||
indices: A vector of indices into the first dimension of var, ms and mom.
|
||||
out: Same as "var".
|
||||
use_locking: If `True`, updating of the var, m, and v tensors will be protected
|
||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -37,6 +37,13 @@ limitations under the License.
|
|||
#define IS_MOBILE_PLATFORM
|
||||
#endif
|
||||
|
||||
#elif defined(__arm__)
|
||||
#define PLATFORM_POSIX
|
||||
|
||||
// Since there's no macro for the Raspberry Pi, assume we're on a mobile
|
||||
// platform if we're compiling for the ARM CPU.
|
||||
#define IS_MOBILE_PLATFORM
|
||||
|
||||
#else
|
||||
// If no platform specified, use:
|
||||
#define PLATFORM_POSIX
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ Some examples use the `pandas` library for data processing (`sudo pip install pa
|
|||
## Image classification
|
||||
|
||||
* [Convolutional Neural Networks on MNIST Data](mnist.py)
|
||||
* [Recurrent Neural Networks on MNIST Data](mnist_rnn.py)
|
||||
* [Deep Residual Networks on MNIST Data](resnet.py)
|
||||
|
||||
|
||||
|
|
|
|||
78
tensorflow/examples/skflow/mnist_rnn.py
Normal file
78
tensorflow/examples/skflow/mnist_rnn.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This example builds rnn network for mnist data.
|
||||
Borrowed structure from here: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3%20-%20Neural%20Networks/recurrent_network.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from sklearn import metrics, preprocessing
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib import learn
|
||||
|
||||
# Parameters
|
||||
learning_rate = 0.1
|
||||
training_steps = 3000
|
||||
batch_size = 128
|
||||
|
||||
# Network Parameters
|
||||
n_input = 28 # MNIST data input (img shape: 28*28)
|
||||
n_steps = 28 # timesteps
|
||||
n_hidden = 128 # hidden layer num of features
|
||||
n_classes = 10 # MNIST total classes (0-9 digits)
|
||||
|
||||
### Download and load MNIST data.
|
||||
mnist = learn.datasets.load_dataset('mnist')
|
||||
|
||||
X_train = mnist.train.images
|
||||
y_train = mnist.train.labels
|
||||
X_test = mnist.test.images
|
||||
y_test = mnist.test.labels
|
||||
|
||||
# It's useful to scale to ensure Stochastic Gradient Descent will do the right thing
|
||||
scaler = preprocessing.StandardScaler()
|
||||
X_train = scaler.fit_transform(X_train)
|
||||
X_test = scaler.fit_transform(X_test)
|
||||
|
||||
|
||||
def rnn_model(X, y):
|
||||
X = tf.reshape(X, [-1, n_steps, n_input]) # (batch_size, n_steps, n_input)
|
||||
# # permute n_steps and batch_size
|
||||
X = tf.transpose(X, [1, 0, 2])
|
||||
# # Reshape to prepare input to hidden activation
|
||||
X = tf.reshape(X, [-1, n_input]) # (n_steps*batch_size, n_input)
|
||||
# # Split data because rnn cell needs a list of inputs for the RNN inner loop
|
||||
X = tf.split(0, n_steps, X) # n_steps * (batch_size, n_input)
|
||||
|
||||
# Define a GRU cell with tensorflow
|
||||
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden)
|
||||
# Get lstm cell output
|
||||
_, encoding = tf.nn.rnn(lstm_cell, X, dtype=tf.float32)
|
||||
|
||||
return learn.models.logistic_regression(encoding, y)
|
||||
|
||||
|
||||
classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=n_classes,
|
||||
batch_size=batch_size,
|
||||
steps=training_steps,
|
||||
learning_rate=learning_rate)
|
||||
|
||||
classifier.fit(X_train, y_train, logdir="/tmp/mnist_rnn")
|
||||
score = metrics.accuracy_score(y_test, classifier.predict(X_test))
|
||||
print('Accuracy: {0:f}'.format(score))
|
||||
|
|
@ -106,7 +106,7 @@ idempotent operation that simply divides `total` by `count`.
|
|||
To facilitate the estimation of the accuracy over a stream of data, the
|
||||
function utilizes two operations. First, an `is_correct` operation that
|
||||
computes a tensor whose shape matches `predictions` and whose elements are
|
||||
set to 1.0 when the corresponding values of `predictions` and `labels match
|
||||
set to 1.0 when the corresponding values of `predictions` and `labels` match
|
||||
and 0.0 otherwise. Second, an `update_op` operation whose behavior is
|
||||
dependent on the value of `weights`. If `weights` is None, then `update_op`
|
||||
increments `total` with the number of elements of `predictions` that match
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ Install TensorFlow:
|
|||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7:
|
||||
$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4.
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and cuDNN v4.
|
||||
# For other versions, see "Install from sources" below.
|
||||
$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
|
|
@ -77,7 +77,7 @@ For python3:
|
|||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4:
|
||||
$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4.
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and cuDNN v4.
|
||||
# For other versions, see "Install from sources" below.
|
||||
$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
|
|
@ -137,7 +137,7 @@ $ source ~/tensorflow/bin/activate.csh # If using csh
|
|||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7:
|
||||
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4.
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and cuDNN v4.
|
||||
# For other versions, see "Install from sources" below.
|
||||
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
|
|
@ -155,7 +155,7 @@ $ source ~/tensorflow/bin/activate.csh # If using csh
|
|||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4:
|
||||
(tensorflow)$ pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4.
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and cuDNN v4.
|
||||
# For other versions, see "Install from sources" below.
|
||||
(tensorflow)$ pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
|
|
@ -228,7 +228,7 @@ $ source activate tensorflow
|
|||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7:
|
||||
(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4.
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and cuDNN v4.
|
||||
# For other versions, see "Install from sources" below.
|
||||
(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
|
|
@ -245,7 +245,7 @@ $ source activate tensorflow
|
|||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4:
|
||||
(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4.
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and cuDNN v4.
|
||||
# For other versions, see "Install from sources" below.
|
||||
(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
|
|
@ -314,7 +314,7 @@ $ docker run -it -p 8888:8888 gcr.io/tensorflow/tensorflow
|
|||
|
||||
The option `-p 8888:8888` is used to publish the Docker container᾿s internal port to the host machine, in this case to ensure Jupyter notebook connection.
|
||||
|
||||
The format of the port mapping `hostPort:containerPort`. You can speficy any valid port number for the host port but has to be `8888` for the container port portion.
|
||||
The format of the port mapping is `hostPort:containerPort`. You can specify any valid port number for the host port but have to use `8888` for the container port portion.
|
||||
|
||||
If you're using a container with GPU support, some additional flags must be
|
||||
passed to expose the GPU device to the container. For the default config, we
|
||||
|
|
@ -526,7 +526,7 @@ empty to use system default]: 7.5
|
|||
Please specify the location where CUDA 7.5 toolkit is installed. Refer to
|
||||
README.md for more details. [default is: /usr/local/cuda]: /usr/local/cuda
|
||||
|
||||
Please specify the Cudnn version you want to use. [Leave empty to use system
|
||||
Please specify the cuDNN version you want to use. [Leave empty to use system
|
||||
default]: 4.0.4
|
||||
|
||||
Please specify the location where the cuDNN 4.0.4 library is installed. Refer to
|
||||
|
|
@ -549,7 +549,7 @@ Configuration finished
|
|||
|
||||
This creates a canonical set of symbolic links to the Cuda libraries on your system.
|
||||
Every time you change the Cuda library paths you need to run this step again before
|
||||
you invoke the bazel build command. For the Cudnn libraries, use '6.5' for R2, '7.0'
|
||||
you invoke the bazel build command. For the cuDNN libraries, use '6.5' for R2, '7.0'
|
||||
for R3, and '4.0.4' for R4-RC.
|
||||
|
||||
|
||||
|
|
@ -672,7 +672,7 @@ GPU support will be enabled for TensorFlow
|
|||
Please specify which gcc nvcc should use as the host compiler. [Default is /usr/bin/gcc]:
|
||||
Please specify the Cuda SDK version you want to use, e.g. 7.0. [Leave empty to use system default]: 7.5
|
||||
Please specify the location where CUDA 7.5 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]:
|
||||
Please specify the Cudnn version you want to use. [Leave empty to use system default]: 5
|
||||
Please specify the cuDNN version you want to use. [Leave empty to use system default]: 5
|
||||
Please specify the location where cuDNN 5 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda]:
|
||||
Please specify a list of comma-separated Cuda compute capabilities you want to build with.
|
||||
You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
|
||||
|
|
|
|||
|
|
@ -173,5 +173,40 @@ class ReverseTest(test_util.TensorFlowTestCase):
|
|||
tf.reverse(data_2d_t, dims_3d_t)
|
||||
|
||||
|
||||
class MeshgridTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _compare(self, n, np_dtype, use_gpu):
|
||||
inputs = []
|
||||
for i in range(n):
|
||||
x = np.linspace(-10, 10, 5).astype(np_dtype)
|
||||
if np_dtype in (np.complex64, np.complex128):
|
||||
x += 1j
|
||||
inputs.append(x)
|
||||
|
||||
numpy_out = np.meshgrid(*inputs)
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_out = array_ops.meshgrid(*inputs)
|
||||
for X, _X in zip(numpy_out, tf_out):
|
||||
self.assertAllEqual(X, _X.eval())
|
||||
|
||||
def testCompare(self):
|
||||
for t in (np.float16, np.float32, np.float64, np.int32, np.int64,
|
||||
np.complex64, np.complex128):
|
||||
# Don't test the one-dimensional case, as
|
||||
# old numpy versions don't support it
|
||||
self._compare(2, t, False)
|
||||
self._compare(3, t, False)
|
||||
self._compare(4, t, False)
|
||||
self._compare(5, t, False)
|
||||
|
||||
# Test for inputs with rank not equal to 1
|
||||
x = [[1, 1], [1, 1]]
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"needs to have rank 1"):
|
||||
with self.test_session():
|
||||
X, _ = array_ops.meshgrid(x, x)
|
||||
X.eval()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ of a tensor and change the shape of a tensor.
|
|||
@@reshape
|
||||
@@squeeze
|
||||
@@expand_dims
|
||||
@@meshgrid
|
||||
|
||||
## Slicing and Joining
|
||||
|
||||
|
|
@ -125,7 +126,7 @@ def shape(input, name=None):
|
|||
else:
|
||||
return gen_array_ops.shape(input, name=name)
|
||||
|
||||
|
||||
|
||||
def rank(input, name=None):
|
||||
"""Returns the rank of a tensor.
|
||||
|
||||
|
|
@ -1047,6 +1048,82 @@ def pad(tensor, paddings, mode="CONSTANT", name=None): # pylint: disable=invali
|
|||
raise ValueError("Unknown padding mode: %s" % mode)
|
||||
|
||||
|
||||
def meshgrid(*args, **kwargs):
|
||||
"""Broadcasts parameters for evaluation on an N-D grid.
|
||||
|
||||
Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
|
||||
of N-D coordinate arrays for evaluating expressions on an N-D grid.
|
||||
|
||||
Notes:
|
||||
|
||||
`meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
|
||||
When the `indexing` argument is set to 'xy' (the default), the broadcasting
|
||||
instructions for the first two dimensions are swapped.
|
||||
|
||||
Examples:
|
||||
|
||||
Calling `X, Y = meshgrid(x, y)` with the tensors
|
||||
```prettyprint
|
||||
x = [1, 2, 3]
|
||||
y = [4, 5, 6]
|
||||
```
|
||||
results in
|
||||
```prettyprint
|
||||
X = [[1, 1, 1],
|
||||
[2, 2, 2],
|
||||
[3, 3, 3]]
|
||||
Y = [[4, 5, 6],
|
||||
[4, 5, 6],
|
||||
[4, 5, 6]]
|
||||
```
|
||||
|
||||
Args:
|
||||
*args: `Tensor`s with rank 1
|
||||
indexing: Either 'xy' or 'ij' (optional, default: 'xy')
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
outputs: A list of N `Tensor`s with rank N
|
||||
"""
|
||||
indexing = kwargs.pop("indexing", "xy")
|
||||
name = kwargs.pop("name", "meshgrid")
|
||||
if len(kwargs) > 0:
|
||||
key = list(kwargs.keys())[0]
|
||||
raise TypeError("'{}' is an invalid keyword argument "
|
||||
"for this function".format(key))
|
||||
|
||||
if indexing not in ("xy", "ij"):
|
||||
raise ValueError("indexing parameter must be either 'xy' or 'ij'")
|
||||
|
||||
with ops.op_scope(args, name, "meshgrid") as name:
|
||||
num_inputs = len(args)
|
||||
ones = (1,) * num_inputs
|
||||
|
||||
asserts = [logging_ops.Assert(
|
||||
gen_math_ops.equal(rank(x), 1),
|
||||
["Input %d needs to have rank 1: " % i, rank(x)],
|
||||
) for i, x in enumerate(args)]
|
||||
|
||||
# Prepare reshape by inserting dimensions with size 1 where needed
|
||||
shapes = [ones[:i] + (-1,) + ones[i + 1:] for i in range(num_inputs)]
|
||||
# Create parameters for broadcasting each tensor to the full size
|
||||
sizes = [size(x) for x in args]
|
||||
bcast = [sizes[:i] + [1] + sizes[i + 1:] for i in range(num_inputs)]
|
||||
|
||||
# By default, the numpy version swaps the instructions
|
||||
# for the first and second dimension
|
||||
if indexing == "xy" and num_inputs > 1:
|
||||
shapes[0], shapes[1] = shapes[1], shapes[0]
|
||||
bcast[0], bcast[1] = bcast[1], bcast[0]
|
||||
|
||||
results = []
|
||||
with ops.control_dependencies(asserts):
|
||||
for a, r, e in zip(args, shapes, bcast):
|
||||
results.append(tile(reshape(a, r), e))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@ops.RegisterShape("Placeholder")
|
||||
def _PlaceholderShape(op):
|
||||
given_shape = tensor_util.TensorShapeProtoToList(op.get_attr("shape"))
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ class _VariableStore(object):
|
|||
|
||||
If initializer is `None` (the default), the default initializer passed in
|
||||
the constructor is used. If that one is `None` too, we use a new
|
||||
`UniformUnitScalingInitializer`. If initializer is a Tensor, we use
|
||||
`uniform_unit_scaling_initializer`. If initializer is a Tensor, we use
|
||||
it as a value and derive the shape from the initializer.
|
||||
|
||||
If the initializer is a callable, then it will be called for each
|
||||
|
|
@ -681,7 +681,7 @@ def get_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
|
|||
|
||||
If initializer is `None` (the default), the default initializer passed in
|
||||
the variable scope will be used. If that one is `None` too, a
|
||||
`UniformUnitScalingInitializer` will be used. The initializer can also be
|
||||
`uniform_unit_scaling_initializer` will be used. The initializer can also be
|
||||
a Tensor, in which case the variable is initialized to this value and shape.
|
||||
|
||||
Similarly, if the regularizer is `None` (the default), the default regularizer
|
||||
|
|
@ -757,7 +757,7 @@ def _get_partitioned_variable(
|
|||
|
||||
If initializer is `None` (the default), the default initializer passed in
|
||||
the constructor is used. If that one is `None` too, we use a new
|
||||
`UniformUnitScalingInitializer`. If initializer is a Tensor, we use
|
||||
`uniform_unit_scaling_initializer`. If initializer is a Tensor, we use
|
||||
it as a value and derive the shape from the initializer.
|
||||
|
||||
If the initializer is a callable, then it will be called for each
|
||||
|
|
|
|||
|
|
@ -64,6 +64,10 @@ class AdamOptimizer(optimizer.Optimizer):
|
|||
general. For example, when training an Inception network on ImageNet a
|
||||
current good choice is 1.0 or 0.1.
|
||||
|
||||
Note that in dense implement of this algorithm, m_t, v_t and variable will
|
||||
update even if g is zero, but in sparse implement, m_t, v_t and variable
|
||||
will not update in iterations g is zero.
|
||||
|
||||
Args:
|
||||
learning_rate: A Tensor or a floating point value. The learning rate.
|
||||
beta1: A float value or a constant float tensor.
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ class RMSPropOptimizer(optimizer.Optimizer):
|
|||
name="RMSProp"):
|
||||
"""Construct a new RMSProp optimizer.
|
||||
|
||||
Note that in dense implement of this algorithm, m_t and v_t will
|
||||
update even if g is zero, but in sparse implement, m_t and v_t
|
||||
will not update in iterations g is zero.
|
||||
|
||||
Args:
|
||||
learning_rate: A Tensor or a floating point value. The learning rate.
|
||||
decay: Discounting factor for the history/coming gradient
|
||||
|
|
@ -105,4 +109,14 @@ class RMSPropOptimizer(optimizer.Optimizer):
|
|||
grad, use_locking=self._use_locking).op
|
||||
|
||||
def _apply_sparse(self, grad, var):
|
||||
raise NotImplementedError()
|
||||
rms = self.get_slot(var, "rms")
|
||||
mom = self.get_slot(var, "momentum")
|
||||
return training_ops.sparse_apply_rms_prop(
|
||||
var, rms, mom,
|
||||
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
|
||||
math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
|
||||
math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
|
||||
math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
|
||||
grad.values,
|
||||
grad.indices,
|
||||
use_locking=self._use_locking)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,131 @@ import tensorflow as tf
|
|||
|
||||
class RMSPropOptimizerTest(tf.test.TestCase):
|
||||
|
||||
def _rmsprop_update_numpy(self, var, g, rms, mom, lr, decay, momentum,
|
||||
epsilon):
|
||||
rms_t = rms * decay + (1-decay) * g * g
|
||||
mom_t = momentum * mom + lr * g / np.sqrt(rms_t + epsilon)
|
||||
var_t = var - mom_t
|
||||
return var_t, rms_t, mom_t
|
||||
|
||||
def testSparseWithMomentum(self):
|
||||
for dtype in [tf.half, tf.float32]:
|
||||
with self.test_session():
|
||||
# Initialize variables for numpy implementation.
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
var0 = tf.Variable(var0_np)
|
||||
var1 = tf.Variable(var1_np)
|
||||
grads0_np_indices = np.array([0, 1], dtype=np.int32)
|
||||
grads0 = tf.IndexedSlices(tf.constant(grads0_np),
|
||||
tf.constant(grads0_np_indices),
|
||||
tf.constant([2]))
|
||||
grads1_np_indices = np.array([0, 1], dtype=np.int32)
|
||||
grads1 = tf.IndexedSlices(tf.constant(grads1_np),
|
||||
tf.constant(grads1_np_indices),
|
||||
tf.constant([2]))
|
||||
opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
|
||||
momentum=0.5, epsilon=1e-5)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
tf.initialize_all_variables().run()
|
||||
|
||||
rms0 = opt.get_slot(var0, "rms")
|
||||
self.assertTrue(rms0 is not None)
|
||||
rms1 = opt.get_slot(var1, "rms")
|
||||
self.assertTrue(rms1 is not None)
|
||||
mom0 = opt.get_slot(var0, "momentum")
|
||||
self.assertTrue(mom0 is not None)
|
||||
mom1 = opt.get_slot(var1, "momentum")
|
||||
self.assertTrue(mom1 is not None)
|
||||
|
||||
rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
|
||||
rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
|
||||
mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||
mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
|
||||
# Run 4 steps of RMSProp
|
||||
for t in range(1, 5):
|
||||
update.run()
|
||||
|
||||
var0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(var0_np,
|
||||
grads0_np, rms0_np, mom0_np, 2.0, 0.9, 0.5, 1e-5)
|
||||
var1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(var1_np,
|
||||
grads1_np, rms1_np, mom1_np, 2.0, 0.9, 0.5, 1e-5)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
|
||||
self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
|
||||
self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
|
||||
self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
|
||||
self.assertAllCloseAccordingToType(var0_np, var0.eval())
|
||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
||||
|
||||
def testSparseWithoutMomentum(self):
|
||||
for dtype in [tf.half, tf.float32]:
|
||||
with self.test_session():
|
||||
# Initialize variables for numpy implementation.
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
var0 = tf.Variable(var0_np)
|
||||
var1 = tf.Variable(var1_np)
|
||||
grads0_np_indices = np.array([0, 1], dtype=np.int32)
|
||||
grads0 = tf.IndexedSlices(tf.constant(grads0_np),
|
||||
tf.constant(grads0_np_indices),
|
||||
tf.constant([2]))
|
||||
grads1_np_indices = np.array([0, 1], dtype=np.int32)
|
||||
grads1 = tf.IndexedSlices(tf.constant(grads1_np),
|
||||
tf.constant(grads1_np_indices),
|
||||
tf.constant([2]))
|
||||
opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
|
||||
momentum=0.0, epsilon=1.0)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
tf.initialize_all_variables().run()
|
||||
|
||||
rms0 = opt.get_slot(var0, "rms")
|
||||
self.assertTrue(rms0 is not None)
|
||||
rms1 = opt.get_slot(var1, "rms")
|
||||
self.assertTrue(rms1 is not None)
|
||||
mom0 = opt.get_slot(var0, "momentum")
|
||||
self.assertTrue(mom0 is not None)
|
||||
mom1 = opt.get_slot(var1, "momentum")
|
||||
self.assertTrue(mom1 is not None)
|
||||
|
||||
rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
|
||||
rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
|
||||
mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||
mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
|
||||
# Run 4 steps of RMSProp
|
||||
for t in range(1, 5):
|
||||
update.run()
|
||||
|
||||
var0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(var0_np,
|
||||
grads0_np, rms0_np, mom0_np, 2.0, 0.9, 0.0, 1.0)
|
||||
var1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(var1_np,
|
||||
grads1_np, rms1_np, mom1_np, 2.0, 0.9, 0.0, 1.0)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
|
||||
self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
|
||||
self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
|
||||
self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
|
||||
self.assertAllCloseAccordingToType(var0_np, var0.eval())
|
||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
||||
|
||||
def testWithoutMomentum(self):
|
||||
for dtype in [tf.half, tf.float32]:
|
||||
with self.test_session():
|
||||
|
|
|
|||
|
|
@ -170,6 +170,23 @@ def _SparseApplyProximalGradientDescentShape(op):
|
|||
return [var_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("SparseApplyRMSProp")
|
||||
def _SparseApplyRMSPropShape(op):
|
||||
"""Shape function for the SparseApplyRMSProp op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
|
||||
_AssertInputIsScalar(op, 3) # lr
|
||||
_AssertInputIsScalar(op, 4) # rho
|
||||
_AssertInputIsScalar(op, 5) # momentum
|
||||
_AssertInputIsScalar(op, 6) # epsilon
|
||||
grad_shape = op.inputs[7].get_shape().merge_with(
|
||||
tensor_shape.TensorShape([None]).concatenate(mom_shape[1:]))
|
||||
unused_indices_shape = op.inputs[8].get_shape().merge_with(
|
||||
tensor_shape.vector(grad_shape[0]))
|
||||
return [mom_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("SparseApplyAdadelta")
|
||||
def _SparseApplyAdadeltaShape(op):
|
||||
"""Shape function for the SparseApplyAdadelta op."""
|
||||
|
|
|
|||
|
|
@ -25,6 +25,12 @@ limitations under the License.
|
|||
#define EIGEN_HAS_CUDA_FP16
|
||||
#endif
|
||||
|
||||
#if CUDA_VERSION >= 8000
|
||||
#define SE_CUDA_DATA_HALF CUDA_R_16F
|
||||
#else
|
||||
#define SE_CUDA_DATA_HALF CUBLAS_DATA_HALF
|
||||
#endif
|
||||
|
||||
#include "tensorflow/stream_executor/cuda/cuda_blas.h"
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
|
@ -1680,10 +1686,10 @@ bool CUDABlas::DoBlasGemm(
|
|||
return DoBlasInternal(
|
||||
dynload::cublasSgemmEx, stream, true /* = pointer_mode_host */,
|
||||
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
|
||||
CUDAMemory(a), CUBLAS_DATA_HALF, lda,
|
||||
CUDAMemory(b), CUBLAS_DATA_HALF, ldb,
|
||||
CUDAMemory(a), SE_CUDA_DATA_HALF, lda,
|
||||
CUDAMemory(b), SE_CUDA_DATA_HALF, ldb,
|
||||
&beta,
|
||||
CUDAMemoryMutable(c), CUBLAS_DATA_HALF, ldc);
|
||||
CUDAMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
|
||||
#else
|
||||
LOG(ERROR) << "fp16 sgemm is not implemented in this cuBLAS version "
|
||||
<< "(need at least CUDA 7.5)";
|
||||
|
|
|
|||
|
|
@ -267,8 +267,11 @@ export class Minimap {
|
|||
downloadContext.drawImage(image, 0, 0,
|
||||
this.downloadCanvas.width, this.downloadCanvas.height);
|
||||
};
|
||||
let blob = new Blob([svgXml], {type: 'image/svg+xml;charset=utf-8'});
|
||||
image.src = URL.createObjectURL(blob);
|
||||
image.onerror = () => {
|
||||
let blob = new Blob([svgXml], {type: 'image/svg+xml;charset=utf-8'});
|
||||
image.src = URL.createObjectURL(blob);
|
||||
}
|
||||
image.src = 'data:image/svg+xml;charset=utf-8,' + encodeURIComponent(svgXml);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -9987,8 +9987,11 @@ var tf;
|
|||
downloadContext.clearRect(0, 0, _this.downloadCanvas.width, _this.downloadCanvas.height);
|
||||
downloadContext.drawImage(image, 0, 0, _this.downloadCanvas.width, _this.downloadCanvas.height);
|
||||
};
|
||||
var blob = new Blob([svgXml], { type: 'image/svg+xml;charset=utf-8' });
|
||||
image.src = URL.createObjectURL(blob);
|
||||
image.onerror = function() {
|
||||
var blob = new Blob([svgXml], {type: "image/svg+xml;charset=utf-8"});
|
||||
image.src = URL.createObjectURL(blob);
|
||||
};
|
||||
image.src = "data:image/svg+xml;charset=utf-8," + encodeURIComponent(svgXml);
|
||||
};
|
||||
/**
|
||||
* Handles changes in zooming/panning. Should be called from the main svg
|
||||
|
|
|
|||
|
|
@ -14,16 +14,8 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
set -e
|
||||
|
||||
export CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
|
||||
|
||||
if [ ! -d ${CUDA_HOME}/lib64 ]; then
|
||||
echo "Failed to locate CUDA libs at ${CUDA_HOME}/lib64."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export CUDA_SO=$(\ls /usr/lib/x86_64-linux-gnu/libcuda.* | \
|
||||
xargs -I{} echo '-v {}:{}')
|
||||
export DEVICES=$(\ls /dev/nvidia* | \
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user