tensorflow/tensorflow/python/ops/array_grad.py
Vijay Vasudevan 795f35da2d TensorFlow: upstream changes to git
Change:
	Clean up documentation for ReverseSequence
Change:
	Updated several tensorflow operations to use 32bit indices on GPU.
Change:
	Add attribute batch_dim to ReverseSequenceOp.
Change:
	Fix error in convert_to_records.py.  As reported in
	https://github.com/tensorflow/tensorflow/issues/370
	by AlexUnderMicrocontRoll.
Change:
	Update TensorBoard README.
Change:
	Fixes to boolean flags reported in
	https://github.com/tensorflow/tensorflow/issues/379.  Supports:

	--bool_flag=True  --> True
	--bool_flag=False  --> False
	--bool_flag=gibberish  --> False
	--bool_flag --> True
	--nobool_flag --> False

	Fixes #379
Change:
	Update generated Op docs.
Change:
	Enable local development of TensorBoard using gulp
	Also make tf-tensorboard a regular component rather than special case

	This is mostly effected by creating tfserve.js, which is a small server
	with clever routing to load from bower_components/ and components/ using
	the paths that work within google3.

	Workflow: `gulp serve`
Change:
	Add a full working code example to the tensorboard and summaries tutorial
Change:
	Fix seq2seq_test when running on GPU.

	The "proj_w" and "proj_b" variables were being created before the
	`test_session()`'s device function took effect, which pushed the
	placement algorithm into making an incorrect decision.
Change:
	Add a sentence in TensorBoard README on how to serialize summary data to logs and provide link to the how-to tutorial on the TensorFlow website.
Change:
	Add error-catching code if string_input_producer is supplied a null input.
	Before this change, it would die with an opaque shape error from inside
	the queue.  This change catches (most) python null lists being
	passed directly in, and at runtime detects null tensors.

	Adds two tests for this to input_test.py
Change:
	Speed up for models that use the same variable multiple times in the case
	where variables must be copied across devices:
	- Have Variables wrap the Variable op in an Identity op when converted to Tensor.
	  This avoids multiple copies across devices if a variable is used multiple time
	  in a computation.
	- Add Variable.mutable() to return the non-wrapped Variable op for used when
	  assigning new values.
	- Add an as_ref parameter to convert_to_tensor() to allow code to specify
	  if they plan to assign a new value to the result of the conversion.  Make Variable
	  return the result of Variable.mutable() when as_ref is True.
	- Make all ops that assign values to variables pass as_ref=True when converting
	  their arguments.
Change:
	Change to reduce critical section times in gpu_event_mgr.h:
	(1) Call stream->ThenRecordEvent outside the EventMgr critical section
	(2) Do memory deallocation outside the critical section

	Speeds up one configuration of ptb_word_lm from 2924 words per
	second (wps) to 3278 wps on my desktop machine with a Titan X.
Change:
	Remove some colons that break the open source build

	::tensorflow::StringPiece breaks for @raingo, see
	https://github.com/tensorflow/tensorflow/issues/358.
	tensorflow::StringPiece (without the leading colons)
	seems to fix the problem.
Change:
	Added check that inputs to Operation is a list and make a defensive copy of the input. This is for cases where the input list is changed such as in _add_input.
Change:
	Use standard names for TensorFlow dtypes in the tutorial.
Change:
	Add tests for tensor inputs.
Change:
	Fix build after declaring more types for ops
Change:
	Switch to 32 bit indexing to speedup convolutions and concatenations.
Change:
	Add convert_image op to convert between types for images (similar to OpenCV's cvtScale).
Change:
	Make cast work between numeric types (bool, uint8, int16, int32, int64, float, double).
Change:

	Padding input data for odd number of paddings, so we can use cudnn anyway.
	+ Fix total padding computation when padding==VALID.
	+ This CL makes the Googlenet benchmark run 5x faster.

Change:
	Support IndexedSlices in ConcatGrad
Change:
	* sampled softmax op uses one embedding lookup for positive and negative samples
	* float64 support for sampled softmax
Change:
	Move RNN code out of models.rnn (without breaking existing code).  The API may still undergo minor changes, until full documentation as added.
Change:
	Changed to use per-step stacks for the accumulators used in while-loop gradient computation. This addresses the problem caused by using concat without sufficient static shape information. It should also improve performance as we avoided those expensive concats.
Change:
	Update generated Op docs.
Change:
	Improve error messages when the optimizer finds no variables to minimize or
	when none of the variables has gradients.
Change:
	Say that -1 isn't just for flattening in reshape docs

	Also add scalar reshape (reshape(t, [])) as an example.

	This fixes https://github.com/tensorflow/tensorflow/issues/281.
Change:
	This is a test.

Base CL: 109118714
2015-12-01 13:26:53 -08:00

263 lines
8.7 KiB
Python

# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradients for operators defined in array_ops.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
@ops.RegisterGradient("Pack")
def _PackGrad(op, grad):
"""Gradient for pack op."""
return array_ops.unpack(grad, num=op.get_attr("N"))
@ops.RegisterGradient("Unpack")
def _UnpackGrad(_, *grads):
"""Gradient for unpack op."""
return array_ops.pack(grads)
@ops.RegisterGradient("Concat")
def _ConcatGrad(op, grad):
"""Gradient for concat op."""
def _CreateDenseMaskAndBegin(sizes, concat_dim):
"""Create variables for iteratively slicing a dense gradients tensor."""
# Since shape is 1-D, shape_of_shape = [rank-of-inputs]
shape_of_shape = array_ops.shape(sizes[0])
# Make a vector of length equal to the input's dimensions,
# with 0's everywhere and 1 in the concat dim position.
# Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
mask = array_ops.concat(0,
[array_ops.fill(
array_ops.expand_dims(concat_dim, 0), 0),
[1],
array_ops.fill(
shape_of_shape - concat_dim - 1, 0)])
begin = array_ops.fill(shape_of_shape, 0)
return mask, begin
# Degenerate concatenation, just return grad.
if len(op.inputs) == 2:
return [None, grad]
concat_dim = op.inputs[0]
out_grads = []
if isinstance(grad, ops.Tensor):
# Get the inputs' tensor shapes
sizes = [array_ops.shape(x) for x in op.inputs[1:]]
mask, begin = _CreateDenseMaskAndBegin(sizes, concat_dim)
for size in sizes:
out_grads.append(array_ops.slice(grad, begin, size))
# Lint complains begin = begin + ...
begin = math_ops.add(begin, size * mask)
elif isinstance(grad, ops.IndexedSlices):
concat_dim_static = tensor_util.ConstantValue(concat_dim)
if concat_dim_static is None:
raise ValueError("Can only compute IndexedSlices gradient with "
"statically-known concat_dim")
# Get the inputs' tensor shapes
sizes = [array_ops.shape(x) for x in op.inputs[1:]]
if concat_dim_static > 0:
# IndexedSlices, concat_dim > 0. Each input gets IndexedSlices gradients
# with all the indices, but with grad.values sliced accordingly. This
# is like the Tensor case, except shape(grad.values)[0] is not equal to
# shape(sizes[i])[0], since only a subset of the dim-0 values are stored.
mask, begin = _CreateDenseMaskAndBegin(sizes, concat_dim)
for size in sizes:
new_values = array_ops.slice(
grad.values,
begin,
array_ops.concat(0, [[-1], array_ops.slice(size, [1], [-1])]))
out_grads.append(
ops.IndexedSlices(new_values, grad.indices, size))
# Lint complains begin = begin + ...
begin = math_ops.add(begin, size * mask)
else:
# IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients
# only for the relevant indices.
start = constant_op.constant(0, dtype=grad.indices.dtype)
for size in sizes:
size_concat_dim = array_ops.gather(size, concat_dim)
if size_concat_dim.dtype != grad.indices.dtype:
size_concat_dim = math_ops.cast(size_concat_dim,
dtype=grad.indices.dtype)
end = start + size_concat_dim
# Compute the 1-D Tensor of indices relevant for this input.
indices_to_select = array_ops.squeeze(
array_ops.where(math_ops.logical_and(grad.indices >= start,
grad.indices < end)),
squeeze_dims=[1])
new_indices = array_ops.gather(grad.indices, indices_to_select) - start
new_values = array_ops.gather(grad.values, indices_to_select)
out_grads.append(
ops.IndexedSlices(new_values, new_indices, size))
start = end
else:
raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad))
return [None] + out_grads
@ops.RegisterGradient("Slice")
def _SliceGrad(op, grad):
"""Gradient for Slice op."""
# Create an Nx2 padding where the first column represents how many
# zeros are to be prepended for each dimension, and the second
# column indicates how many zeros are appended.
#
# The number of zeros to append is the shape of the input
# elementwise-subtracted by both the begin vector and sizes vector.
#
# Some more reshaping is needed to assemble this tensor with the
# right dimensions.
input_vec = op.inputs[0]
begin_vec = op.inputs[1]
input_rank = array_ops.rank(input_vec)
slice_size = array_ops.shape(op.outputs[0])
shape = array_ops.pack([input_rank, 1])
before_pad = array_ops.reshape(begin_vec, shape)
after_pad = array_ops.reshape(
array_ops.shape(input_vec) - slice_size - begin_vec, shape)
paddings = array_ops.concat(1, [before_pad, after_pad])
return array_ops.pad(grad, paddings), None, None
@ops.RegisterGradient("Split")
def _SplitGrad(op, *grads):
return None, array_ops.concat(op.inputs[0], list(grads))
ops.NoGradient("Const")
# TODO(liqzhang): The gradient for Diag operator would be
# the diagonal of the backprop. Implement if there is a need.
ops.NoGradient("Diag")
# Edit Distance has no gradient (but can be used to eval seq2seq or CTC).
ops.NoGradient("EditDistance")
ops.NoGradient("Fill")
@ops.RegisterGradient("Gather")
def _GatherGrad(op, grad):
return [
ops.IndexedSlices(grad, op.inputs[1], array_ops.shape(op.inputs[0])), None
]
@ops.RegisterGradient("Identity")
def _IdGrad(_, grad):
return grad
@ops.RegisterGradient("RefIdentity")
def _RefIdGrad(_, grad):
return grad
ops.NoGradient("StopGradient")
@ops.RegisterGradient("Reshape")
def _ReshapeGrad(op, grad):
return [array_ops.reshape(grad, array_ops.shape(op.inputs[0])), None]
ops.NoGradient("InvertPermutation")
def _ReshapeToInput(op, grad):
"""Reshapes the gradient to the shape of the original input."""
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]))
@ops.RegisterGradient("ExpandDims")
def _ExpandDimsGrad(op, grad):
return [_ReshapeToInput(op, grad), None]
@ops.RegisterGradient("Squeeze")
def _SqueezeGrad(op, grad):
return _ReshapeToInput(op, grad)
@ops.RegisterGradient("Transpose")
def _TransposeGrad(op, grad):
"""Returns unshuffle(grad)."""
p = op.inputs[1]
return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None]
ops.NoGradient("Shape")
ops.NoGradient("Rank")
ops.NoGradient("Size")
@ops.RegisterGradient("Tile")
def _TileGrad(op, grad):
"""Sum reduces grad along the tiled dimensions."""
assert isinstance(grad, ops.Tensor)
return [gen_array_ops._tile_grad(grad, op.inputs[1]), None]
ops.NoGradient("TileGrad")
ops.NoGradient("BroadcastGradientArgs")
@ops.RegisterGradient("Pad")
def _PadGrad(op, grad):
"""Gradient for Pad."""
# Pad introduces values around the original tensor, so the gradient function
# slices the original shape out of the gradient."""
x = op.inputs[0]
a = op.inputs[1] # [Rank(x), 2]
# Takes a slice of a. The 1st column. [Rank(x), 1].
pad_before = array_ops.slice(a, [0, 0],
array_ops.pack([array_ops.rank(x), 1]))
# Make it a 1-D tensor.
begin = array_ops.reshape(pad_before, [-1])
sizes = array_ops.shape(x)
return array_ops.slice(grad, begin, sizes), None
# ReverseSequence is just a permutation. The gradient permutes back.
@ops.RegisterGradient("ReverseSequence")
def _ReverseSequenceGrad(op, grad):
seq_lengths = op.inputs[1]
return [array_ops.reverse_sequence(grad,
batch_dim=op.get_attr("batch_dim"),
seq_dim=op.get_attr("seq_dim"),
seq_lengths=seq_lengths),
None]