mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[TF:XLA] Initial implementation of TensorArray ops.
The XLA implementation of TensorArrays is more restrictive than regular TensorArrays: * XLA TensorArrays must have dynamic_size=False. * all elements in an XLA TensorArray must have the same shape. * writes always add their values to any existing values; neither reads nor writes ever issue errors. Out-of-bounds writes currently wrap. Refactor Variable handling in the TF/XLA bridge. Use a XlaVariable* to refer to variables inside compilation rather than a numerical ID. Allow for variables that don't correspond to variables known to the user. Also use XlaVariable to handle TensorArrays. PiperOrigin-RevId: 158322041
This commit is contained in:
parent
b5e8d30865
commit
c19e6cac04
|
|
@ -346,6 +346,25 @@ tf_xla_py_test(
|
|||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "tensor_array_ops_test",
|
||||
size = "small",
|
||||
srcs = ["tensor_array_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:math_ops_gen",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:nn_ops_gen",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:tensor_array_grad",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "ternary_ops_test",
|
||||
size = "small",
|
||||
|
|
|
|||
1018
tensorflow/compiler/tests/tensor_array_ops_test.py
Normal file
1018
tensorflow/compiler/tests/tensor_array_ops_test.py
Normal file
File diff suppressed because it is too large
Load Diff
|
|
@ -54,16 +54,20 @@ class XLATestCase(test.TestCase):
|
|||
self.device = FLAGS.test_device
|
||||
self.has_custom_call = (self.device == 'XLA_CPU')
|
||||
self.all_tf_types = [
|
||||
dtypes.DType(types_pb2.DataType.Value(name))
|
||||
dtypes.as_dtype(types_pb2.DataType.Value(name))
|
||||
for name in FLAGS.types.split(',')
|
||||
]
|
||||
self.int_tf_types = [
|
||||
dtype for dtype in self.all_tf_types if dtype.is_integer
|
||||
]
|
||||
self.float_tf_types = [
|
||||
dtype for dtype in self.all_tf_types if dtype.is_floating
|
||||
]
|
||||
self.numeric_tf_types = self.int_tf_types + self.float_tf_types
|
||||
|
||||
self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
|
||||
self.int_types = [
|
||||
dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_integer
|
||||
]
|
||||
self.float_types = [
|
||||
dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_floating
|
||||
]
|
||||
self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types]
|
||||
self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types]
|
||||
self.numeric_types = self.int_types + self.float_types
|
||||
|
||||
# Parse the manifest file, if any, into a regex identifying tests to
|
||||
|
|
|
|||
|
|
@ -89,6 +89,8 @@ Status BackwardsConstAnalysis(const Graph& g,
|
|||
{"StridedSliceGrad", "end"},
|
||||
{"StridedSliceGrad", "strides"},
|
||||
{"Sum", "reduction_indices"},
|
||||
{"TensorArrayV3", "size"},
|
||||
{"TensorArraySplitV3", "lengths"},
|
||||
{"Tile", "multiples"},
|
||||
{"Transpose", "perm"}};
|
||||
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ tf_kernel_library(
|
|||
"spacetobatch_op.cc",
|
||||
"split_op.cc",
|
||||
"strided_slice_op.cc",
|
||||
"tensor_array_ops.cc",
|
||||
"tile_ops.cc",
|
||||
"training_ops.cc",
|
||||
"transpose_op.cc",
|
||||
|
|
|
|||
|
|
@ -49,14 +49,15 @@ class ArgOp : public XlaOpKernel {
|
|||
return;
|
||||
}
|
||||
|
||||
XlaContext& tc = XlaContext::Get(ctx);
|
||||
const XlaContext::Argument& arg = tc.args()[index_];
|
||||
XlaContext& xc = XlaContext::Get(ctx);
|
||||
const XlaContext::Argument& arg = xc.args()[index_];
|
||||
if (arg.is_variable) {
|
||||
// We use the argument position of the variable input as a unique ID.
|
||||
// TODO(phawkins): this code assumes that variables do not alias.
|
||||
OP_REQUIRES_OK(ctx, tc.CreateVariable(index_, arg.name, arg.value.type,
|
||||
arg.value.handle));
|
||||
ctx->SetVariableOutput(0, index_);
|
||||
XlaVariable* var;
|
||||
OP_REQUIRES_OK(ctx, xc.CreateVariable(index_, arg.name, arg.value.type,
|
||||
arg.value.handle, &var));
|
||||
var->tensor_array_size = arg.tensor_array_size;
|
||||
ctx->SetVariableOutput(0, var);
|
||||
} else if (arg.value.is_constant) {
|
||||
ctx->SetConstantOutput(0, arg.value.constant_value);
|
||||
} else {
|
||||
|
|
|
|||
538
tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
Normal file
538
tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
Normal file
|
|
@ -0,0 +1,538 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// XLA TensorArray operators.
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/concat_lib.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Since the element shape is not always provided to the TensorArrayV3 operator,
|
||||
// we must support lazily initialization of the TensorArray at the time of the
|
||||
// first write.
|
||||
// If a TensorArray `var` has not been initialized, constructs storage for the
|
||||
// TensorArray with elements of `elem_shape`. For both initialized and
|
||||
// uninitialized TensorArrays, checks that the tensor has a type compatible with
|
||||
// 'dtype' and shape compatible with 'elem_shape'.
|
||||
Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
|
||||
XlaVariable* var, DataType dtype,
|
||||
const TensorShape& elem_shape) {
|
||||
if (var->type != dtype) {
|
||||
return errors::InvalidArgument(
|
||||
"TensorArray dtype is ", DataTypeString(var->type),
|
||||
" but op has dtype ", DataTypeString(dtype), ".");
|
||||
}
|
||||
|
||||
TF_RET_CHECK(var->tensor_array_size >= 0)
|
||||
<< var->name << " size " << var->tensor_array_size;
|
||||
TensorShape ta_shape;
|
||||
ta_shape.AddDim(var->tensor_array_size);
|
||||
ta_shape.AppendShape(elem_shape);
|
||||
|
||||
if (var->value.handle() == 0) {
|
||||
// TensorArray has not been initialized.
|
||||
xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type);
|
||||
var->value = builder->Broadcast(zero, ta_shape.dim_sizes());
|
||||
} else {
|
||||
// Checks the elem_shape matches the TensorArray shape.
|
||||
auto shape_or_status = builder->GetShape(var->value);
|
||||
if (!shape_or_status.ok()) {
|
||||
return shape_or_status.status();
|
||||
}
|
||||
TensorShape shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie());
|
||||
if (ta_shape != shape) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ",
|
||||
shape.DebugString());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pads 'x' with 'count' zero indices. 'x' must have 1 element.
|
||||
xla::ComputationDataHandle PadIndexWithZeros(
|
||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
|
||||
int count) {
|
||||
xla::ComputationDataHandle zero = builder->ConstantR1<int32>({0});
|
||||
std::vector<xla::ComputationDataHandle> xs(count + 1, zero);
|
||||
xs[0] = builder->Reshape(x, {1});
|
||||
return builder->ConcatInDim(xs, 0);
|
||||
}
|
||||
|
||||
// Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the
|
||||
// relevant slice of 'operand'.
|
||||
xla::ComputationDataHandle DynamicAddSlice(
|
||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand,
|
||||
const xla::ComputationDataHandle& update,
|
||||
const gtl::ArraySlice<int64>& update_dims,
|
||||
const xla::ComputationDataHandle& start_indices) {
|
||||
xla::ComputationDataHandle current =
|
||||
builder->DynamicSlice(operand, start_indices, update_dims);
|
||||
xla::ComputationDataHandle sum = builder->Add(current, update);
|
||||
return builder->DynamicUpdateSlice(operand, sum, start_indices);
|
||||
}
|
||||
|
||||
class TensorArrayOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorArrayOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("element_shape", &element_shape_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
bool dynamic_size;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dynamic_size", &dynamic_size));
|
||||
OP_REQUIRES(
|
||||
ctx, !dynamic_size,
|
||||
errors::Unimplemented(
|
||||
"TensorArrays with dynamic size are not supported by XLA."));
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_array_name", &tensor_array_name_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
int64 size;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size));
|
||||
OP_REQUIRES(ctx, size >= 0,
|
||||
errors::InvalidArgument("TensorArray size must be >= 0"));
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
b->set_die_immediately_on_error(true);
|
||||
|
||||
// Initializes the TensorArray value if we know the element shape.
|
||||
// Otherwise, defer initialization to the first write.
|
||||
xla::ComputationDataHandle value;
|
||||
if (element_shape_.IsFullyDefined()) {
|
||||
TensorShape shape;
|
||||
CHECK(element_shape_.AsTensorShape(&shape));
|
||||
TensorShape ta_shape;
|
||||
ta_shape.AddDim(size);
|
||||
ta_shape.AppendShape(shape);
|
||||
xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_);
|
||||
value = b->Broadcast(zero, ta_shape.dim_sizes());
|
||||
}
|
||||
|
||||
XlaContext& xc = XlaContext::Get(ctx);
|
||||
XlaVariable* var;
|
||||
string name = strings::StrCat("TensorArray: ", tensor_array_name_);
|
||||
OP_REQUIRES_OK(ctx,
|
||||
xc.CreateVariable(-1, std::move(name), dtype_, value, &var));
|
||||
var->tensor_array_size = size;
|
||||
ctx->SetVariableOutput(0, var);
|
||||
ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
|
||||
}
|
||||
|
||||
private:
|
||||
PartialTensorShape element_shape_;
|
||||
DataType dtype_;
|
||||
string tensor_array_name_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp);
|
||||
|
||||
class TensorArrayWriteOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorArrayWriteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
TensorShape elem_shape = ctx->InputShape(2);
|
||||
|
||||
// Initializes the TensorArray, if the element shape was not known at
|
||||
// construction time.
|
||||
XlaVariable* var;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
|
||||
OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
|
||||
|
||||
xla::ComputationDataHandle ta = var->value;
|
||||
xla::ComputationDataHandle index = ctx->Input(1);
|
||||
xla::ComputationDataHandle value = ctx->Input(2);
|
||||
|
||||
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
|
||||
auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
|
||||
|
||||
TensorShape slice_shape = elem_shape;
|
||||
slice_shape.InsertDim(0, 1LL);
|
||||
auto update = b->Reshape(value, slice_shape.dim_sizes());
|
||||
|
||||
xla::ComputationDataHandle written =
|
||||
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written));
|
||||
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayWriteOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorArrayWriteV3"), TensorArrayWriteOp);
|
||||
|
||||
class TensorArrayReadOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorArrayReadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType ta_type;
|
||||
TensorShape ta_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
|
||||
OP_REQUIRES(ctx, ta_type == dtype_,
|
||||
errors::InvalidArgument(
|
||||
"TensorArray dtype is ", DataTypeString(ta_type),
|
||||
" but Op requested dtype ", DataTypeString(dtype_), "."));
|
||||
OP_REQUIRES(ctx, ta_shape.dims() >= 1,
|
||||
errors::InvalidArgument("TensorArray rank must be >= 1"));
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
xla::ComputationDataHandle ta;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
|
||||
xla::ComputationDataHandle index = ctx->Input(1);
|
||||
|
||||
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
|
||||
auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1);
|
||||
|
||||
auto slice_shape = ta_shape.dim_sizes();
|
||||
slice_shape[0] = 1LL;
|
||||
|
||||
xla::ComputationDataHandle read =
|
||||
b->DynamicSlice(ta, start_indices, slice_shape);
|
||||
|
||||
// Remove the leading '1' dimension.
|
||||
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
|
||||
ctx->SetOutput(0, b->Reshape(read, value_shape));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayReadOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorArrayReadV3"), TensorArrayReadOp);
|
||||
|
||||
class TensorArrayGatherOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorArrayGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType ta_type;
|
||||
TensorShape ta_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
|
||||
OP_REQUIRES(ctx, ta_type == dtype_,
|
||||
errors::InvalidArgument("TensorArray type mismatch"));
|
||||
OP_REQUIRES(ctx, ta_shape.dims() >= 1,
|
||||
errors::InvalidArgument("TensorArray rank must be >= 1"));
|
||||
|
||||
const TensorShape indices_shape = ctx->InputShape(1);
|
||||
OP_REQUIRES(ctx, indices_shape.dims() >= 1,
|
||||
errors::InvalidArgument("indices must be rank 1"));
|
||||
const int num_indices = indices_shape.dim_size(0);
|
||||
auto indices = ctx->Input(1);
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
xla::ComputationDataHandle ta;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
|
||||
|
||||
// For each index in `indices`, add the corresponding slice to `slices`.
|
||||
std::vector<xla::ComputationDataHandle> slices(num_indices);
|
||||
for (int i = 0; i < num_indices; ++i) {
|
||||
// Slices the i-th index out of `indices`, and pads it with zeros in the
|
||||
// minor dimensions to form an index into the TensorArray storage.
|
||||
auto index = b->Slice(indices, {i}, {i + 1});
|
||||
|
||||
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
|
||||
auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1);
|
||||
|
||||
auto slice_shape = ta_shape.dim_sizes();
|
||||
slice_shape[0] = 1LL;
|
||||
|
||||
slices[i] = b->DynamicSlice(ta, start_indices, slice_shape);
|
||||
}
|
||||
|
||||
xla::ComputationDataHandle gather;
|
||||
if (slices.empty()) {
|
||||
auto shape = ta_shape.dim_sizes();
|
||||
shape[0] = 0;
|
||||
gather = b->Broadcast(XlaHelpers::Zero(b, dtype_), shape);
|
||||
} else {
|
||||
gather = b->ConcatInDim(slices, 0);
|
||||
}
|
||||
ctx->SetOutput(0, gather);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGatherOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorArrayGatherV3"), TensorArrayGatherOp);
|
||||
|
||||
class TensorArrayScatterOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorArrayScatterOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
const TensorShape value_shape = ctx->InputShape(2);
|
||||
|
||||
XlaVariable* var;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
|
||||
TensorShape elem_shape = value_shape;
|
||||
elem_shape.RemoveDim(0);
|
||||
OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
|
||||
|
||||
const TensorShape indices_shape = ctx->InputShape(1);
|
||||
OP_REQUIRES(ctx, indices_shape.dims() >= 1,
|
||||
errors::InvalidArgument("indices must be rank 1"));
|
||||
const int num_indices = indices_shape.dim_size(0);
|
||||
const xla::ComputationDataHandle indices = ctx->Input(1);
|
||||
|
||||
xla::ComputationDataHandle ta = var->value;
|
||||
const xla::ComputationDataHandle value = ctx->Input(2);
|
||||
|
||||
auto slice_dims = value_shape.dim_sizes();
|
||||
slice_dims[0] = 1LL;
|
||||
|
||||
std::vector<int64> value_starts(value_shape.dims(), 0);
|
||||
auto value_ends = value_shape.dim_sizes();
|
||||
|
||||
// For every (index, value) pair, update the corresponding TensorArray
|
||||
// storage.
|
||||
for (int i = 0; i < num_indices; ++i) {
|
||||
// Slice out part of the value.
|
||||
value_starts[0] = i;
|
||||
value_ends[0] = i + 1;
|
||||
auto slice = b->Slice(value, value_starts, value_ends);
|
||||
|
||||
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
|
||||
auto index = b->Slice(indices, {i}, {i + 1});
|
||||
auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
|
||||
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
|
||||
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayScatterOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorArrayScatterV3"), TensorArrayScatterOp);
|
||||
|
||||
class TensorArrayConcatOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorArrayConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType ta_type;
|
||||
TensorShape ta_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
|
||||
OP_REQUIRES(ctx, ta_type == dtype_,
|
||||
errors::InvalidArgument("TensorArray type mismatch"));
|
||||
OP_REQUIRES(ctx, ta_shape.dims() >= 1,
|
||||
errors::InvalidArgument("TensorArray rank must be >= 1"));
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
xla::ComputationDataHandle ta;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
|
||||
|
||||
auto ta_dims = ta_shape.dim_sizes();
|
||||
std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
|
||||
shape[0] *= ta_shape.dim_size(0);
|
||||
ctx->SetOutput(0, b->Reshape(ta, shape));
|
||||
|
||||
Tensor lengths(DT_INT64, {ta_dims[0]});
|
||||
auto lengths_vec = lengths.vec<int64>();
|
||||
for (int i = 0; i < ta_dims[0]; ++i) {
|
||||
lengths_vec(i) = ta_dims[1];
|
||||
}
|
||||
ctx->SetConstantOutput(1, lengths);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayConcatOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorArrayConcatV3"), TensorArrayConcatOp);
|
||||
|
||||
class TensorArraySplitOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorArraySplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
std::vector<int64> lengths;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths));
|
||||
|
||||
int64 length = 0;
|
||||
if (!lengths.empty()) {
|
||||
length = lengths[0];
|
||||
for (int i = 1; i < lengths.size(); ++i) {
|
||||
OP_REQUIRES(ctx, lengths[i] == length,
|
||||
errors::InvalidArgument("lengths must be equal: ", length,
|
||||
" vs. ", lengths[i]));
|
||||
}
|
||||
}
|
||||
|
||||
TensorShape value_shape = ctx->InputShape(1);
|
||||
OP_REQUIRES(ctx, value_shape.dims() >= 1,
|
||||
errors::InvalidArgument("value must have rank >= 1, got ",
|
||||
value_shape.DebugString()));
|
||||
TensorShape elem_shape = value_shape;
|
||||
elem_shape.set_dim(0, length);
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
XlaVariable* var;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
|
||||
OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
|
||||
xla::ComputationDataHandle ta = var->value;
|
||||
|
||||
TensorShape ta_shape;
|
||||
ta_shape.AddDim(var->tensor_array_size);
|
||||
ta_shape.AppendShape(elem_shape);
|
||||
|
||||
OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size,
|
||||
errors::InvalidArgument(
|
||||
"TensorArray's size is not equal to the size of lengths (",
|
||||
lengths.size(), " vs. ", var->tensor_array_size, ")"));
|
||||
|
||||
const xla::ComputationDataHandle value = ctx->Input(1);
|
||||
|
||||
OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(),
|
||||
errors::InvalidArgument("mismatched element count ",
|
||||
value_shape.DebugString(), " vs. ",
|
||||
ta_shape.DebugString()));
|
||||
|
||||
ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
|
||||
|
||||
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp);
|
||||
|
||||
class TensorArraySizeOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
XlaVariable* var;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
|
||||
Tensor size_tensor(DT_INT32, {});
|
||||
size_tensor.scalar<int32>()() = static_cast<int32>(var->tensor_array_size);
|
||||
ctx->SetConstantOutput(0, size_tensor);
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySizeOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorArraySizeV3"), TensorArraySizeOp);
|
||||
|
||||
class TensorArrayGradOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorArrayGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("source", &source_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
XlaVariable* var;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
|
||||
|
||||
DataType ta_type;
|
||||
TensorShape ta_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
|
||||
OP_REQUIRES(ctx, ta_shape.dims() >= 1,
|
||||
errors::InvalidArgument("TensorArray rank must be >= 1"));
|
||||
|
||||
// Finds or looks up the corresponding gradient TensorArray, which stores
|
||||
// gradients computed during backpropagation.
|
||||
XlaVariable*& gradient = var->tensor_array_gradient[source_];
|
||||
if (!gradient) {
|
||||
xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type);
|
||||
xla::ComputationDataHandle value =
|
||||
b->Broadcast(zero, ta_shape.dim_sizes());
|
||||
|
||||
XlaContext& xc = XlaContext::Get(ctx);
|
||||
string name = strings::StrCat("TensorArrayGrad: ", var->name);
|
||||
OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type,
|
||||
value, &gradient));
|
||||
gradient->tensor_array_size = var->tensor_array_size;
|
||||
}
|
||||
|
||||
ctx->SetVariableOutput(0, gradient);
|
||||
ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
|
||||
}
|
||||
|
||||
private:
|
||||
string source_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGradOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
|
@ -119,6 +119,4 @@ void XlaExpression::set_constant_value(Tensor value) {
|
|||
constant_value_ = std::move(value);
|
||||
}
|
||||
|
||||
void XlaExpression::set_variable_id(int id) { variable_id_ = id; }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -64,6 +64,39 @@ class XlaCompilationDevice : public LocalDevice {
|
|||
std::unique_ptr<XlaCompilationAllocator> allocator_;
|
||||
};
|
||||
|
||||
struct XlaVariable {
|
||||
// If this variable is visible externally, what was its argument number?
|
||||
int arg_num = -1;
|
||||
|
||||
// A descriptive name for the variable, used in error messages.
|
||||
string name;
|
||||
|
||||
// Current type and value of the variable. Uninitialized variables are
|
||||
// represented by a default (zero) handle and type DT_INVALID.
|
||||
// While the type of a variable is notionally fixed during execution, when
|
||||
// a variable is first initialized we do not yet know its type, so we keep
|
||||
// track of its type dynamically.
|
||||
DataType type = DT_INVALID;
|
||||
xla::ComputationDataHandle value;
|
||||
|
||||
// Value of the variable at computation entry. Used to detect which
|
||||
// variables have new values that need to be written back.
|
||||
xla::ComputationDataHandle initial_value;
|
||||
|
||||
// We treat TensorArrays as a Variable with some extra metadata.
|
||||
|
||||
// 'tensor_array_size' stores the expected size of the TensorArray. We need
|
||||
// to store this since sometimes TensorArrays must be initialized lazily since
|
||||
// we do not know the element shape at construction time.
|
||||
int64 tensor_array_size = -1;
|
||||
|
||||
// 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes
|
||||
// to an XlaVariable containing the gradient TensorArrays. We store a pointer
|
||||
// here since there should only be one gradient TensorArray per 'source'
|
||||
// string, irrespective of the number of calls to TensorArrayGrad.
|
||||
std::unordered_map<string, XlaVariable*> tensor_array_gradient;
|
||||
};
|
||||
|
||||
// A XlaExpression wraps an XLA computation. Each Tensor on an
|
||||
// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor
|
||||
// matches the shape of the subcomputation in the ComputationDataHandle. Each
|
||||
|
|
@ -82,8 +115,8 @@ class XlaExpression {
|
|||
bool has_constant_value() const { return has_constant_value_; }
|
||||
const Tensor& constant_value() const { return constant_value_; }
|
||||
|
||||
void set_variable_id(int id);
|
||||
int variable_id() const { return variable_id_; }
|
||||
void set_variable(XlaVariable* variable) { variable_ = variable; }
|
||||
XlaVariable* variable() const { return variable_; }
|
||||
|
||||
private:
|
||||
// The XLA handle of the expression's computation.
|
||||
|
|
@ -95,7 +128,7 @@ class XlaExpression {
|
|||
bool has_constant_value_ = false;
|
||||
Tensor constant_value_;
|
||||
|
||||
int variable_id_ = -1;
|
||||
XlaVariable* variable_ = nullptr; // Not owned.
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -59,8 +59,9 @@ Status CheckSignature(const DataTypeVector& types,
|
|||
|
||||
bool XlaCompiler::Argument::operator==(
|
||||
const XlaCompiler::Argument& other) const {
|
||||
if (std::tie(kind, type, shape, name) !=
|
||||
std::tie(other.kind, other.type, other.shape, other.name)) {
|
||||
if (std::tie(kind, type, shape, name, tensor_array_size) !=
|
||||
std::tie(other.kind, other.type, other.shape, other.name,
|
||||
other.tensor_array_size)) {
|
||||
return false;
|
||||
}
|
||||
if (constant_value.shape() != other.constant_value.shape()) {
|
||||
|
|
@ -264,8 +265,9 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
|||
switch (args[i].kind) {
|
||||
case XlaCompiler::Argument::kVariable:
|
||||
variables.push_back(i);
|
||||
context_arg.value.is_constant = false;
|
||||
context_arg.is_variable = true;
|
||||
context_arg.value.is_constant = false;
|
||||
context_arg.tensor_array_size = args[i].tensor_array_size;
|
||||
break;
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
parameters.push_back(i);
|
||||
|
|
@ -274,6 +276,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
|||
case XlaCompiler::Argument::kUninitializedVariable:
|
||||
context_arg.is_variable = true;
|
||||
context_arg.value.is_constant = true;
|
||||
context_arg.tensor_array_size = args[i].tensor_array_size;
|
||||
break;
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
context_arg.value.is_constant = true;
|
||||
|
|
@ -337,7 +340,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
|||
// type of the final output.
|
||||
Status BuildComputation(
|
||||
const std::vector<XlaContext::HandleOrConstant>& retvals,
|
||||
const std::unordered_map<int, XlaContext::Variable>& variable_map,
|
||||
const std::vector<std::unique_ptr<XlaVariable>>& variables,
|
||||
bool has_side_effects, bool return_updated_values_for_all_variables,
|
||||
xla::ComputationBuilder* builder, xla::Computation* computation,
|
||||
int* num_nonconst_outputs,
|
||||
|
|
@ -352,27 +355,27 @@ Status BuildComputation(
|
|||
*num_nonconst_outputs = elems.size();
|
||||
|
||||
// Add return values for variables whose values have changed.
|
||||
std::vector<std::pair<int, const XlaContext::Variable*>> variables;
|
||||
variables.reserve(variable_map.size());
|
||||
for (const auto& entry : variable_map) {
|
||||
variables.emplace_back(entry.first, &entry.second);
|
||||
std::vector<const XlaVariable*> arg_vars;
|
||||
arg_vars.reserve(variables.size());
|
||||
for (const auto& var : variables) {
|
||||
if (var->arg_num >= 0) {
|
||||
arg_vars.push_back(var.get());
|
||||
}
|
||||
std::sort(variables.begin(), variables.end(),
|
||||
[](const std::pair<int, const XlaContext::Variable*>& a,
|
||||
const std::pair<int, const XlaContext::Variable*>& b) {
|
||||
return a.first < b.first;
|
||||
}
|
||||
std::sort(arg_vars.begin(), arg_vars.end(),
|
||||
[](const XlaVariable* a, const XlaVariable* b) {
|
||||
return a->arg_num < b->arg_num;
|
||||
});
|
||||
|
||||
for (const auto& entry : variables) {
|
||||
bool modified =
|
||||
entry.second->value.handle() != entry.second->initial_value.handle();
|
||||
for (const XlaVariable* var : arg_vars) {
|
||||
bool modified = var->value.handle() != var->initial_value.handle();
|
||||
if (return_updated_values_for_all_variables || modified) {
|
||||
variable_updates->emplace_back();
|
||||
XlaCompiler::VariableUpdate& update = variable_updates->back();
|
||||
update.input_index = entry.first;
|
||||
update.type = entry.second->type;
|
||||
update.input_index = var->arg_num;
|
||||
update.type = var->type;
|
||||
update.modified = modified;
|
||||
elems.push_back(entry.second->value);
|
||||
elems.push_back(var->value);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -114,6 +114,10 @@ class XlaCompiler {
|
|||
// The name of this argument, used for debugging.
|
||||
string name;
|
||||
|
||||
// For a kVariable or kUninitializedVariable corresponding to a TensorArray,
|
||||
// what is the tensor array's declared size?
|
||||
int64 tensor_array_size = -1;
|
||||
|
||||
bool operator==(const Argument& other) const;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
|
|
@ -53,6 +54,10 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context";
|
|||
return *context;
|
||||
}
|
||||
|
||||
/* static */ XlaContext& XlaContext::Get(const XlaOpKernelContext* ctx) {
|
||||
return Get(ctx->op_kernel_context());
|
||||
}
|
||||
|
||||
void XlaContext::set_args(std::vector<Argument> args) {
|
||||
args_ = std::move(args);
|
||||
}
|
||||
|
|
@ -124,29 +129,19 @@ void XlaContext::AddSideEffects() {
|
|||
|
||||
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
|
||||
|
||||
Status XlaContext::CreateVariable(int variable_id, string name, DataType type,
|
||||
const xla::ComputationDataHandle& handle) {
|
||||
auto result = variables_.emplace(variable_id, Variable());
|
||||
if (!result.second) {
|
||||
return errors::InvalidArgument("Duplicate ID ", variable_id,
|
||||
" for variable ", name);
|
||||
}
|
||||
Variable& var = result.first->second;
|
||||
Status XlaContext::CreateVariable(int arg_num, string name, DataType type,
|
||||
const xla::ComputationDataHandle& handle,
|
||||
XlaVariable** variable) {
|
||||
variables_.emplace_back(new XlaVariable);
|
||||
*variable = variables_.back().get();
|
||||
XlaVariable& var = **variable;
|
||||
var.arg_num = arg_num;
|
||||
var.name = std::move(name);
|
||||
var.type = type;
|
||||
var.initial_value = var.value = handle;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaContext::GetVariable(int variable_id, Variable** variable) {
|
||||
auto it = variables_.find(variable_id);
|
||||
if (it == variables_.end()) {
|
||||
return errors::InvalidArgument("Unknown variable ID ", variable_id);
|
||||
}
|
||||
*variable = &it->second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) {
|
||||
return LookupOrCreate(type, &max_func_, [this, type] {
|
||||
const string type_string = DataTypeString(type);
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ limitations under the License.
|
|||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/computation.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
|
@ -31,6 +30,8 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
class XlaOpKernelContext;
|
||||
|
||||
// The XlaContext is the data structure that holds the state of an XLA
|
||||
// compilation, that is accessible from OpKernelContexts when compiling a
|
||||
// subgraph of Ops using XLA.
|
||||
|
|
@ -55,16 +56,16 @@ class XlaContext : public ResourceBase {
|
|||
string name;
|
||||
|
||||
// Is this a variable?
|
||||
bool is_variable;
|
||||
bool is_variable = false;
|
||||
|
||||
HandleOrConstant value;
|
||||
|
||||
int64 tensor_array_size = -1;
|
||||
};
|
||||
|
||||
// Retrieves the XlaContext of the current compilation.
|
||||
static XlaContext& Get(const OpKernelContext* ctx);
|
||||
static XlaContext& Get(const XlaOpKernelContext* ctx) {
|
||||
return Get(ctx->op_kernel_context());
|
||||
}
|
||||
static XlaContext& Get(const XlaOpKernelContext* ctx);
|
||||
|
||||
// Creates a new XlaContext.
|
||||
XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
|
||||
|
|
@ -105,33 +106,16 @@ class XlaContext : public ResourceBase {
|
|||
|
||||
bool has_side_effects() const { return has_side_effects_; }
|
||||
|
||||
struct Variable {
|
||||
// A descriptive name for the variable, used in error messages.
|
||||
string name;
|
||||
|
||||
// Current type and value of the variable. Uninitialized variables are
|
||||
// represented by a default (zero) handle and type DT_INVALID.
|
||||
// While the type of a variable is notionally fixed during execution, when
|
||||
// a variable is first initialized we do not yet know its type, so we keep
|
||||
// track of its type dynamically.
|
||||
DataType type = DT_INVALID;
|
||||
xla::ComputationDataHandle value;
|
||||
|
||||
// Value of the variable at computation entry. Used to detect which
|
||||
// variables have new values that need to be written back.
|
||||
xla::ComputationDataHandle initial_value;
|
||||
};
|
||||
|
||||
// Creates a variable with variable `variable_id` and initial type `type` and
|
||||
// value `handle`. `name` is a descriptive name for use in error messages.
|
||||
// Fails if the variable already exists.
|
||||
Status CreateVariable(int variable_id, string name, DataType type,
|
||||
const xla::ComputationDataHandle& handle);
|
||||
Status CreateVariable(int arg_num, string name, DataType type,
|
||||
const xla::ComputationDataHandle& handle,
|
||||
XlaVariable** variable);
|
||||
|
||||
// Retrieves variable `variable_id`. Fails if the variable does not exist.
|
||||
Status GetVariable(int variable_id, Variable** variable);
|
||||
|
||||
const std::unordered_map<int, Variable>& variables() { return variables_; }
|
||||
const std::vector<std::unique_ptr<XlaVariable>>& variables() {
|
||||
return variables_;
|
||||
}
|
||||
|
||||
// Get an XLA lambda to compute Max. This is cached in the
|
||||
// XlaContext since it may be used by multiple Ops. There is a
|
||||
|
|
@ -182,8 +166,8 @@ class XlaContext : public ResourceBase {
|
|||
// Does the computation have side effects, i.e., Send() calls?
|
||||
bool has_side_effects_ = false;
|
||||
|
||||
// Map from variable ID to the current value of each variable.
|
||||
std::unordered_map<int, Variable> variables_;
|
||||
// Holds ownership of variables. The variables are not ordered.
|
||||
std::vector<std::unique_ptr<XlaVariable>> variables_;
|
||||
|
||||
// Cache of prebuilt computations indexed by their type.
|
||||
using ComputationMap = std::map<DataType, xla::Computation>;
|
||||
|
|
|
|||
|
|
@ -38,7 +38,8 @@ xla::ComputationBuilder* XlaOpKernelContext::builder() const {
|
|||
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
|
||||
CHECK(expression->handle().handle() != 0 || expression->variable_id() >= 0);
|
||||
CHECK(expression->handle().handle() != 0 ||
|
||||
expression->variable() != nullptr);
|
||||
VLOG(1) << "Fetched T" << expression->handle().handle();
|
||||
return expression;
|
||||
}
|
||||
|
|
@ -251,11 +252,8 @@ Status XlaOpKernelContext::ReadVariableInput(
|
|||
int index, xla::ComputationDataHandle* value) {
|
||||
const Tensor& tensor = context_->input(index);
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
int variable_id = expression->variable_id();
|
||||
|
||||
XlaContext::Variable* variable;
|
||||
XlaContext& context = XlaContext::Get(this);
|
||||
TF_RETURN_IF_ERROR(context.GetVariable(variable_id, &variable));
|
||||
XlaVariable* variable = expression->variable();
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
if (variable->value.handle() == 0) {
|
||||
return errors::InvalidArgument("Read of uninitialized variable ",
|
||||
variable->name);
|
||||
|
|
@ -267,11 +265,8 @@ Status XlaOpKernelContext::ReadVariableInput(
|
|||
string XlaOpKernelContext::VariableDebugString(int index) {
|
||||
const Tensor& tensor = context_->input(index);
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
int variable_id = expression->variable_id();
|
||||
|
||||
XlaContext::Variable* variable;
|
||||
XlaContext& context = XlaContext::Get(this);
|
||||
if (!context.GetVariable(variable_id, &variable).ok()) {
|
||||
XlaVariable* variable = expression->variable();
|
||||
if (!variable) {
|
||||
return "<invalid variable ID>";
|
||||
}
|
||||
return variable->name;
|
||||
|
|
@ -281,11 +276,8 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
|
|||
TensorShape* shape) const {
|
||||
const Tensor& tensor = context_->input(index);
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
int variable_id = expression->variable_id();
|
||||
|
||||
XlaContext::Variable* variable;
|
||||
XlaContext& context = XlaContext::Get(this);
|
||||
TF_RETURN_IF_ERROR(context.GetVariable(variable_id, &variable));
|
||||
XlaVariable* variable = expression->variable();
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
if (variable->value.handle() == 0) {
|
||||
return errors::InvalidArgument("Read of uninitialized variable ",
|
||||
variable->name);
|
||||
|
|
@ -345,14 +337,22 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
|
|||
expression->set_constant_value(constant);
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetVariableOutput(int index, int variable_id) {
|
||||
void XlaOpKernelContext::SetVariableOutput(int index, XlaVariable* variable) {
|
||||
Tensor* output = nullptr;
|
||||
// The shape of the output tensor is the shape of the variable resource
|
||||
// (i.e., a scalar), not the shape of the variable's value.
|
||||
OP_REQUIRES_OK(context_,
|
||||
context_->allocate_output(index, TensorShape(), &output));
|
||||
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
|
||||
expression->set_variable_id(variable_id);
|
||||
expression->set_variable(variable);
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::GetVariableInput(int index, XlaVariable** variable) {
|
||||
const XlaExpression* expression =
|
||||
CastExpressionFromTensor(context_->input(index));
|
||||
TF_RET_CHECK(expression->variable() != nullptr);
|
||||
*variable = expression->variable();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::AssignVariable(
|
||||
|
|
@ -362,9 +362,8 @@ Status XlaOpKernelContext::AssignVariable(
|
|||
|
||||
const XlaExpression* expression =
|
||||
CastExpressionFromTensor(context_->input(index));
|
||||
XlaContext& context = XlaContext::Get(this);
|
||||
XlaContext::Variable* variable;
|
||||
TF_RETURN_IF_ERROR(context.GetVariable(expression->variable_id(), &variable));
|
||||
XlaVariable* variable = expression->variable();
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
if (!((variable->type == DT_INVALID && type != DT_INVALID) ||
|
||||
(variable->type == type))) {
|
||||
return errors::InvalidArgument(
|
||||
|
|
|
|||
|
|
@ -157,15 +157,18 @@ class XlaOpKernelContext {
|
|||
// 'index'.
|
||||
Status ReadVariableInput(int index, xla::ComputationDataHandle* value);
|
||||
|
||||
// Sets output 'index' to be a reference to variable 'variable_id'. Used
|
||||
// to propagate resource variables through the compilation.
|
||||
void SetVariableOutput(int index, int variable_id);
|
||||
|
||||
// Assigns the value `handle` to the variable referenced by input
|
||||
// `variable_index`. Marks the operator as having side effects.
|
||||
Status AssignVariable(int variable_index, DataType type,
|
||||
const xla::ComputationDataHandle& handle);
|
||||
|
||||
// Sets '*variable' to the variable associated with input `index`.
|
||||
Status GetVariableInput(int index, XlaVariable** variable);
|
||||
|
||||
// Sets output 'index' to be a reference to variable 'variable'. Used
|
||||
// to propagate resource variables through the compilation.
|
||||
void SetVariableOutput(int index, XlaVariable* variable);
|
||||
|
||||
// Returns a human-readable debug string describing 'variable_index'.
|
||||
string VariableDebugString(int variable_index);
|
||||
|
||||
|
|
|
|||
|
|
@ -1221,7 +1221,7 @@ of the forward TensorArray is known when this operation is called.
|
|||
|
||||
TensorArray gradient calls use an accumulator TensorArray object. If
|
||||
multiple gradients are calculated and run in the same session, the multiple
|
||||
gradient nodes may accidentally flow throuth the same accumulator TensorArray.
|
||||
gradient nodes may accidentally flow through the same accumulator TensorArray.
|
||||
This double counts and generally breaks the TensorArray gradient flow.
|
||||
|
||||
The solution is to identify which gradient call this particular
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user