Deletes convert_n_to_eager_tensor. Moves convert_to_eager_tensor to constant_op.

PiperOrigin-RevId: 165704074
This commit is contained in:
Alexandre Passos 2017-08-18 07:39:41 -07:00 committed by TensorFlower Gardener
parent 573b303ac8
commit a6729325a3
5 changed files with 34 additions and 48 deletions

View File

@ -272,9 +272,7 @@ class TargetTest(test_util.TensorFlowTestCase):
def testInvalidInputDataType(self): def testInvalidInputDataType(self):
# Fill requires the first input to be an int32 tensor. # Fill requires the first input to be an int32 tensor.
with self.assertRaisesRegexp( with self.assertRaisesRegexp(ValueError, 'int64'):
TypeError,
'Expected tensor with type tf.int32 not tf.int64'):
array_ops.fill(tensor.Tensor([2], dtype=dtypes.int64), tensor.Tensor(1)) array_ops.fill(tensor.Tensor([2], dtype=dtypes.int64), tensor.Tensor(1))
def testOutputOnHostMemory(self): def testOutputOnHostMemory(self):

View File

@ -624,8 +624,8 @@ void GenEagerPythonOp::AddEagerInputCasts() {
const string fn = arg.number_attr().empty() ? "" : "n_"; const string fn = arg.number_attr().empty() ? "" : "n_";
const string dtype = const string dtype =
python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes."); python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
strings::StrAppend(&result_, " ", param, " = _tensor.convert_", fn, strings::StrAppend(&result_, " ", param, " = _ops.convert_", fn,
"to_eager_tensor(", param, ", ", dtype, ")\n"); "to_tensor(", param, ", ", dtype, ")\n");
} }
} }

View File

@ -24,8 +24,6 @@ import numpy as np
# ops.py. # ops.py.
# pylint: disable=unused-import # pylint: disable=unused-import
from tensorflow.python.framework.ops import _tensor_from_handle from tensorflow.python.framework.ops import _tensor_from_handle
from tensorflow.python.framework.ops import convert_n_to_eager_tensor
from tensorflow.python.framework.ops import convert_to_eager_tensor
from tensorflow.python.framework.ops import EagerTensor as Tensor from tensorflow.python.framework.ops import EagerTensor as Tensor
# pylint: enable=unused-import # pylint: enable=unused-import

View File

@ -41,6 +41,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from autograd import core as ag_core
import numpy as np import numpy as np
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
@ -66,13 +67,29 @@ def _eager_reshape(tensor, shape):
def _eager_fill(dims, value): def _eager_fill(dims, value):
"""Eager-only version of Fill op; requires value is an eager Tensor.""" """Eager-only version of Fill op; requires value is an eager Tensor."""
attr_t = value.dtype.as_datatype_enum attr_t = value.dtype.as_datatype_enum
dims = ops.convert_to_eager_tensor(dims, dtypes.int32) dims = convert_to_eager_tensor(dims, dtypes.int32)
inputs_flat = [dims, value] inputs_flat = [dims, value]
attrs = ("T", attr_t) attrs = ("T", attr_t)
result, = execute.execute("Fill", 1, inputs=inputs_flat, attrs=attrs) result, = execute.execute("Fill", 1, inputs=inputs_flat, attrs=attrs)
return result return result
def convert_to_eager_tensor(t, dtype=None):
"""Converts the given `value` to an `EagerTensor`."""
if isinstance(ag_core.getval(t), ops.EagerTensor):
if dtype is not None and t.dtype != dtype:
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
return t
# Handle converting ResourceVariable to Tensor.
# TODO(josh11b): get rid of this explicit ugly conversion once we have a more
# general scheme in place.
try:
return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access
except AttributeError:
pass
return ops.EagerTensor(t, dtype=dtype)
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
"""Creates a constant tensor. """Creates a constant tensor.
@ -123,8 +140,8 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
""" """
if not context.in_graph_mode(): if not context.in_graph_mode():
if shape is None: if shape is None:
return ops.convert_to_eager_tensor(value, dtype) return convert_to_eager_tensor(value, dtype)
t = ops.convert_to_eager_tensor(value, dtype) t = convert_to_eager_tensor(value, dtype)
shape = tensor_shape.as_shape(shape) shape = tensor_shape.as_shape(shape)
if shape == t.shape: if shape == t.shape:
return t return t

View File

@ -876,29 +876,6 @@ class EagerTensor(Tensor):
raise NotImplementedError("eval not supported for Eager Tensors.") raise NotImplementedError("eval not supported for Eager Tensors.")
# TODO(josh11b): Support other cases like converting TensorShape, lists/tuples and
# other custom conversion functions.
def convert_to_eager_tensor(t, dtype=None):
"""Converts the given `value` to an `EagerTensor`."""
if isinstance(ag_core.getval(t), EagerTensor):
if dtype is not None and t.dtype != dtype:
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
return t
# Handle converting ResourceVariable to Tensor.
# TODO(josh11b): get rid of this explicit ugly conversion once we have a more
# general scheme in place.
try:
return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access
except AttributeError:
pass
return EagerTensor(t, dtype=dtype)
def convert_n_to_eager_tensor(values, dtype):
"""Converts the given `values` to a list of `EagerTensor`."""
return [convert_to_eager_tensor(t, dtype) for t in values]
def _tensor_from_handle(handle): def _tensor_from_handle(handle):
"""'Private' constructor for the Tensor object. """'Private' constructor for the Tensor object.
@ -1112,21 +1089,17 @@ def internal_convert_n_to_tensor(values,
""" """
if not isinstance(values, collections.Sequence): if not isinstance(values, collections.Sequence):
raise TypeError("values must be a list.") raise TypeError("values must be a list.")
if context.in_graph_mode(): ret = []
ret = [] for i, value in enumerate(values):
for i, value in enumerate(values): n = None if name is None else "%s_%d" % (name, i)
n = None if name is None else "%s_%d" % (name, i) ret.append(
ret.append( internal_convert_to_tensor(
internal_convert_to_tensor( value,
value, dtype=dtype,
dtype=dtype, name=n,
name=n, as_ref=as_ref,
as_ref=as_ref, preferred_dtype=preferred_dtype))
preferred_dtype=preferred_dtype)) return ret
return ret
else:
# TODO(josh11b): handle preferred_dtype, as_ref
return convert_n_to_eager_tensor(values, dtype=dtype)
def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None): def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):