matmul uses shape_tuple internally

PiperOrigin-RevId: 170938790
This commit is contained in:
Alexandre Passos 2017-10-03 17:06:45 -07:00 committed by TensorFlower Gardener
parent ad37fa81fd
commit 0c8dbc1fda

View File

@ -1843,11 +1843,12 @@ def matmul(a,
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")
a_shape = a.get_shape()
b_shape = b.get_shape()
# TODO(apassos) remove _shape_tuple here when it is not needed.
a_shape = a._shape_tuple() # pylint: disable=protected-access
b_shape = b._shape_tuple() # pylint: disable=protected-access
if (not a_is_sparse and not b_is_sparse) and (
(a_shape.ndims is None or a_shape.ndims > 2) and
(b_shape.ndims is None or b_shape.ndims > 2)):
(a_shape is None or len(a_shape) > 2) and
(b_shape is None or len(b_shape) > 2)):
# BatchMatmul does not support transpose, so we conjugate the matrix and
# use adjoint instead. Conj() is a noop for real matrices.
if transpose_a: