mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add TFE_Py_TensorShapeSlice function
TFE_Py_TensorShapeSlice takes a list of EagerTensors and returns a list of their i'th dimensions. This utility is fairly niche but it is simple and reduces SPINN training time by over 12%. PiperOrigin-RevId: 174065044
This commit is contained in:
parent
585432cc21
commit
8a09bbc4a5
|
|
@ -657,3 +657,71 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
|
|||
EagerTensorType->tp_dictoffset = 0;
|
||||
return reinterpret_cast<PyObject*>(EagerTensorType);
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) {
|
||||
if (!PyList_Check(tensor_list)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"tensor_list argument must be a list. Got \"",
|
||||
Py_TYPE(tensor_list)->tp_name, "\"")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
if (slice_dim < 0) {
|
||||
PyErr_SetString(
|
||||
PyExc_ValueError,
|
||||
tensorflow::strings::StrCat("Slice dimension must be non-negative. "
|
||||
"Got ",
|
||||
slice_dim)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Py_ssize_t num_tensors = PyList_Size(tensor_list);
|
||||
int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
|
||||
auto tensor = tensorflow::make_safe(TF_AllocateTensor(
|
||||
TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int));
|
||||
int32_t* data = reinterpret_cast<int32_t*>(TF_TensorData(tensor.get()));
|
||||
for (Py_ssize_t i = 0; i < num_tensors; ++i) {
|
||||
PyObject* tensor_obj = PyList_GET_ITEM(tensor_list, i);
|
||||
if (!EagerTensor_CheckExact(tensor_obj)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Expected a list of EagerTensors but "
|
||||
"element ",
|
||||
i, " has type \"", Py_TYPE(tensor_obj)->tp_name, "\"")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
|
||||
TFE_TensorHandle* handle = t->handle;
|
||||
if (slice_dim >= TFE_TensorHandleNumDims(handle)) {
|
||||
PyErr_SetString(PyExc_IndexError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Slice dimension (", slice_dim,
|
||||
") must be smaller than rank of all "
|
||||
"tensors, but tensor at index ",
|
||||
i, " has rank ", TFE_TensorHandleNumDims(handle))
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
int64_t dim = TFE_TensorHandleDim(handle, slice_dim);
|
||||
data[i] = dim;
|
||||
}
|
||||
|
||||
auto status = tensorflow::make_safe(TF_NewStatus());
|
||||
TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) {
|
||||
PyErr_SetString(
|
||||
PyExc_RuntimeError,
|
||||
tensorflow::strings::StrCat("Failed to construct new tensor handle: ",
|
||||
TF_Message(status.get()))
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
// handle now owns the tensor. Release it from the smart pointer.
|
||||
tensor.release();
|
||||
|
||||
return EagerTensorFromHandle(handle);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -105,4 +105,16 @@ void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type,
|
|||
PyObject* backward_function);
|
||||
PyObject* TFE_Py_TapeExport(PyObject* tape);
|
||||
|
||||
// Returns an EagerTensor of dimension [len(`tensor_list`)] containing
|
||||
// the `slice_dim`'th dimension of each tensor in `tensor_list`. In other words,
|
||||
// TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in
|
||||
// `tensor_list`. For example, if `tensor_list` contains tensors of with shapes
|
||||
// [1, 2, 3], [4, 5], [6, 7, 8, 9], TFE_Py_TensorShapeSlice called with
|
||||
// `slice_dim` equal to 1 will return [2, 5, 7].
|
||||
// On error, returns nullptr and sets python exception.
|
||||
// REQUIRES: `tensor_list` is a python list of EagerTensors
|
||||
// REQUIRES: `slice_dim` is non-negative and smaller than the rank of all
|
||||
// tensors in `tensor_list`.
|
||||
PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim);
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import copy
|
|||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import test
|
||||
|
|
@ -216,5 +217,93 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
|||
_create_tensor("test string")
|
||||
|
||||
|
||||
class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testListOfThree(self):
|
||||
t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32)
|
||||
t2 = _create_tensor([[1, 2, 5], [3, 4, 5]], dtype=dtypes.int32)
|
||||
t3 = _create_tensor([[1], [3], [5], [6]], dtype=dtypes.int32)
|
||||
|
||||
r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 0)
|
||||
self.assertAllEqual(np.array([3, 2, 4]), r.numpy())
|
||||
|
||||
r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 1)
|
||||
self.assertAllEqual(np.array([2, 3, 1]), r.numpy())
|
||||
|
||||
def testEmptyTensorList(self):
|
||||
a = pywrap_tensorflow.TFE_Py_TensorShapeSlice([], 0)
|
||||
self.assertTrue(isinstance(a, ops.EagerTensor))
|
||||
self.assertEqual(0, a.numpy().size)
|
||||
|
||||
def testTensorListContainsNonTensors(self):
|
||||
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
r"Expected a list of EagerTensors but element 1 has type \"str\""):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, "abc"], 0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
r"Expected a list of EagerTensors but element 0 has type \"int\""):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([2, t1], 0)
|
||||
|
||||
def testTensorListNotList(self):
|
||||
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
r"tensor_list argument must be a list. Got \"EagerTensor\""):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
r"tensor_list argument must be a list. Got \"tuple\""):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice((t1,), -2)
|
||||
|
||||
def testNegativeSliceDim(self):
|
||||
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"Slice dimension must be non-negative. Got -2"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], -2)
|
||||
|
||||
def testSliceDimOutOfRange(self):
|
||||
t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32)
|
||||
t2 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||
t3 = _create_tensor(2, dtype=dtypes.int32)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
IndexError,
|
||||
r"Slice dimension \(2\) must be smaller than rank of all tensors, "
|
||||
"but tensor at index 0 has rank 2"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], 2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
IndexError,
|
||||
r"Slice dimension \(1\) must be smaller than rank of all tensors, "
|
||||
"but tensor at index 0 has rank 1"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2], 1)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
IndexError,
|
||||
r"Slice dimension \(1\) must be smaller than rank of all tensors, "
|
||||
"but tensor at index 1 has rank 1"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2], 1)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
IndexError,
|
||||
r"Slice dimension \(0\) must be smaller than rank of all tensors, "
|
||||
"but tensor at index 0 has rank 0"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t3], 0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
IndexError,
|
||||
r"Slice dimension \(0\) must be smaller than rank of all tensors, "
|
||||
"but tensor at index 2 has rank 0"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||
|
||||
from math import ceil
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
|
|
@ -102,32 +103,46 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
|
|||
|
||||
concat_dim = op.inputs[dim_index]
|
||||
input_values = op.inputs[start_value_index:end_value_index]
|
||||
# Using mod here for convenience since concat_dim is already verified
|
||||
# in concat implementation to be within the allowed [-rank, rank) range.
|
||||
non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
|
||||
|
||||
out_grads = []
|
||||
if isinstance(grad, ops.Tensor):
|
||||
# Get the inputs' tensor shapes
|
||||
sizes = _ExtractInputShapes(input_values)
|
||||
# The magic number of 16 was found through benchmarking a range of sizes
|
||||
# on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of
|
||||
# cases when switching implementations at N=16, but it is possible that
|
||||
# there will be a small number of performance regressions.
|
||||
# pylint: disable=protected-access
|
||||
if len(sizes) > 16:
|
||||
# extract the size of each input along the concat dimension
|
||||
sizes = array_ops.squeeze(
|
||||
array_ops.slice(
|
||||
array_ops.stack(
|
||||
sizes, axis=1), [non_neg_concat_dim, 0], [1, -1]))
|
||||
if context.in_eager_mode():
|
||||
# Using mod here for convenience since concat_dim is already verified
|
||||
# in concat implementation to be within the allowed [-rank, rank) range.
|
||||
non_neg_concat_dim = (
|
||||
concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access
|
||||
# All inputs are guaranteed to be EagerTensors in eager mode
|
||||
sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice(input_values,
|
||||
non_neg_concat_dim)
|
||||
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
|
||||
else:
|
||||
offset = gen_array_ops._concat_offset(non_neg_concat_dim, sizes)
|
||||
for (begin, size) in zip(offset, sizes):
|
||||
out_grads.append(array_ops.slice(grad, begin, size))
|
||||
# pylint: enable=protected-access
|
||||
# Using mod here for convenience since concat_dim is already verified
|
||||
# in concat implementation to be within the allowed [-rank, rank) range.
|
||||
non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
|
||||
|
||||
# Get the inputs' tensor shapes
|
||||
sizes = _ExtractInputShapes(input_values)
|
||||
# The magic number of 16 was found through benchmarking a range of sizes
|
||||
# on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of
|
||||
# cases when switching implementations at N=16, but it is possible that
|
||||
# there will be a small number of performance regressions.
|
||||
# pylint: disable=protected-access
|
||||
if len(sizes) > 16:
|
||||
# extract the size of each input along the concat dimension
|
||||
sizes = array_ops.squeeze(
|
||||
array_ops.slice(
|
||||
array_ops.stack(
|
||||
sizes, axis=1), [non_neg_concat_dim, 0], [1, -1]))
|
||||
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
|
||||
else:
|
||||
offset = gen_array_ops._concat_offset(non_neg_concat_dim, sizes)
|
||||
for (begin, size) in zip(offset, sizes):
|
||||
out_grads.append(array_ops.slice(grad, begin, size))
|
||||
# pylint: enable=protected-access
|
||||
elif isinstance(grad, ops.IndexedSlices):
|
||||
# Using mod here for convenience since concat_dim is already verified
|
||||
# in concat implementation to be within the allowed [-rank, rank) range.
|
||||
non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
|
||||
concat_dim_static = tensor_util.constant_value(concat_dim)
|
||||
if concat_dim_static is None:
|
||||
raise ValueError("Can only compute IndexedSlices gradient with "
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||
%rename("%s") TFE_ContextOptionsSetConfig;
|
||||
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
|
||||
%rename("%s") TFE_DeleteContextOptions;
|
||||
%rename("%s") TFE_Py_TensorShapeSlice;
|
||||
|
||||
%{
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user