mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add Tensor._rank() getter
It appears to speed up SPINN model by about 1%, which is not much, but this method is very simple and easier to use than len(tensor._shape_tuple()) PiperOrigin-RevId: 173703259
This commit is contained in:
parent
d7cffe9c03
commit
7c4e98eb4a
|
|
@ -377,6 +377,15 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
|||
return shape;
|
||||
}
|
||||
|
||||
// Getter for `_rank`.
|
||||
static PyObject* EagerTensor_rank(EagerTensor* self) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
return PyInt_FromLong(TFE_TensorHandleNumDims(self->handle));
|
||||
#else
|
||||
return PyLong_FromLong(TFE_TensorHandleNumDims(self->handle));
|
||||
#endif
|
||||
}
|
||||
|
||||
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
|
||||
Py_INCREF(self->handle_data);
|
||||
return self->handle_data;
|
||||
|
|
@ -470,6 +479,7 @@ static PyMethodDef EagerTensor_methods[] = {
|
|||
PyDoc_STR("_datatype_enum")},
|
||||
{"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
|
||||
PyDoc_STR("_shape_tuple")},
|
||||
{"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS, PyDoc_STR("_rank")},
|
||||
{"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
|
||||
METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
|
||||
{nullptr, nullptr},
|
||||
|
|
|
|||
|
|
@ -383,6 +383,14 @@ class Tensor(_TensorLike):
|
|||
return None
|
||||
return tuple(shape)
|
||||
|
||||
def _rank(self):
|
||||
"""Integer rank of this Tensor, if known, else None.
|
||||
|
||||
Returns:
|
||||
Integer rank or None
|
||||
"""
|
||||
return self._shape.ndims
|
||||
|
||||
def get_shape(self):
|
||||
"""Alias of Tensor.shape."""
|
||||
return self.shape
|
||||
|
|
@ -664,6 +672,18 @@ class _EagerTensorBase(Tensor):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _rank(self):
|
||||
"""Integer rank of this Tensor.
|
||||
|
||||
Unlike regular Tensors, the rank is always known for EagerTensors.
|
||||
|
||||
This is more performant than len(self._shape_tuple())
|
||||
|
||||
Returns:
|
||||
Integer rank
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _copy_to_device(self, context, device): # pylint: disable=redefined-outer-name
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user