Merge changes from github.

Change: 124644444
This commit is contained in:
Vijay Vasudevan 2016-06-11 10:45:56 -08:00 committed by TensorFlower Gardener
parent f34e397622
commit 5a65d43a9e
28 changed files with 676 additions and 68 deletions

View File

@ -16,7 +16,7 @@ Installed version of CUDA and cuDNN:
If installed from binary pip package, provide:
1. Which pip package you installed.
2. The output from python -c "import tensorflow; print(tensorflow.__version__)".
2. The output from `python -c "import tensorflow; print(tensorflow.__version__)"`.
If installed from sources, provide the commit hash:

View File

@ -145,6 +145,12 @@ cc_library(
"include",
".",
],
defines = [
"GPR_BACKWARDS_COMPATIBILITY_MODE",
],
copts = [
"-std=c99",
],
deps = [
],
)

View File

@ -285,7 +285,7 @@
GCC_WARN_UNUSED_VARIABLE = YES;
HEADER_SEARCH_PATHS = (
"$(SRCROOT)/../../makefile/gen/proto",
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f",
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30",
"$(SRCROOT)/../../makefile/downloads",
"$(SRCROOT)/../../makefile/downloads/protobuf/src/",
"$(SRCROOT)/../../../..",
@ -300,6 +300,12 @@
OTHER_LDFLAGS = (
"-force_load",
"$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
"-Xlinker",
"-S",
"-Xlinker",
"-x",
"-Xlinker",
"-dead_strip",
);
PRODUCT_BUNDLE_IDENTIFIER = com.google.CameraExample;
PRODUCT_NAME = "$(TARGET_NAME)";
@ -344,7 +350,7 @@
GCC_WARN_UNUSED_VARIABLE = YES;
HEADER_SEARCH_PATHS = (
"$(SRCROOT)/../../makefile/gen/proto",
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f",
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30",
"$(SRCROOT)/../../makefile/downloads",
"$(SRCROOT)/../../makefile/downloads/protobuf/src/",
"$(SRCROOT)/../../../..",
@ -359,6 +365,12 @@
OTHER_LDFLAGS = (
"-force_load",
"$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
"-Xlinker",
"-S",
"-Xlinker",
"-x",
"-Xlinker",
"-dead_strip",
);
PRODUCT_BUNDLE_IDENTIFIER = com.google.CameraExample;
PRODUCT_NAME = "$(TARGET_NAME)";

View File

@ -276,7 +276,7 @@
"$(SRCROOT)/../../../..",
"$(SRCROOT)/../../makefile/downloads/protobuf/src/",
"$(SRCROOT)/../../makefile/downloads",
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f",
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30",
"$(SRCROOT)/../../makefile/gen/proto",
);
INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist";
@ -289,6 +289,12 @@
OTHER_LDFLAGS = (
"-force_load",
"$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
"-Xlinker",
"-S",
"-Xlinker",
"-x",
"-Xlinker",
"-dead_strip",
);
PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test";
PRODUCT_NAME = "$(TARGET_NAME)";
@ -304,7 +310,7 @@
"$(SRCROOT)/../../../..",
"$(SRCROOT)/../../makefile/downloads/protobuf/src/",
"$(SRCROOT)/../../makefile/downloads",
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f",
"$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30",
"$(SRCROOT)/../../makefile/gen/proto",
);
INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist";
@ -314,10 +320,16 @@
"$(SRCROOT)/../../makefile/gen/protobuf_ios/lib",
"$(SRCROOT)/../../makefile/gen/lib",
);
ONLY_ACTIVE_ARCH = NO;
ONLY_ACTIVE_ARCH = YES;
OTHER_LDFLAGS = (
"-force_load",
"$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
"-Xlinker",
"-S",
"-Xlinker",
"-x",
"-Xlinker",
"-dead_strip",
);
PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test";
PRODUCT_NAME = "$(TARGET_NAME)";

View File

@ -358,7 +358,7 @@ class DataFeeder(object):
else:
if self.n_classes > 1:
if len(self.output_shape) == 2:
out.itemset((i, self.y[sample]), 1.0)
out.itemset((i, int(self.y[sample])), 1.0)
else:
for idx, value in enumerate(self.y[sample]):
out.itemset(tuple([i, idx, value]), 1.0)

View File

@ -54,22 +54,31 @@ def batch_normalize(tensor_in,
initializer=init_ops.random_normal_initializer(1., 0.02))
beta = vs.get_variable("beta", [shape[-1]],
initializer=init_ops.constant_initializer(0.))
ema = moving_averages.ExponentialMovingAverage(decay=decay)
if convnet:
assign_mean, assign_var = nn.moments(tensor_in, [0, 1, 2])
else:
assign_mean, assign_var = nn.moments(tensor_in, [0])
ema_assign_op = ema.apply([assign_mean, assign_var])
ema_mean, ema_var = ema.average(assign_mean), ema.average(assign_var)
moving_mean = vs.get_variable(
'moving_mean',
shape=[shape[-1]],
initializer=init_ops.zeros_initializer,
trainable=False)
moving_var = vs.get_variable(
'moving_var',
shape=[shape[-1]],
initializer=init_ops.ones_initializer,
trainable=False)
def _update_mean_var():
"""Internal function that updates mean and variance during training."""
with ops.control_dependencies([ema_assign_op]):
return array_ops_.identity(assign_mean), array_ops_.identity(assign_var)
axis = [0, 1, 2] if convnet else [0]
mean, var = nn.moments(tensor_in, axis)
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, decay)
update_moving_var = moving_averages.assign_moving_average(
moving_var, var, decay)
with ops.control_dependencies([update_moving_mean, update_moving_var]):
return array_ops_.identity(mean), array_ops_.identity(var)
is_training = array_ops_.squeeze(ops.get_collection("IS_TRAINING"))
mean, variance = control_flow_ops.cond(is_training, _update_mean_var,
lambda: (ema_mean, ema_var))
lambda: (moving_mean, moving_var))
return nn.batch_norm_with_global_normalization(
tensor_in,
mean,

View File

@ -35,8 +35,8 @@ HOST_OBJDIR := $(MAKEFILE_DIR)/gen/host_obj/
HOST_BINDIR := $(MAKEFILE_DIR)/gen/host_bin/
HOST_GENDIR := $(MAKEFILE_DIR)/gen/host_obj/
# Which Eigen version we're using.
EIGEN_HASH := d02e6a705c30
# Find the current Eigen version name from the Bazel build file
EIGEN_HASH := $(shell cat eigen.BUILD | grep archive_dir | head -1 | cut -f3 -d- | cut -f1 -d\")
# Settings for the host compiler.
HOST_CXX := gcc
@ -56,9 +56,15 @@ HOST_LIBS := \
# If we're on Linux, also link in the dl library.
ifeq ($(HOST_OS),LINUX)
HOST_LIBS += -ldl
HOST_LIBS += -ldl -lpthread
endif
# If we're on a Pi, link in pthreads and dl
ifeq ($(HOST_OS),PI)
HOST_LIBS += -ldl -lpthread
endif
# proto_text is a tool that converts protobufs into a form we can use more
# compactly within TensorFlow. It's a bit like protoc, but is designed to
# produce a much more minimal result so we can save binary space.
@ -125,13 +131,13 @@ ifeq ($(TARGET),LINUX)
endif
# If we're on Linux, also link in the dl library.
ifeq ($(TARGET),LINUX)
LIBS += -ldl
LIBS += -ldl -lpthread
endif
# If we're cross-compiling for the Raspberry Pi, use the right gcc.
ifeq ($(TARGET),PI)
CXX := arm-linux-gnueabihf-g++
LDFLAGS := -L$(GENDIR)protobuf_pi/lib/ -Wl,--no-whole-archive
LIBS += -ldl
CXXFLAGS += -D__ANDROID_TYPES_SLIM__
LDFLAGS := -Wl,--no-whole-archive
LIBS += -ldl -lpthread
LIBFLAGS += -Wl,--allow-multiple-definition -Wl,--whole-archive
endif
@ -169,12 +175,16 @@ ifeq ($(TARGET),IOS)
-Wno-c++11-narrowing \
-mno-thumb \
-DTF_LEAN_BINARY \
-D__ANDROID_TYPES_SLIM__ \
-DMIN_LOG_LEVEL=0 \
-fno-exceptions \
-isysroot \
${IPHONEOS_SYSROOT}
LDFLAGS := -arch armv7 \
-miphoneos-version-min=${MIN_SDK_VERSION} \
-Xlinker -S \
-Xlinker -x \
-Xlinker -dead_strip \
-all_load \
-L$(GENDIR)protobuf_ios/lib \
-lz
@ -186,6 +196,7 @@ ifeq ($(TARGET),IOS)
-Wno-c++11-narrowing \
-mno-thumb \
-DTF_LEAN_BINARY \
-D__ANDROID_TYPES_SLIM__ \
-DMIN_LOG_LEVEL=0 \
-fno-exceptions \
-isysroot \
@ -205,6 +216,7 @@ ifeq ($(TARGET),IOS)
-D__thread= \
-Wno-c++11-narrowing \
-DTF_LEAN_BINARY \
-D__ANDROID_TYPES_SLIM__ \
-DMIN_LOG_LEVEL=0 \
-fno-exceptions \
-isysroot \
@ -224,6 +236,7 @@ ifeq ($(TARGET),IOS)
-D__thread= \
-Wno-c++11-narrowing \
-DTF_LEAN_BINARY \
-D__ANDROID_TYPES_SLIM__ \
-DMIN_LOG_LEVEL=0 \
-fno-exceptions \
-isysroot \
@ -243,6 +256,7 @@ ifeq ($(TARGET),IOS)
-D__thread= \
-Wno-c++11-narrowing \
-DTF_LEAN_BINARY \
-D__ANDROID_TYPES_SLIM__ \
-DMIN_LOG_LEVEL=0 \
-fno-exceptions \
-isysroot \

View File

@ -42,7 +42,7 @@ at `tensorflow/contrib/makefile/gen/bin/benchmark`. To run the executable, use:
tensorflow/contrib/makefile/gen/bin/benchmark --graph=tensorflow_inception_graph.pb
```
You should download the example graph from [http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz).
You should download the example graph from [https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip).
## Supported Systems
@ -132,24 +132,31 @@ static library in a simple app.
## Raspberry Pi
The easiest way to build for the Raspberry Pi is to cross-compile from Linux.
To use this makefile to do that, you first need to install the right version of
the compiler to target the Pi, using a command like this on your Linux machine:
Building on the Raspberry Pi is similar to a normal Linux system, though we
recommend starting by compiling and installing protobuf:
```bash
sudo apt-get install g++-arm-linux-gnueabihf
cd tensorflow/contrib/makefile/downloads/protobuf/
./autogen.sh
./configure
make
sudo make install
cd ../../../../..
```
After that, run `tensorflow/contrib/makefile/compile_pi_protobuf.sh` to build a
version of the protobuf library aimed at the Pi. Then you should be able to run:
Once that's done, you can use make to build the library and example:
```bash
make -f tensorflow/contrib/makefile/Makefile TARGET=PI
make -f tensorflow/contrib/makefile/Makefile HOST_OS=PI TARGET=PI OPTFLAGS="-Os"
```
This will build the static library, and the example benchmark executable. You
can then copy the `tensorflow/contrib/makefile/gen/bin/benchmark` program over
to your Raspberry Pi, and run it there.
If you're only interested in building for Raspberry Pi's 2 and 3, you can supply
some extra optimization flags to give you code that will run faster:
```bash
make -f tensorflow/contrib/makefile/Makefile HOST_OS=PI TARGET=PI \
OPTFLAGS="-Os -mfpu=neon-vfpv4 -funsafe-math-optimizations -ftree-vectorize"
```
## Dependencies

View File

@ -18,7 +18,12 @@ DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads
mkdir ${DOWNLOADS_DIR}
EIGEN_HASH=d02e6a705c30
EIGEN_HASH=62a2305d5734
if [ -f eigen.BUILD ]; then
# Grab the current Eigen version name from the Bazel build file
EIGEN_HASH=$(cat eigen.BUILD | grep archive_dir | head -1 | cut -f3 -d- | cut -f1 -d\")
fi
curl "https://bitbucket.org/eigen/eigen/get/${EIGEN_HASH}.tar.gz" \
-o /tmp/eigen-${EIGEN_HASH}.tar.gz
tar xzf /tmp/eigen-${EIGEN_HASH}.tar.gz -C ${DOWNLOADS_DIR}

View File

@ -1937,4 +1937,141 @@ REGISTER_KERNELS(GPU, double);
#undef REGISTER_CPU_KERNELS
#undef REGISTER_KERNELS
// Note, this op works on cpu only.
template <typename T, typename Tindex>
class SparseApplyRMSPropOp : public OpKernel {
public:
explicit SparseApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
Tensor ms = ctx->mutable_input(1, use_exclusive_lock_);
Tensor mom = ctx->mutable_input(2, use_exclusive_lock_);
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", def().input(0)));
OP_REQUIRES(
ctx, ms.IsInitialized(),
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", def().input(1)));
OP_REQUIRES(
ctx, mom.IsInitialized(),
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", def().input(2)));
const Tensor& lr = ctx->input(3);
const Tensor& rho = ctx->input(4);
const Tensor& momentum = ctx->input(5);
const Tensor& epsilon = ctx->input(6);
const Tensor& grad = ctx->input(7);
const Tensor& indices = ctx->input(8);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
errors::InvalidArgument("rho is not a scalar: ",
rho.shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
errors::InvalidArgument("momentum is not a scalar: ",
momentum.shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
errors::InvalidArgument("epsilon is not a scalar: ",
epsilon.shape().DebugString()));
OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
errors::InvalidArgument("var and ms do not have the same shape",
var.shape().DebugString(), " ",
ms.shape().DebugString()));
OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
errors::InvalidArgument(
"var and mom do not have the same shape",
var.shape().DebugString(), " ", mom.shape().DebugString()));
OP_REQUIRES(
ctx, var.shape().IsSameSize(grad.shape()),
errors::InvalidArgument("var and grad do not have the same shape",
var.shape().DebugString(), " ",
grad.shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
errors::InvalidArgument("indices must be one-dimensional"));
const Tindex N = indices.dim_size(0);
OP_REQUIRES(
ctx, grad.dim_size(0) == N,
errors::InvalidArgument(
"grad must be the same size as indices in the first dimension."));
if (N > 0) {
const Tindex first_dim_size = var.dim_size(0);
// Validate all the indices are in range
auto indices_vec = indices.vec<Tindex>();
for (Tindex i = 0; i < N; i++) {
const Tindex index = indices_vec(i);
OP_REQUIRES(ctx, index >= 0 && index < first_dim_size,
errors::InvalidArgument(
strings::StrCat("Index ", index, " at offset ", i,
" in indices is out of range")));
}
auto var_flat = var.flat_outer_dims<T>();
auto ms_flat = ms.flat_outer_dims<T>();
auto mom_flat = mom.flat_outer_dims<T>();
auto grad_flat = grad.flat_outer_dims<T>();
const T lr_scalar = lr.scalar<T>()();
const T rho_scalar = rho.scalar<T>()();
const T epsilon_scalar = epsilon.scalar<T>()();
const T momentum_scalar = momentum.scalar<T>()();
for (Tindex i = 0; i < N; i++) {
const Tindex index = indices_vec(i);
auto ms_ = ms_flat.template chip<0>(index);
auto mom_ = mom_flat.template chip<0>(index);
auto grad_ = grad_flat.template chip<0>(i);
ms_ = ms_ * ms_.constant(rho_scalar) +
grad_.square() * grad_.constant(T(1) - rho_scalar);
mom_ = mom_ * mom_.constant(momentum_scalar) +
(ms_ + ms_.constant(epsilon_scalar)).rsqrt() *
ms_.constant(lr_scalar) * grad_;
auto v = var_flat.template chip<0>(index);
v -= mom_;
}
}
ctx->forward_ref_input_to_ref_output(0, 0);
}
private:
bool use_exclusive_lock_;
};
#define REGISTER_KERNELS(T, Tindices) \
REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<Tindices>("Tindices"), \
SparseApplyRMSPropOp<T, Tindices>);
REGISTER_KERNELS(Eigen::half, int32);
REGISTER_KERNELS(Eigen::half, int64);
REGISTER_KERNELS(float, int32);
REGISTER_KERNELS(float, int64);
REGISTER_KERNELS(double, int32);
REGISTER_KERNELS(double, int64);
#undef REGISTER_KERNELS
} // namespace tensorflow

View File

@ -1715,7 +1715,7 @@ REGISTER_OP("ExtractImagePatches")
.Attr("T: realnumbertype")
.Attr(GetPaddingAttrString())
.Doc(R"doc(
Extract `patches` from `images` and puth them in the "depth" output dimension.
Extract `patches` from `images` and put them in the "depth" output dimension.
images: 4-D Tensor with shape `[batch, in_rows, in_cols, depth]`.
patches: 4-D Tensor with shape `[batch, out_rows, out_cols, ksize_rows *

View File

@ -441,6 +441,9 @@ REGISTER_OP("ApplyRMSProp")
.Attr("use_locking: bool = false")
.Doc(R"doc(
Update '*var' according to the RMSProp algorithm.
Note that in dense implement of this algorithm, ms and mom will
update even if the grad is zero, but in this sparse implement, ms
and mom will not update in iterations the grad is zero.
mean_square = decay * mean_square + (1-decay) * gradient ** 2
Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
@ -462,4 +465,45 @@ use_locking: If `True`, updating of the var, m, and v tensors will be protected
contention.
)doc");
REGISTER_OP("SparseApplyRMSProp")
.Input("var: Ref(T)")
.Input("ms: Ref(T)")
.Input("mom: Ref(T)")
.Input("lr: T")
.Input("rho: T")
.Input("momentum: T")
.Input("epsilon: T")
.Input("grad: T")
.Input("indices: Tindices")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Doc(R"doc(
Update '*var' according to the RMSProp algorithm.
Note that in dense implement of this algorithm, ms and mom will
update even if the grad is zero, but in this sparse implement, ms
and mom will not update in iterations the grad is zero.
mean_square = decay * mean_square + (1-decay) * gradient ** 2
Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
ms <- rho * ms_{t-1} + (1-rho) * grad * grad
mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
var <- var - mom
var: Should be from a Variable().
ms: Should be from a Variable().
mom: Should be from a Variable().
lr: Scaling factor. Must be a scalar.
epsilon: Ridge term. Must be a scalar.
rho: Decay rate. Must be a scalar.
grad: The gradient.
indices: A vector of indices into the first dimension of var, ms and mom.
out: Same as "var".
use_locking: If `True`, updating of the var, m, and v tensors will be protected
by a lock; otherwise the behavior is undefined, but may exhibit less
contention.
)doc");
} // namespace tensorflow

View File

@ -37,6 +37,13 @@ limitations under the License.
#define IS_MOBILE_PLATFORM
#endif
#elif defined(__arm__)
#define PLATFORM_POSIX
// Since there's no macro for the Raspberry Pi, assume we're on a mobile
// platform if we're compiling for the ARM CPU.
#define IS_MOBILE_PLATFORM
#else
// If no platform specified, use:
#define PLATFORM_POSIX

View File

@ -33,6 +33,7 @@ Some examples use the `pandas` library for data processing (`sudo pip install pa
## Image classification
* [Convolutional Neural Networks on MNIST Data](mnist.py)
* [Recurrent Neural Networks on MNIST Data](mnist_rnn.py)
* [Deep Residual Networks on MNIST Data](resnet.py)

View File

@ -0,0 +1,78 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example builds rnn network for mnist data.
Borrowed structure from here: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3%20-%20Neural%20Networks/recurrent_network.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from sklearn import metrics, preprocessing
import tensorflow as tf
from tensorflow.contrib import learn
# Parameters
learning_rate = 0.1
training_steps = 3000
batch_size = 128
# Network Parameters
n_input = 28 # MNIST data input (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST total classes (0-9 digits)
### Download and load MNIST data.
mnist = learn.datasets.load_dataset('mnist')
X_train = mnist.train.images
y_train = mnist.train.labels
X_test = mnist.test.images
y_test = mnist.test.labels
# It's useful to scale to ensure Stochastic Gradient Descent will do the right thing
scaler = preprocessing.StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.fit_transform(X_test)
def rnn_model(X, y):
X = tf.reshape(X, [-1, n_steps, n_input]) # (batch_size, n_steps, n_input)
# # permute n_steps and batch_size
X = tf.transpose(X, [1, 0, 2])
# # Reshape to prepare input to hidden activation
X = tf.reshape(X, [-1, n_input]) # (n_steps*batch_size, n_input)
# # Split data because rnn cell needs a list of inputs for the RNN inner loop
X = tf.split(0, n_steps, X) # n_steps * (batch_size, n_input)
# Define a GRU cell with tensorflow
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden)
# Get lstm cell output
_, encoding = tf.nn.rnn(lstm_cell, X, dtype=tf.float32)
return learn.models.logistic_regression(encoding, y)
classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=n_classes,
batch_size=batch_size,
steps=training_steps,
learning_rate=learning_rate)
classifier.fit(X_train, y_train, logdir="/tmp/mnist_rnn")
score = metrics.accuracy_score(y_test, classifier.predict(X_test))
print('Accuracy: {0:f}'.format(score))

View File

@ -106,7 +106,7 @@ idempotent operation that simply divides `total` by `count`.
To facilitate the estimation of the accuracy over a stream of data, the
function utilizes two operations. First, an `is_correct` operation that
computes a tensor whose shape matches `predictions` and whose elements are
set to 1.0 when the corresponding values of `predictions` and `labels match
set to 1.0 when the corresponding values of `predictions` and `labels` match
and 0.0 otherwise. Second, an `update_op` operation whose behavior is
dependent on the value of `weights`. If `weights` is None, then `update_op`
increments `total` with the number of elements of `predictions` that match

View File

@ -62,7 +62,7 @@ Install TensorFlow:
# Ubuntu/Linux 64-bit, CPU only, Python 2.7:
$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4.
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and cuDNN v4.
# For other versions, see "Install from sources" below.
$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
@ -77,7 +77,7 @@ For python3:
# Ubuntu/Linux 64-bit, CPU only, Python 3.4:
$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4.
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and cuDNN v4.
# For other versions, see "Install from sources" below.
$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
@ -137,7 +137,7 @@ $ source ~/tensorflow/bin/activate.csh # If using csh
# Ubuntu/Linux 64-bit, CPU only, Python 2.7:
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4.
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and cuDNN v4.
# For other versions, see "Install from sources" below.
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
@ -155,7 +155,7 @@ $ source ~/tensorflow/bin/activate.csh # If using csh
# Ubuntu/Linux 64-bit, CPU only, Python 3.4:
(tensorflow)$ pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4.
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and cuDNN v4.
# For other versions, see "Install from sources" below.
(tensorflow)$ pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
@ -228,7 +228,7 @@ $ source activate tensorflow
# Ubuntu/Linux 64-bit, CPU only, Python 2.7:
(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4.
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and cuDNN v4.
# For other versions, see "Install from sources" below.
(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
@ -245,7 +245,7 @@ $ source activate tensorflow
# Ubuntu/Linux 64-bit, CPU only, Python 3.4:
(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4.
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and cuDNN v4.
# For other versions, see "Install from sources" below.
(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
@ -314,7 +314,7 @@ $ docker run -it -p 8888:8888 gcr.io/tensorflow/tensorflow
The option `-p 8888:8888` is used to publish the Docker container᾿s internal port to the host machine, in this case to ensure Jupyter notebook connection.
The format of the port mapping `hostPort:containerPort`. You can speficy any valid port number for the host port but has to be `8888` for the container port portion.
The format of the port mapping is `hostPort:containerPort`. You can specify any valid port number for the host port but have to use `8888` for the container port portion.
If you're using a container with GPU support, some additional flags must be
passed to expose the GPU device to the container. For the default config, we
@ -526,7 +526,7 @@ empty to use system default]: 7.5
Please specify the location where CUDA 7.5 toolkit is installed. Refer to
README.md for more details. [default is: /usr/local/cuda]: /usr/local/cuda
Please specify the Cudnn version you want to use. [Leave empty to use system
Please specify the cuDNN version you want to use. [Leave empty to use system
default]: 4.0.4
Please specify the location where the cuDNN 4.0.4 library is installed. Refer to
@ -549,7 +549,7 @@ Configuration finished
This creates a canonical set of symbolic links to the Cuda libraries on your system.
Every time you change the Cuda library paths you need to run this step again before
you invoke the bazel build command. For the Cudnn libraries, use '6.5' for R2, '7.0'
you invoke the bazel build command. For the cuDNN libraries, use '6.5' for R2, '7.0'
for R3, and '4.0.4' for R4-RC.
@ -672,7 +672,7 @@ GPU support will be enabled for TensorFlow
Please specify which gcc nvcc should use as the host compiler. [Default is /usr/bin/gcc]:
Please specify the Cuda SDK version you want to use, e.g. 7.0. [Leave empty to use system default]: 7.5
Please specify the location where CUDA 7.5 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]:
Please specify the Cudnn version you want to use. [Leave empty to use system default]: 5
Please specify the cuDNN version you want to use. [Leave empty to use system default]: 5
Please specify the location where cuDNN 5 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda]:
Please specify a list of comma-separated Cuda compute capabilities you want to build with.
You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.

View File

@ -173,5 +173,40 @@ class ReverseTest(test_util.TensorFlowTestCase):
tf.reverse(data_2d_t, dims_3d_t)
class MeshgridTest(test_util.TensorFlowTestCase):
def _compare(self, n, np_dtype, use_gpu):
inputs = []
for i in range(n):
x = np.linspace(-10, 10, 5).astype(np_dtype)
if np_dtype in (np.complex64, np.complex128):
x += 1j
inputs.append(x)
numpy_out = np.meshgrid(*inputs)
with self.test_session(use_gpu=use_gpu):
tf_out = array_ops.meshgrid(*inputs)
for X, _X in zip(numpy_out, tf_out):
self.assertAllEqual(X, _X.eval())
def testCompare(self):
for t in (np.float16, np.float32, np.float64, np.int32, np.int64,
np.complex64, np.complex128):
# Don't test the one-dimensional case, as
# old numpy versions don't support it
self._compare(2, t, False)
self._compare(3, t, False)
self._compare(4, t, False)
self._compare(5, t, False)
# Test for inputs with rank not equal to 1
x = [[1, 1], [1, 1]]
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"needs to have rank 1"):
with self.test_session():
X, _ = array_ops.meshgrid(x, x)
X.eval()
if __name__ == "__main__":
googletest.main()

View File

@ -38,6 +38,7 @@ of a tensor and change the shape of a tensor.
@@reshape
@@squeeze
@@expand_dims
@@meshgrid
## Slicing and Joining
@ -1047,6 +1048,82 @@ def pad(tensor, paddings, mode="CONSTANT", name=None): # pylint: disable=invali
raise ValueError("Unknown padding mode: %s" % mode)
def meshgrid(*args, **kwargs):
"""Broadcasts parameters for evaluation on an N-D grid.
Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
of N-D coordinate arrays for evaluating expressions on an N-D grid.
Notes:
`meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
When the `indexing` argument is set to 'xy' (the default), the broadcasting
instructions for the first two dimensions are swapped.
Examples:
Calling `X, Y = meshgrid(x, y)` with the tensors
```prettyprint
x = [1, 2, 3]
y = [4, 5, 6]
```
results in
```prettyprint
X = [[1, 1, 1],
[2, 2, 2],
[3, 3, 3]]
Y = [[4, 5, 6],
[4, 5, 6],
[4, 5, 6]]
```
Args:
*args: `Tensor`s with rank 1
indexing: Either 'xy' or 'ij' (optional, default: 'xy')
name: A name for the operation (optional).
Returns:
outputs: A list of N `Tensor`s with rank N
"""
indexing = kwargs.pop("indexing", "xy")
name = kwargs.pop("name", "meshgrid")
if len(kwargs) > 0:
key = list(kwargs.keys())[0]
raise TypeError("'{}' is an invalid keyword argument "
"for this function".format(key))
if indexing not in ("xy", "ij"):
raise ValueError("indexing parameter must be either 'xy' or 'ij'")
with ops.op_scope(args, name, "meshgrid") as name:
num_inputs = len(args)
ones = (1,) * num_inputs
asserts = [logging_ops.Assert(
gen_math_ops.equal(rank(x), 1),
["Input %d needs to have rank 1: " % i, rank(x)],
) for i, x in enumerate(args)]
# Prepare reshape by inserting dimensions with size 1 where needed
shapes = [ones[:i] + (-1,) + ones[i + 1:] for i in range(num_inputs)]
# Create parameters for broadcasting each tensor to the full size
sizes = [size(x) for x in args]
bcast = [sizes[:i] + [1] + sizes[i + 1:] for i in range(num_inputs)]
# By default, the numpy version swaps the instructions
# for the first and second dimension
if indexing == "xy" and num_inputs > 1:
shapes[0], shapes[1] = shapes[1], shapes[0]
bcast[0], bcast[1] = bcast[1], bcast[0]
results = []
with ops.control_dependencies(asserts):
for a, r, e in zip(args, shapes, bcast):
results.append(tile(reshape(a, r), e))
return results
@ops.RegisterShape("Placeholder")
def _PlaceholderShape(op):
given_shape = tensor_util.TensorShapeProtoToList(op.get_attr("shape"))

View File

@ -184,7 +184,7 @@ class _VariableStore(object):
If initializer is `None` (the default), the default initializer passed in
the constructor is used. If that one is `None` too, we use a new
`UniformUnitScalingInitializer`. If initializer is a Tensor, we use
`uniform_unit_scaling_initializer`. If initializer is a Tensor, we use
it as a value and derive the shape from the initializer.
If the initializer is a callable, then it will be called for each
@ -681,7 +681,7 @@ def get_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
If initializer is `None` (the default), the default initializer passed in
the variable scope will be used. If that one is `None` too, a
`UniformUnitScalingInitializer` will be used. The initializer can also be
`uniform_unit_scaling_initializer` will be used. The initializer can also be
a Tensor, in which case the variable is initialized to this value and shape.
Similarly, if the regularizer is `None` (the default), the default regularizer
@ -757,7 +757,7 @@ def _get_partitioned_variable(
If initializer is `None` (the default), the default initializer passed in
the constructor is used. If that one is `None` too, we use a new
`UniformUnitScalingInitializer`. If initializer is a Tensor, we use
`uniform_unit_scaling_initializer`. If initializer is a Tensor, we use
it as a value and derive the shape from the initializer.
If the initializer is a callable, then it will be called for each

View File

@ -64,6 +64,10 @@ class AdamOptimizer(optimizer.Optimizer):
general. For example, when training an Inception network on ImageNet a
current good choice is 1.0 or 0.1.
Note that in dense implement of this algorithm, m_t, v_t and variable will
update even if g is zero, but in sparse implement, m_t, v_t and variable
will not update in iterations g is zero.
Args:
learning_rate: A Tensor or a floating point value. The learning rate.
beta1: A float value or a constant float tensor.

View File

@ -57,6 +57,10 @@ class RMSPropOptimizer(optimizer.Optimizer):
name="RMSProp"):
"""Construct a new RMSProp optimizer.
Note that in dense implement of this algorithm, m_t and v_t will
update even if g is zero, but in sparse implement, m_t and v_t
will not update in iterations g is zero.
Args:
learning_rate: A Tensor or a floating point value. The learning rate.
decay: Discounting factor for the history/coming gradient
@ -105,4 +109,14 @@ class RMSPropOptimizer(optimizer.Optimizer):
grad, use_locking=self._use_locking).op
def _apply_sparse(self, grad, var):
raise NotImplementedError()
rms = self.get_slot(var, "rms")
mom = self.get_slot(var, "momentum")
return training_ops.sparse_apply_rms_prop(
var, rms, mom,
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
grad.values,
grad.indices,
use_locking=self._use_locking)

View File

@ -26,6 +26,131 @@ import tensorflow as tf
class RMSPropOptimizerTest(tf.test.TestCase):
def _rmsprop_update_numpy(self, var, g, rms, mom, lr, decay, momentum,
epsilon):
rms_t = rms * decay + (1-decay) * g * g
mom_t = momentum * mom + lr * g / np.sqrt(rms_t + epsilon)
var_t = var - mom_t
return var_t, rms_t, mom_t
def testSparseWithMomentum(self):
for dtype in [tf.half, tf.float32]:
with self.test_session():
# Initialize variables for numpy implementation.
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
var0 = tf.Variable(var0_np)
var1 = tf.Variable(var1_np)
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = tf.IndexedSlices(tf.constant(grads0_np),
tf.constant(grads0_np_indices),
tf.constant([2]))
grads1_np_indices = np.array([0, 1], dtype=np.int32)
grads1 = tf.IndexedSlices(tf.constant(grads1_np),
tf.constant(grads1_np_indices),
tf.constant([2]))
opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
momentum=0.5, epsilon=1e-5)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
tf.initialize_all_variables().run()
rms0 = opt.get_slot(var0, "rms")
self.assertTrue(rms0 is not None)
rms1 = opt.get_slot(var1, "rms")
self.assertTrue(rms1 is not None)
mom0 = opt.get_slot(var0, "momentum")
self.assertTrue(mom0 is not None)
mom1 = opt.get_slot(var1, "momentum")
self.assertTrue(mom1 is not None)
rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())
# Run 4 steps of RMSProp
for t in range(1, 5):
update.run()
var0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(var0_np,
grads0_np, rms0_np, mom0_np, 2.0, 0.9, 0.5, 1e-5)
var1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(var1_np,
grads1_np, rms1_np, mom1_np, 2.0, 0.9, 0.5, 1e-5)
# Validate updated params
self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
self.assertAllCloseAccordingToType(var0_np, var0.eval())
self.assertAllCloseAccordingToType(var1_np, var1.eval())
def testSparseWithoutMomentum(self):
for dtype in [tf.half, tf.float32]:
with self.test_session():
# Initialize variables for numpy implementation.
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
var0 = tf.Variable(var0_np)
var1 = tf.Variable(var1_np)
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = tf.IndexedSlices(tf.constant(grads0_np),
tf.constant(grads0_np_indices),
tf.constant([2]))
grads1_np_indices = np.array([0, 1], dtype=np.int32)
grads1 = tf.IndexedSlices(tf.constant(grads1_np),
tf.constant(grads1_np_indices),
tf.constant([2]))
opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
momentum=0.0, epsilon=1.0)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
tf.initialize_all_variables().run()
rms0 = opt.get_slot(var0, "rms")
self.assertTrue(rms0 is not None)
rms1 = opt.get_slot(var1, "rms")
self.assertTrue(rms1 is not None)
mom0 = opt.get_slot(var0, "momentum")
self.assertTrue(mom0 is not None)
mom1 = opt.get_slot(var1, "momentum")
self.assertTrue(mom1 is not None)
rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())
# Run 4 steps of RMSProp
for t in range(1, 5):
update.run()
var0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(var0_np,
grads0_np, rms0_np, mom0_np, 2.0, 0.9, 0.0, 1.0)
var1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(var1_np,
grads1_np, rms1_np, mom1_np, 2.0, 0.9, 0.0, 1.0)
# Validate updated params
self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
self.assertAllCloseAccordingToType(var0_np, var0.eval())
self.assertAllCloseAccordingToType(var1_np, var1.eval())
def testWithoutMomentum(self):
for dtype in [tf.half, tf.float32]:
with self.test_session():

View File

@ -170,6 +170,23 @@ def _SparseApplyProximalGradientDescentShape(op):
return [var_shape]
@ops.RegisterShape("SparseApplyRMSProp")
def _SparseApplyRMSPropShape(op):
"""Shape function for the SparseApplyRMSProp op."""
var_shape = op.inputs[0].get_shape()
ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
_AssertInputIsScalar(op, 3) # lr
_AssertInputIsScalar(op, 4) # rho
_AssertInputIsScalar(op, 5) # momentum
_AssertInputIsScalar(op, 6) # epsilon
grad_shape = op.inputs[7].get_shape().merge_with(
tensor_shape.TensorShape([None]).concatenate(mom_shape[1:]))
unused_indices_shape = op.inputs[8].get_shape().merge_with(
tensor_shape.vector(grad_shape[0]))
return [mom_shape]
@ops.RegisterShape("SparseApplyAdadelta")
def _SparseApplyAdadeltaShape(op):
"""Shape function for the SparseApplyAdadelta op."""

View File

@ -25,6 +25,12 @@ limitations under the License.
#define EIGEN_HAS_CUDA_FP16
#endif
#if CUDA_VERSION >= 8000
#define SE_CUDA_DATA_HALF CUDA_R_16F
#else
#define SE_CUDA_DATA_HALF CUBLAS_DATA_HALF
#endif
#include "tensorflow/stream_executor/cuda/cuda_blas.h"
#include <dlfcn.h>
@ -1680,10 +1686,10 @@ bool CUDABlas::DoBlasGemm(
return DoBlasInternal(
dynload::cublasSgemmEx, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
CUDAMemory(a), CUBLAS_DATA_HALF, lda,
CUDAMemory(b), CUBLAS_DATA_HALF, ldb,
CUDAMemory(a), SE_CUDA_DATA_HALF, lda,
CUDAMemory(b), SE_CUDA_DATA_HALF, ldb,
&beta,
CUDAMemoryMutable(c), CUBLAS_DATA_HALF, ldc);
CUDAMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
#else
LOG(ERROR) << "fp16 sgemm is not implemented in this cuBLAS version "
<< "(need at least CUDA 7.5)";

View File

@ -267,8 +267,11 @@ export class Minimap {
downloadContext.drawImage(image, 0, 0,
this.downloadCanvas.width, this.downloadCanvas.height);
};
let blob = new Blob([svgXml], {type: 'image/svg+xml;charset=utf-8'});
image.src = URL.createObjectURL(blob);
image.onerror = () => {
let blob = new Blob([svgXml], {type: 'image/svg+xml;charset=utf-8'});
image.src = URL.createObjectURL(blob);
}
image.src = 'data:image/svg+xml;charset=utf-8,' + encodeURIComponent(svgXml);
}
/**

View File

@ -9987,8 +9987,11 @@ var tf;
downloadContext.clearRect(0, 0, _this.downloadCanvas.width, _this.downloadCanvas.height);
downloadContext.drawImage(image, 0, 0, _this.downloadCanvas.width, _this.downloadCanvas.height);
};
var blob = new Blob([svgXml], { type: 'image/svg+xml;charset=utf-8' });
image.src = URL.createObjectURL(blob);
image.onerror = function() {
var blob = new Blob([svgXml], {type: "image/svg+xml;charset=utf-8"});
image.src = URL.createObjectURL(blob);
};
image.src = "data:image/svg+xml;charset=utf-8," + encodeURIComponent(svgXml);
};
/**
* Handles changes in zooming/panning. Should be called from the main svg

View File

@ -14,16 +14,8 @@
# limitations under the License.
# ==============================================================================
set -e
export CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
if [ ! -d ${CUDA_HOME}/lib64 ]; then
echo "Failed to locate CUDA libs at ${CUDA_HOME}/lib64."
exit 1
fi
export CUDA_SO=$(\ls /usr/lib/x86_64-linux-gnu/libcuda.* | \
xargs -I{} echo '-v {}:{}')
export DEVICES=$(\ls /dev/nvidia* | \