mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Deletes convert_n_to_eager_tensor. Moves convert_to_eager_tensor to constant_op.
PiperOrigin-RevId: 165704074
This commit is contained in:
parent
573b303ac8
commit
a6729325a3
|
|
@ -272,9 +272,7 @@ class TargetTest(test_util.TensorFlowTestCase):
|
|||
|
||||
def testInvalidInputDataType(self):
|
||||
# Fill requires the first input to be an int32 tensor.
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
'Expected tensor with type tf.int32 not tf.int64'):
|
||||
with self.assertRaisesRegexp(ValueError, 'int64'):
|
||||
array_ops.fill(tensor.Tensor([2], dtype=dtypes.int64), tensor.Tensor(1))
|
||||
|
||||
def testOutputOnHostMemory(self):
|
||||
|
|
|
|||
|
|
@ -624,8 +624,8 @@ void GenEagerPythonOp::AddEagerInputCasts() {
|
|||
const string fn = arg.number_attr().empty() ? "" : "n_";
|
||||
const string dtype =
|
||||
python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
|
||||
strings::StrAppend(&result_, " ", param, " = _tensor.convert_", fn,
|
||||
"to_eager_tensor(", param, ", ", dtype, ")\n");
|
||||
strings::StrAppend(&result_, " ", param, " = _ops.convert_", fn,
|
||||
"to_tensor(", param, ", ", dtype, ")\n");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,8 +24,6 @@ import numpy as np
|
|||
# ops.py.
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.framework.ops import _tensor_from_handle
|
||||
from tensorflow.python.framework.ops import convert_n_to_eager_tensor
|
||||
from tensorflow.python.framework.ops import convert_to_eager_tensor
|
||||
from tensorflow.python.framework.ops import EagerTensor as Tensor
|
||||
# pylint: enable=unused-import
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from autograd import core as ag_core
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
|
|
@ -66,13 +67,29 @@ def _eager_reshape(tensor, shape):
|
|||
def _eager_fill(dims, value):
|
||||
"""Eager-only version of Fill op; requires value is an eager Tensor."""
|
||||
attr_t = value.dtype.as_datatype_enum
|
||||
dims = ops.convert_to_eager_tensor(dims, dtypes.int32)
|
||||
dims = convert_to_eager_tensor(dims, dtypes.int32)
|
||||
inputs_flat = [dims, value]
|
||||
attrs = ("T", attr_t)
|
||||
result, = execute.execute("Fill", 1, inputs=inputs_flat, attrs=attrs)
|
||||
return result
|
||||
|
||||
|
||||
def convert_to_eager_tensor(t, dtype=None):
|
||||
"""Converts the given `value` to an `EagerTensor`."""
|
||||
if isinstance(ag_core.getval(t), ops.EagerTensor):
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
|
||||
return t
|
||||
# Handle converting ResourceVariable to Tensor.
|
||||
# TODO(josh11b): get rid of this explicit ugly conversion once we have a more
|
||||
# general scheme in place.
|
||||
try:
|
||||
return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access
|
||||
except AttributeError:
|
||||
pass
|
||||
return ops.EagerTensor(t, dtype=dtype)
|
||||
|
||||
|
||||
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
|
||||
"""Creates a constant tensor.
|
||||
|
||||
|
|
@ -123,8 +140,8 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
|
|||
"""
|
||||
if not context.in_graph_mode():
|
||||
if shape is None:
|
||||
return ops.convert_to_eager_tensor(value, dtype)
|
||||
t = ops.convert_to_eager_tensor(value, dtype)
|
||||
return convert_to_eager_tensor(value, dtype)
|
||||
t = convert_to_eager_tensor(value, dtype)
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
if shape == t.shape:
|
||||
return t
|
||||
|
|
|
|||
|
|
@ -876,29 +876,6 @@ class EagerTensor(Tensor):
|
|||
raise NotImplementedError("eval not supported for Eager Tensors.")
|
||||
|
||||
|
||||
# TODO(josh11b): Support other cases like converting TensorShape, lists/tuples and
|
||||
# other custom conversion functions.
|
||||
def convert_to_eager_tensor(t, dtype=None):
|
||||
"""Converts the given `value` to an `EagerTensor`."""
|
||||
if isinstance(ag_core.getval(t), EagerTensor):
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
|
||||
return t
|
||||
# Handle converting ResourceVariable to Tensor.
|
||||
# TODO(josh11b): get rid of this explicit ugly conversion once we have a more
|
||||
# general scheme in place.
|
||||
try:
|
||||
return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access
|
||||
except AttributeError:
|
||||
pass
|
||||
return EagerTensor(t, dtype=dtype)
|
||||
|
||||
|
||||
def convert_n_to_eager_tensor(values, dtype):
|
||||
"""Converts the given `values` to a list of `EagerTensor`."""
|
||||
return [convert_to_eager_tensor(t, dtype) for t in values]
|
||||
|
||||
|
||||
def _tensor_from_handle(handle):
|
||||
"""'Private' constructor for the Tensor object.
|
||||
|
||||
|
|
@ -1112,21 +1089,17 @@ def internal_convert_n_to_tensor(values,
|
|||
"""
|
||||
if not isinstance(values, collections.Sequence):
|
||||
raise TypeError("values must be a list.")
|
||||
if context.in_graph_mode():
|
||||
ret = []
|
||||
for i, value in enumerate(values):
|
||||
n = None if name is None else "%s_%d" % (name, i)
|
||||
ret.append(
|
||||
internal_convert_to_tensor(
|
||||
value,
|
||||
dtype=dtype,
|
||||
name=n,
|
||||
as_ref=as_ref,
|
||||
preferred_dtype=preferred_dtype))
|
||||
return ret
|
||||
else:
|
||||
# TODO(josh11b): handle preferred_dtype, as_ref
|
||||
return convert_n_to_eager_tensor(values, dtype=dtype)
|
||||
ret = []
|
||||
for i, value in enumerate(values):
|
||||
n = None if name is None else "%s_%d" % (name, i)
|
||||
ret.append(
|
||||
internal_convert_to_tensor(
|
||||
value,
|
||||
dtype=dtype,
|
||||
name=n,
|
||||
as_ref=as_ref,
|
||||
preferred_dtype=preferred_dtype))
|
||||
return ret
|
||||
|
||||
|
||||
def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user