mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
partially exposing the _set_attr and _get_attr method in python
PiperOrigin-RevId: 174113043
This commit is contained in:
parent
8e732a3124
commit
f97e7c69b8
|
|
@ -24,6 +24,19 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
|
|||
graph->graph.AddControlEdge(&input->node, &op->node);
|
||||
}
|
||||
|
||||
void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
||||
TF_Buffer* attr_value_proto, TF_Status* status) {
|
||||
AttrValue attr_val;
|
||||
if (!attr_val.ParseFromArray(attr_value_proto->data,
|
||||
attr_value_proto->length)) {
|
||||
status->status =
|
||||
tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
|
||||
}
|
||||
|
||||
mutex_lock l(graph->mu);
|
||||
op->node.AddAttr(attr_name, attr_val);
|
||||
}
|
||||
|
||||
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
|
||||
mutex_lock l(graph->mu);
|
||||
op->node.set_requested_device(device);
|
||||
|
|
|
|||
|
|
@ -25,6 +25,11 @@ namespace tensorflow {
|
|||
|
||||
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
|
||||
|
||||
// Changes an attr value in the node_def Protocol Buffer and sets a status upon
|
||||
// completion.
|
||||
void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
||||
TF_Buffer* attr_value_proto, TF_Status* status);
|
||||
|
||||
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
|
||||
|
||||
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||
|
|
|
|||
|
|
@ -341,6 +341,16 @@ bool PyTensorListToVector(PyObject* py_tensor_list,
|
|||
%rename("_TF_SetConfig") TF_SetConfig;
|
||||
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
|
||||
|
||||
// Create temporary int64_t to pass to TF_OperationGetAttrInt
|
||||
%typemap(in, numinputs=0) int64_t* value (int64_t val) {
|
||||
$1 = &val;
|
||||
}
|
||||
|
||||
// Convert value to Python int
|
||||
%typemap(argout) int64_t* value {
|
||||
$result = PyInt_FromLong(*$1);
|
||||
}
|
||||
|
||||
%include "tensorflow/c/c_api.h"
|
||||
%include "tensorflow/c/python_api.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -2056,6 +2056,19 @@ class Operation(object):
|
|||
self._traceback,
|
||||
include_func_start_lineno=True)
|
||||
|
||||
def _set_attr(self, attr_name, attr_value):
|
||||
"""Private method used to set an attribute in the node_def."""
|
||||
if not _USE_C_API:
|
||||
assert "_set_attr not supported with _USE_C_API == False"
|
||||
return
|
||||
buf = c_api.TF_NewBufferFromString(
|
||||
compat.as_bytes(attr_value.SerializeToString()))
|
||||
try:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf, status) # pylint: disable=protected-access
|
||||
finally:
|
||||
c_api.TF_DeleteBuffer(buf)
|
||||
|
||||
def get_attr(self, name):
|
||||
"""Returns the value of the attr of this op with the given `name`.
|
||||
|
||||
|
|
@ -2068,6 +2081,20 @@ class Operation(object):
|
|||
Raises:
|
||||
ValueError: If this op does not have an attr with the given `name`.
|
||||
"""
|
||||
if _USE_C_API:
|
||||
try:
|
||||
# TODO(b/65162920): remove this try/except block when all attrs are
|
||||
# implemented to use the _set_attr method instead of node_def.attr.
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
metadata = c_api.TF_OperationGetAttrMetadata(self._c_op, name, status)
|
||||
if metadata.type == c_api.TF_ATTR_INT and metadata.is_list == 0:
|
||||
return c_api.TF_OperationGetAttrInt(self._c_op, name, status)
|
||||
except errors.InvalidArgumentError:
|
||||
# Colocation ops are failing to find attrs begininning with "_*". They
|
||||
# should fall through to the not-CAPI logic until the attribute is set
|
||||
# via the C-API always.
|
||||
pass
|
||||
|
||||
fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
|
||||
if name not in self._node_def.attr:
|
||||
raise ValueError("No attr named '" + name + "' in " + str(self._node_def))
|
||||
|
|
|
|||
|
|
@ -357,36 +357,58 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||
self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
|
||||
|
||||
def testGetAttr(self):
|
||||
# TODO(skyewm): implement get_attr with C API
|
||||
if ops._USE_C_API: return
|
||||
# TODO(b/65162920): implement all tests for get_attr with C API
|
||||
if ops._USE_C_API:
|
||||
op = test_ops.int_attr().op
|
||||
self.assertEqual(op.get_attr("foo"), 1)
|
||||
|
||||
list_value = attr_value_pb2.AttrValue.ListValue()
|
||||
list_value.type.append(types_pb2.DT_STRING)
|
||||
list_value.type.append(types_pb2.DT_DOUBLE)
|
||||
op = ops.Operation(
|
||||
ops._NodeDef(
|
||||
"None",
|
||||
"op1",
|
||||
attrs={
|
||||
"value": attr_value_pb2.AttrValue(i=32),
|
||||
"dtype": attr_value_pb2.AttrValue(type=types_pb2.DT_INT32),
|
||||
"list": attr_value_pb2.AttrValue(list=list_value),
|
||||
"func": attr_value_pb2.AttrValue(
|
||||
func=attr_value_pb2.NameAttrList())
|
||||
}), ops.Graph(), [], [dtypes.int32])
|
||||
self.assertEqual(32, op.get_attr("value"))
|
||||
self.assertEqual("", op.get_attr("func").name)
|
||||
op_str = test_ops.string_list_attr(a=["z"], b="y")
|
||||
self.assertEqual(op_str.get_attr("a"), [b"z"])
|
||||
self.assertEqual(op_str.get_attr("b"), b"y")
|
||||
|
||||
d = op.get_attr("dtype")
|
||||
# First check that d is a DType, because the assertEquals will
|
||||
# work no matter what since DType overrides __eq__
|
||||
self.assertIsInstance(d, dtypes.DType)
|
||||
self.assertEqual(dtypes.int32, d)
|
||||
else:
|
||||
list_value = attr_value_pb2.AttrValue.ListValue()
|
||||
|
||||
l = op.get_attr("list")
|
||||
for x in l:
|
||||
self.assertIsInstance(x, dtypes.DType)
|
||||
self.assertEqual([dtypes.string, dtypes.double], l)
|
||||
list_value.type.append(types_pb2.DT_STRING)
|
||||
list_value.type.append(types_pb2.DT_DOUBLE)
|
||||
op = ops.Operation(
|
||||
ops._NodeDef(
|
||||
"None",
|
||||
"op1",
|
||||
attrs={
|
||||
"value":
|
||||
attr_value_pb2.AttrValue(i=32),
|
||||
"dtype":
|
||||
attr_value_pb2.AttrValue(type=types_pb2.DT_INT32),
|
||||
"list":
|
||||
attr_value_pb2.AttrValue(list=list_value),
|
||||
"func":
|
||||
attr_value_pb2.AttrValue(
|
||||
func=attr_value_pb2.NameAttrList())
|
||||
}), ops.Graph(), [], [dtypes.int32])
|
||||
self.assertEqual(32, op.get_attr("value"))
|
||||
self.assertEqual("", op.get_attr("func").name)
|
||||
|
||||
d = op.get_attr("dtype")
|
||||
# First check that d is a DType, because the assertEquals will
|
||||
# work no matter what since DType overrides __eq__
|
||||
self.assertIsInstance(d, dtypes.DType)
|
||||
self.assertEqual(dtypes.int32, d)
|
||||
|
||||
l = op.get_attr("list")
|
||||
for x in l:
|
||||
self.assertIsInstance(x, dtypes.DType)
|
||||
self.assertEqual([dtypes.string, dtypes.double], l)
|
||||
|
||||
# TODO(b/65162920): remove this test when users who are directly mutating the
|
||||
# node_def have been updated to proper usage.
|
||||
def testSetAttr(self):
|
||||
if not ops._USE_C_API:
|
||||
return
|
||||
op = test_ops.int_attr().op
|
||||
op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
|
||||
# TODO(skyewm): add node_def check
|
||||
self.assertEqual(op.get_attr("foo"), 2)
|
||||
|
||||
# TODO(nolivia): test all error cases
|
||||
def testAddControlInput(self):
|
||||
|
|
|
|||
|
|
@ -331,4 +331,14 @@ REGISTER_OP("OpWithDefaultAttr")
|
|||
REGISTER_OP("OpWithFutureDefaultAttr")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("IntAttr")
|
||||
.Output("out: int64")
|
||||
.Attr("foo: int = 1")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("StringListAttr")
|
||||
.Attr("a: list(string)")
|
||||
.Attr("b: string")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user