mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Add TFE_Py_TensorShapeSlice function
TFE_Py_TensorShapeSlice takes a list of EagerTensors and returns a list of their i'th dimensions. This utility is fairly niche but it is simple and reduces SPINN training time by over 12%. PiperOrigin-RevId: 174065044
This commit is contained in:
parent
585432cc21
commit
8a09bbc4a5
|
|
@ -657,3 +657,71 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
|
||||||
EagerTensorType->tp_dictoffset = 0;
|
EagerTensorType->tp_dictoffset = 0;
|
||||||
return reinterpret_cast<PyObject*>(EagerTensorType);
|
return reinterpret_cast<PyObject*>(EagerTensorType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) {
|
||||||
|
if (!PyList_Check(tensor_list)) {
|
||||||
|
PyErr_SetString(PyExc_TypeError,
|
||||||
|
tensorflow::strings::StrCat(
|
||||||
|
"tensor_list argument must be a list. Got \"",
|
||||||
|
Py_TYPE(tensor_list)->tp_name, "\"")
|
||||||
|
.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (slice_dim < 0) {
|
||||||
|
PyErr_SetString(
|
||||||
|
PyExc_ValueError,
|
||||||
|
tensorflow::strings::StrCat("Slice dimension must be non-negative. "
|
||||||
|
"Got ",
|
||||||
|
slice_dim)
|
||||||
|
.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Py_ssize_t num_tensors = PyList_Size(tensor_list);
|
||||||
|
int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
|
||||||
|
auto tensor = tensorflow::make_safe(TF_AllocateTensor(
|
||||||
|
TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int));
|
||||||
|
int32_t* data = reinterpret_cast<int32_t*>(TF_TensorData(tensor.get()));
|
||||||
|
for (Py_ssize_t i = 0; i < num_tensors; ++i) {
|
||||||
|
PyObject* tensor_obj = PyList_GET_ITEM(tensor_list, i);
|
||||||
|
if (!EagerTensor_CheckExact(tensor_obj)) {
|
||||||
|
PyErr_SetString(PyExc_TypeError,
|
||||||
|
tensorflow::strings::StrCat(
|
||||||
|
"Expected a list of EagerTensors but "
|
||||||
|
"element ",
|
||||||
|
i, " has type \"", Py_TYPE(tensor_obj)->tp_name, "\"")
|
||||||
|
.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
|
||||||
|
TFE_TensorHandle* handle = t->handle;
|
||||||
|
if (slice_dim >= TFE_TensorHandleNumDims(handle)) {
|
||||||
|
PyErr_SetString(PyExc_IndexError,
|
||||||
|
tensorflow::strings::StrCat(
|
||||||
|
"Slice dimension (", slice_dim,
|
||||||
|
") must be smaller than rank of all "
|
||||||
|
"tensors, but tensor at index ",
|
||||||
|
i, " has rank ", TFE_TensorHandleNumDims(handle))
|
||||||
|
.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
int64_t dim = TFE_TensorHandleDim(handle, slice_dim);
|
||||||
|
data[i] = dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto status = tensorflow::make_safe(TF_NewStatus());
|
||||||
|
TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get());
|
||||||
|
if (TF_GetCode(status.get()) != TF_OK) {
|
||||||
|
PyErr_SetString(
|
||||||
|
PyExc_RuntimeError,
|
||||||
|
tensorflow::strings::StrCat("Failed to construct new tensor handle: ",
|
||||||
|
TF_Message(status.get()))
|
||||||
|
.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// handle now owns the tensor. Release it from the smart pointer.
|
||||||
|
tensor.release();
|
||||||
|
|
||||||
|
return EagerTensorFromHandle(handle);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -105,4 +105,16 @@ void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type,
|
||||||
PyObject* backward_function);
|
PyObject* backward_function);
|
||||||
PyObject* TFE_Py_TapeExport(PyObject* tape);
|
PyObject* TFE_Py_TapeExport(PyObject* tape);
|
||||||
|
|
||||||
|
// Returns an EagerTensor of dimension [len(`tensor_list`)] containing
|
||||||
|
// the `slice_dim`'th dimension of each tensor in `tensor_list`. In other words,
|
||||||
|
// TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in
|
||||||
|
// `tensor_list`. For example, if `tensor_list` contains tensors of with shapes
|
||||||
|
// [1, 2, 3], [4, 5], [6, 7, 8, 9], TFE_Py_TensorShapeSlice called with
|
||||||
|
// `slice_dim` equal to 1 will return [2, 5, 7].
|
||||||
|
// On error, returns nullptr and sets python exception.
|
||||||
|
// REQUIRES: `tensor_list` is a python list of EagerTensors
|
||||||
|
// REQUIRES: `slice_dim` is non-negative and smaller than the rank of all
|
||||||
|
// tensors in `tensor_list`.
|
||||||
|
PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim);
|
||||||
|
|
||||||
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import copy
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import core
|
from tensorflow.python.eager import core
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
|
|
@ -216,5 +217,93 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||||
_create_tensor("test string")
|
_create_tensor("test string")
|
||||||
|
|
||||||
|
|
||||||
|
class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def testListOfThree(self):
|
||||||
|
t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32)
|
||||||
|
t2 = _create_tensor([[1, 2, 5], [3, 4, 5]], dtype=dtypes.int32)
|
||||||
|
t3 = _create_tensor([[1], [3], [5], [6]], dtype=dtypes.int32)
|
||||||
|
|
||||||
|
r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 0)
|
||||||
|
self.assertAllEqual(np.array([3, 2, 4]), r.numpy())
|
||||||
|
|
||||||
|
r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 1)
|
||||||
|
self.assertAllEqual(np.array([2, 3, 1]), r.numpy())
|
||||||
|
|
||||||
|
def testEmptyTensorList(self):
|
||||||
|
a = pywrap_tensorflow.TFE_Py_TensorShapeSlice([], 0)
|
||||||
|
self.assertTrue(isinstance(a, ops.EagerTensor))
|
||||||
|
self.assertEqual(0, a.numpy().size)
|
||||||
|
|
||||||
|
def testTensorListContainsNonTensors(self):
|
||||||
|
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError,
|
||||||
|
r"Expected a list of EagerTensors but element 1 has type \"str\""):
|
||||||
|
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, "abc"], 0)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError,
|
||||||
|
r"Expected a list of EagerTensors but element 0 has type \"int\""):
|
||||||
|
pywrap_tensorflow.TFE_Py_TensorShapeSlice([2, t1], 0)
|
||||||
|
|
||||||
|
def testTensorListNotList(self):
|
||||||
|
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError,
|
||||||
|
r"tensor_list argument must be a list. Got \"EagerTensor\""):
|
||||||
|
pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError,
|
||||||
|
r"tensor_list argument must be a list. Got \"tuple\""):
|
||||||
|
pywrap_tensorflow.TFE_Py_TensorShapeSlice((t1,), -2)
|
||||||
|
|
||||||
|
def testNegativeSliceDim(self):
|
||||||
|
t1 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
r"Slice dimension must be non-negative. Got -2"):
|
||||||
|
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], -2)
|
||||||
|
|
||||||
|
def testSliceDimOutOfRange(self):
|
||||||
|
t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32)
|
||||||
|
t2 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||||
|
t3 = _create_tensor(2, dtype=dtypes.int32)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
IndexError,
|
||||||
|
r"Slice dimension \(2\) must be smaller than rank of all tensors, "
|
||||||
|
"but tensor at index 0 has rank 2"):
|
||||||
|
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], 2)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
IndexError,
|
||||||
|
r"Slice dimension \(1\) must be smaller than rank of all tensors, "
|
||||||
|
"but tensor at index 0 has rank 1"):
|
||||||
|
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2], 1)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
IndexError,
|
||||||
|
r"Slice dimension \(1\) must be smaller than rank of all tensors, "
|
||||||
|
"but tensor at index 1 has rank 1"):
|
||||||
|
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2], 1)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
IndexError,
|
||||||
|
r"Slice dimension \(0\) must be smaller than rank of all tensors, "
|
||||||
|
"but tensor at index 0 has rank 0"):
|
||||||
|
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t3], 0)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
IndexError,
|
||||||
|
r"Slice dimension \(0\) must be smaller than rank of all tensors, "
|
||||||
|
"but tensor at index 2 has rank 0"):
|
||||||
|
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||||
|
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
|
||||||
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
|
@ -102,12 +103,23 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
|
||||||
|
|
||||||
concat_dim = op.inputs[dim_index]
|
concat_dim = op.inputs[dim_index]
|
||||||
input_values = op.inputs[start_value_index:end_value_index]
|
input_values = op.inputs[start_value_index:end_value_index]
|
||||||
|
|
||||||
|
out_grads = []
|
||||||
|
if isinstance(grad, ops.Tensor):
|
||||||
|
if context.in_eager_mode():
|
||||||
|
# Using mod here for convenience since concat_dim is already verified
|
||||||
|
# in concat implementation to be within the allowed [-rank, rank) range.
|
||||||
|
non_neg_concat_dim = (
|
||||||
|
concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access
|
||||||
|
# All inputs are guaranteed to be EagerTensors in eager mode
|
||||||
|
sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice(input_values,
|
||||||
|
non_neg_concat_dim)
|
||||||
|
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
|
||||||
|
else:
|
||||||
# Using mod here for convenience since concat_dim is already verified
|
# Using mod here for convenience since concat_dim is already verified
|
||||||
# in concat implementation to be within the allowed [-rank, rank) range.
|
# in concat implementation to be within the allowed [-rank, rank) range.
|
||||||
non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
|
non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
|
||||||
|
|
||||||
out_grads = []
|
|
||||||
if isinstance(grad, ops.Tensor):
|
|
||||||
# Get the inputs' tensor shapes
|
# Get the inputs' tensor shapes
|
||||||
sizes = _ExtractInputShapes(input_values)
|
sizes = _ExtractInputShapes(input_values)
|
||||||
# The magic number of 16 was found through benchmarking a range of sizes
|
# The magic number of 16 was found through benchmarking a range of sizes
|
||||||
|
|
@ -128,6 +140,9 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
|
||||||
out_grads.append(array_ops.slice(grad, begin, size))
|
out_grads.append(array_ops.slice(grad, begin, size))
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
elif isinstance(grad, ops.IndexedSlices):
|
elif isinstance(grad, ops.IndexedSlices):
|
||||||
|
# Using mod here for convenience since concat_dim is already verified
|
||||||
|
# in concat implementation to be within the allowed [-rank, rank) range.
|
||||||
|
non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
|
||||||
concat_dim_static = tensor_util.constant_value(concat_dim)
|
concat_dim_static = tensor_util.constant_value(concat_dim)
|
||||||
if concat_dim_static is None:
|
if concat_dim_static is None:
|
||||||
raise ValueError("Can only compute IndexedSlices gradient with "
|
raise ValueError("Can only compute IndexedSlices gradient with "
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||||
%rename("%s") TFE_ContextOptionsSetConfig;
|
%rename("%s") TFE_ContextOptionsSetConfig;
|
||||||
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
|
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
|
||||||
%rename("%s") TFE_DeleteContextOptions;
|
%rename("%s") TFE_DeleteContextOptions;
|
||||||
|
%rename("%s") TFE_Py_TensorShapeSlice;
|
||||||
|
|
||||||
%{
|
%{
|
||||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user