partially exposing the _set_attr and _get_attr method in python

PiperOrigin-RevId: 174113043
This commit is contained in:
Olivia Nordquist 2017-10-31 16:43:46 -07:00 committed by TensorFlower Gardener
parent 8e732a3124
commit f97e7c69b8
6 changed files with 114 additions and 27 deletions

View File

@ -24,6 +24,19 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
graph->graph.AddControlEdge(&input->node, &op->node);
}
void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
TF_Buffer* attr_value_proto, TF_Status* status) {
AttrValue attr_val;
if (!attr_val.ParseFromArray(attr_value_proto->data,
attr_value_proto->length)) {
status->status =
tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
}
mutex_lock l(graph->mu);
op->node.AddAttr(attr_name, attr_val);
}
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
mutex_lock l(graph->mu);
op->node.set_requested_device(device);

View File

@ -25,6 +25,11 @@ namespace tensorflow {
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
// Changes an attr value in the node_def Protocol Buffer and sets a status upon
// completion.
void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
TF_Buffer* attr_value_proto, TF_Status* status);
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,

View File

@ -341,6 +341,16 @@ bool PyTensorListToVector(PyObject* py_tensor_list,
%rename("_TF_SetConfig") TF_SetConfig;
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
// Create temporary int64_t to pass to TF_OperationGetAttrInt
%typemap(in, numinputs=0) int64_t* value (int64_t val) {
$1 = &val;
}
// Convert value to Python int
%typemap(argout) int64_t* value {
$result = PyInt_FromLong(*$1);
}
%include "tensorflow/c/c_api.h"
%include "tensorflow/c/python_api.h"

View File

@ -2056,6 +2056,19 @@ class Operation(object):
self._traceback,
include_func_start_lineno=True)
def _set_attr(self, attr_name, attr_value):
"""Private method used to set an attribute in the node_def."""
if not _USE_C_API:
assert "_set_attr not supported with _USE_C_API == False"
return
buf = c_api.TF_NewBufferFromString(
compat.as_bytes(attr_value.SerializeToString()))
try:
with errors.raise_exception_on_not_ok_status() as status:
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf, status) # pylint: disable=protected-access
finally:
c_api.TF_DeleteBuffer(buf)
def get_attr(self, name):
"""Returns the value of the attr of this op with the given `name`.
@ -2068,6 +2081,20 @@ class Operation(object):
Raises:
ValueError: If this op does not have an attr with the given `name`.
"""
if _USE_C_API:
try:
# TODO(b/65162920): remove this try/except block when all attrs are
# implemented to use the _set_attr method instead of node_def.attr.
with errors.raise_exception_on_not_ok_status() as status:
metadata = c_api.TF_OperationGetAttrMetadata(self._c_op, name, status)
if metadata.type == c_api.TF_ATTR_INT and metadata.is_list == 0:
return c_api.TF_OperationGetAttrInt(self._c_op, name, status)
except errors.InvalidArgumentError:
# Colocation ops are failing to find attrs begininning with "_*". They
# should fall through to the not-CAPI logic until the attribute is set
# via the C-API always.
pass
fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
if name not in self._node_def.attr:
raise ValueError("No attr named '" + name + "' in " + str(self._node_def))

View File

@ -357,10 +357,18 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
def testGetAttr(self):
# TODO(skyewm): implement get_attr with C API
if ops._USE_C_API: return
# TODO(b/65162920): implement all tests for get_attr with C API
if ops._USE_C_API:
op = test_ops.int_attr().op
self.assertEqual(op.get_attr("foo"), 1)
op_str = test_ops.string_list_attr(a=["z"], b="y")
self.assertEqual(op_str.get_attr("a"), [b"z"])
self.assertEqual(op_str.get_attr("b"), b"y")
else:
list_value = attr_value_pb2.AttrValue.ListValue()
list_value.type.append(types_pb2.DT_STRING)
list_value.type.append(types_pb2.DT_DOUBLE)
op = ops.Operation(
@ -368,10 +376,14 @@ class OperationTest(test_util.TensorFlowTestCase):
"None",
"op1",
attrs={
"value": attr_value_pb2.AttrValue(i=32),
"dtype": attr_value_pb2.AttrValue(type=types_pb2.DT_INT32),
"list": attr_value_pb2.AttrValue(list=list_value),
"func": attr_value_pb2.AttrValue(
"value":
attr_value_pb2.AttrValue(i=32),
"dtype":
attr_value_pb2.AttrValue(type=types_pb2.DT_INT32),
"list":
attr_value_pb2.AttrValue(list=list_value),
"func":
attr_value_pb2.AttrValue(
func=attr_value_pb2.NameAttrList())
}), ops.Graph(), [], [dtypes.int32])
self.assertEqual(32, op.get_attr("value"))
@ -388,6 +400,16 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertIsInstance(x, dtypes.DType)
self.assertEqual([dtypes.string, dtypes.double], l)
# TODO(b/65162920): remove this test when users who are directly mutating the
# node_def have been updated to proper usage.
def testSetAttr(self):
if not ops._USE_C_API:
return
op = test_ops.int_attr().op
op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
# TODO(skyewm): add node_def check
self.assertEqual(op.get_attr("foo"), 2)
# TODO(nolivia): test all error cases
def testAddControlInput(self):
# The C API dedups redundant control edges, pure Python does not

View File

@ -331,4 +331,14 @@ REGISTER_OP("OpWithDefaultAttr")
REGISTER_OP("OpWithFutureDefaultAttr")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("IntAttr")
.Output("out: int64")
.Attr("foo: int = 1")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("StringListAttr")
.Attr("a: list(string)")
.Attr("b: string")
.SetShapeFn(shape_inference::UnknownShape);
} // end namespace tensorflow