Deletes convert_n_to_eager_tensor. Moves convert_to_eager_tensor to constant_op.

PiperOrigin-RevId: 165704074
This commit is contained in:
Alexandre Passos 2017-08-18 07:39:41 -07:00 committed by TensorFlower Gardener
parent 573b303ac8
commit a6729325a3
5 changed files with 34 additions and 48 deletions

View File

@ -272,9 +272,7 @@ class TargetTest(test_util.TensorFlowTestCase):
def testInvalidInputDataType(self):
# Fill requires the first input to be an int32 tensor.
with self.assertRaisesRegexp(
TypeError,
'Expected tensor with type tf.int32 not tf.int64'):
with self.assertRaisesRegexp(ValueError, 'int64'):
array_ops.fill(tensor.Tensor([2], dtype=dtypes.int64), tensor.Tensor(1))
def testOutputOnHostMemory(self):

View File

@ -624,8 +624,8 @@ void GenEagerPythonOp::AddEagerInputCasts() {
const string fn = arg.number_attr().empty() ? "" : "n_";
const string dtype =
python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
strings::StrAppend(&result_, " ", param, " = _tensor.convert_", fn,
"to_eager_tensor(", param, ", ", dtype, ")\n");
strings::StrAppend(&result_, " ", param, " = _ops.convert_", fn,
"to_tensor(", param, ", ", dtype, ")\n");
}
}

View File

@ -24,8 +24,6 @@ import numpy as np
# ops.py.
# pylint: disable=unused-import
from tensorflow.python.framework.ops import _tensor_from_handle
from tensorflow.python.framework.ops import convert_n_to_eager_tensor
from tensorflow.python.framework.ops import convert_to_eager_tensor
from tensorflow.python.framework.ops import EagerTensor as Tensor
# pylint: enable=unused-import

View File

@ -41,6 +41,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from autograd import core as ag_core
import numpy as np
from tensorflow.core.framework import attr_value_pb2
@ -66,13 +67,29 @@ def _eager_reshape(tensor, shape):
def _eager_fill(dims, value):
"""Eager-only version of Fill op; requires value is an eager Tensor."""
attr_t = value.dtype.as_datatype_enum
dims = ops.convert_to_eager_tensor(dims, dtypes.int32)
dims = convert_to_eager_tensor(dims, dtypes.int32)
inputs_flat = [dims, value]
attrs = ("T", attr_t)
result, = execute.execute("Fill", 1, inputs=inputs_flat, attrs=attrs)
return result
def convert_to_eager_tensor(t, dtype=None):
"""Converts the given `value` to an `EagerTensor`."""
if isinstance(ag_core.getval(t), ops.EagerTensor):
if dtype is not None and t.dtype != dtype:
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
return t
# Handle converting ResourceVariable to Tensor.
# TODO(josh11b): get rid of this explicit ugly conversion once we have a more
# general scheme in place.
try:
return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access
except AttributeError:
pass
return ops.EagerTensor(t, dtype=dtype)
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
"""Creates a constant tensor.
@ -123,8 +140,8 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
"""
if not context.in_graph_mode():
if shape is None:
return ops.convert_to_eager_tensor(value, dtype)
t = ops.convert_to_eager_tensor(value, dtype)
return convert_to_eager_tensor(value, dtype)
t = convert_to_eager_tensor(value, dtype)
shape = tensor_shape.as_shape(shape)
if shape == t.shape:
return t

View File

@ -876,29 +876,6 @@ class EagerTensor(Tensor):
raise NotImplementedError("eval not supported for Eager Tensors.")
# TODO(josh11b): Support other cases like converting TensorShape, lists/tuples and
# other custom conversion functions.
def convert_to_eager_tensor(t, dtype=None):
"""Converts the given `value` to an `EagerTensor`."""
if isinstance(ag_core.getval(t), EagerTensor):
if dtype is not None and t.dtype != dtype:
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
return t
# Handle converting ResourceVariable to Tensor.
# TODO(josh11b): get rid of this explicit ugly conversion once we have a more
# general scheme in place.
try:
return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access
except AttributeError:
pass
return EagerTensor(t, dtype=dtype)
def convert_n_to_eager_tensor(values, dtype):
"""Converts the given `values` to a list of `EagerTensor`."""
return [convert_to_eager_tensor(t, dtype) for t in values]
def _tensor_from_handle(handle):
"""'Private' constructor for the Tensor object.
@ -1112,21 +1089,17 @@ def internal_convert_n_to_tensor(values,
"""
if not isinstance(values, collections.Sequence):
raise TypeError("values must be a list.")
if context.in_graph_mode():
ret = []
for i, value in enumerate(values):
n = None if name is None else "%s_%d" % (name, i)
ret.append(
internal_convert_to_tensor(
value,
dtype=dtype,
name=n,
as_ref=as_ref,
preferred_dtype=preferred_dtype))
return ret
else:
# TODO(josh11b): handle preferred_dtype, as_ref
return convert_n_to_eager_tensor(values, dtype=dtype)
ret = []
for i, value in enumerate(values):
n = None if name is None else "%s_%d" % (name, i)
ret.append(
internal_convert_to_tensor(
value,
dtype=dtype,
name=n,
as_ref=as_ref,
preferred_dtype=preferred_dtype))
return ret
def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):