mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Add Tensor._rank() getter
It appears to speed up SPINN model by about 1%, which is not much, but this method is very simple and easier to use than len(tensor._shape_tuple()) PiperOrigin-RevId: 173703259
This commit is contained in:
parent
d7cffe9c03
commit
7c4e98eb4a
|
|
@ -377,6 +377,15 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
|
||||||
return shape;
|
return shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Getter for `_rank`.
|
||||||
|
static PyObject* EagerTensor_rank(EagerTensor* self) {
|
||||||
|
#if PY_MAJOR_VERSION < 3
|
||||||
|
return PyInt_FromLong(TFE_TensorHandleNumDims(self->handle));
|
||||||
|
#else
|
||||||
|
return PyLong_FromLong(TFE_TensorHandleNumDims(self->handle));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
|
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
|
||||||
Py_INCREF(self->handle_data);
|
Py_INCREF(self->handle_data);
|
||||||
return self->handle_data;
|
return self->handle_data;
|
||||||
|
|
@ -470,6 +479,7 @@ static PyMethodDef EagerTensor_methods[] = {
|
||||||
PyDoc_STR("_datatype_enum")},
|
PyDoc_STR("_datatype_enum")},
|
||||||
{"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
|
{"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
|
||||||
PyDoc_STR("_shape_tuple")},
|
PyDoc_STR("_shape_tuple")},
|
||||||
|
{"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS, PyDoc_STR("_rank")},
|
||||||
{"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
|
{"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
|
||||||
METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
|
METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
|
||||||
{nullptr, nullptr},
|
{nullptr, nullptr},
|
||||||
|
|
|
||||||
|
|
@ -383,6 +383,14 @@ class Tensor(_TensorLike):
|
||||||
return None
|
return None
|
||||||
return tuple(shape)
|
return tuple(shape)
|
||||||
|
|
||||||
|
def _rank(self):
|
||||||
|
"""Integer rank of this Tensor, if known, else None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integer rank or None
|
||||||
|
"""
|
||||||
|
return self._shape.ndims
|
||||||
|
|
||||||
def get_shape(self):
|
def get_shape(self):
|
||||||
"""Alias of Tensor.shape."""
|
"""Alias of Tensor.shape."""
|
||||||
return self.shape
|
return self.shape
|
||||||
|
|
@ -664,6 +672,18 @@ class _EagerTensorBase(Tensor):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _rank(self):
|
||||||
|
"""Integer rank of this Tensor.
|
||||||
|
|
||||||
|
Unlike regular Tensors, the rank is always known for EagerTensors.
|
||||||
|
|
||||||
|
This is more performant than len(self._shape_tuple())
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integer rank
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _copy_to_device(self, context, device): # pylint: disable=redefined-outer-name
|
def _copy_to_device(self, context, device): # pylint: disable=redefined-outer-name
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user