mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
matmul uses shape_tuple internally
PiperOrigin-RevId: 170938790
This commit is contained in:
parent
ad37fa81fd
commit
0c8dbc1fda
|
|
@ -1843,11 +1843,12 @@ def matmul(a,
|
|||
|
||||
a = ops.convert_to_tensor(a, name="a")
|
||||
b = ops.convert_to_tensor(b, name="b")
|
||||
a_shape = a.get_shape()
|
||||
b_shape = b.get_shape()
|
||||
# TODO(apassos) remove _shape_tuple here when it is not needed.
|
||||
a_shape = a._shape_tuple() # pylint: disable=protected-access
|
||||
b_shape = b._shape_tuple() # pylint: disable=protected-access
|
||||
if (not a_is_sparse and not b_is_sparse) and (
|
||||
(a_shape.ndims is None or a_shape.ndims > 2) and
|
||||
(b_shape.ndims is None or b_shape.ndims > 2)):
|
||||
(a_shape is None or len(a_shape) > 2) and
|
||||
(b_shape is None or len(b_shape) > 2)):
|
||||
# BatchMatmul does not support transpose, so we conjugate the matrix and
|
||||
# use adjoint instead. Conj() is a noop for real matrices.
|
||||
if transpose_a:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user