Add Tensor._rank() getter

It appears to speed up SPINN model by about 1%, which is not much, but
this method is very simple and easier to use than len(tensor._shape_tuple())

PiperOrigin-RevId: 173703259
This commit is contained in:
Igor Ganichev 2017-10-27 12:20:47 -07:00 committed by TensorFlower Gardener
parent d7cffe9c03
commit 7c4e98eb4a
2 changed files with 30 additions and 0 deletions

View File

@ -377,6 +377,15 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
return shape;
}
// Getter for `_rank`.
static PyObject* EagerTensor_rank(EagerTensor* self) {
#if PY_MAJOR_VERSION < 3
return PyInt_FromLong(TFE_TensorHandleNumDims(self->handle));
#else
return PyLong_FromLong(TFE_TensorHandleNumDims(self->handle));
#endif
}
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
Py_INCREF(self->handle_data);
return self->handle_data;
@ -470,6 +479,7 @@ static PyMethodDef EagerTensor_methods[] = {
PyDoc_STR("_datatype_enum")},
{"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
PyDoc_STR("_shape_tuple")},
{"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS, PyDoc_STR("_rank")},
{"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
{nullptr, nullptr},

View File

@ -383,6 +383,14 @@ class Tensor(_TensorLike):
return None
return tuple(shape)
def _rank(self):
"""Integer rank of this Tensor, if known, else None.
Returns:
Integer rank or None
"""
return self._shape.ndims
def get_shape(self):
"""Alias of Tensor.shape."""
return self.shape
@ -664,6 +672,18 @@ class _EagerTensorBase(Tensor):
"""
raise NotImplementedError()
def _rank(self):
"""Integer rank of this Tensor.
Unlike regular Tensors, the rank is always known for EagerTensors.
This is more performant than len(self._shape_tuple())
Returns:
Integer rank
"""
raise NotImplementedError()
def _copy_to_device(self, context, device): # pylint: disable=redefined-outer-name
raise NotImplementedError()