mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
partially exposing the _set_attr and _get_attr method in python
PiperOrigin-RevId: 174113043
This commit is contained in:
parent
8e732a3124
commit
f97e7c69b8
|
|
@ -24,6 +24,19 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
|
||||||
graph->graph.AddControlEdge(&input->node, &op->node);
|
graph->graph.AddControlEdge(&input->node, &op->node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
||||||
|
TF_Buffer* attr_value_proto, TF_Status* status) {
|
||||||
|
AttrValue attr_val;
|
||||||
|
if (!attr_val.ParseFromArray(attr_value_proto->data,
|
||||||
|
attr_value_proto->length)) {
|
||||||
|
status->status =
|
||||||
|
tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
|
||||||
|
}
|
||||||
|
|
||||||
|
mutex_lock l(graph->mu);
|
||||||
|
op->node.AddAttr(attr_name, attr_val);
|
||||||
|
}
|
||||||
|
|
||||||
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
|
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
|
||||||
mutex_lock l(graph->mu);
|
mutex_lock l(graph->mu);
|
||||||
op->node.set_requested_device(device);
|
op->node.set_requested_device(device);
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,11 @@ namespace tensorflow {
|
||||||
|
|
||||||
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
|
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
|
||||||
|
|
||||||
|
// Changes an attr value in the node_def Protocol Buffer and sets a status upon
|
||||||
|
// completion.
|
||||||
|
void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
||||||
|
TF_Buffer* attr_value_proto, TF_Status* status);
|
||||||
|
|
||||||
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
|
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
|
||||||
|
|
||||||
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||||
|
|
|
||||||
|
|
@ -341,6 +341,16 @@ bool PyTensorListToVector(PyObject* py_tensor_list,
|
||||||
%rename("_TF_SetConfig") TF_SetConfig;
|
%rename("_TF_SetConfig") TF_SetConfig;
|
||||||
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
|
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
|
||||||
|
|
||||||
|
// Create temporary int64_t to pass to TF_OperationGetAttrInt
|
||||||
|
%typemap(in, numinputs=0) int64_t* value (int64_t val) {
|
||||||
|
$1 = &val;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert value to Python int
|
||||||
|
%typemap(argout) int64_t* value {
|
||||||
|
$result = PyInt_FromLong(*$1);
|
||||||
|
}
|
||||||
|
|
||||||
%include "tensorflow/c/c_api.h"
|
%include "tensorflow/c/c_api.h"
|
||||||
%include "tensorflow/c/python_api.h"
|
%include "tensorflow/c/python_api.h"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2056,6 +2056,19 @@ class Operation(object):
|
||||||
self._traceback,
|
self._traceback,
|
||||||
include_func_start_lineno=True)
|
include_func_start_lineno=True)
|
||||||
|
|
||||||
|
def _set_attr(self, attr_name, attr_value):
|
||||||
|
"""Private method used to set an attribute in the node_def."""
|
||||||
|
if not _USE_C_API:
|
||||||
|
assert "_set_attr not supported with _USE_C_API == False"
|
||||||
|
return
|
||||||
|
buf = c_api.TF_NewBufferFromString(
|
||||||
|
compat.as_bytes(attr_value.SerializeToString()))
|
||||||
|
try:
|
||||||
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
|
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf, status) # pylint: disable=protected-access
|
||||||
|
finally:
|
||||||
|
c_api.TF_DeleteBuffer(buf)
|
||||||
|
|
||||||
def get_attr(self, name):
|
def get_attr(self, name):
|
||||||
"""Returns the value of the attr of this op with the given `name`.
|
"""Returns the value of the attr of this op with the given `name`.
|
||||||
|
|
||||||
|
|
@ -2068,6 +2081,20 @@ class Operation(object):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If this op does not have an attr with the given `name`.
|
ValueError: If this op does not have an attr with the given `name`.
|
||||||
"""
|
"""
|
||||||
|
if _USE_C_API:
|
||||||
|
try:
|
||||||
|
# TODO(b/65162920): remove this try/except block when all attrs are
|
||||||
|
# implemented to use the _set_attr method instead of node_def.attr.
|
||||||
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
|
metadata = c_api.TF_OperationGetAttrMetadata(self._c_op, name, status)
|
||||||
|
if metadata.type == c_api.TF_ATTR_INT and metadata.is_list == 0:
|
||||||
|
return c_api.TF_OperationGetAttrInt(self._c_op, name, status)
|
||||||
|
except errors.InvalidArgumentError:
|
||||||
|
# Colocation ops are failing to find attrs begininning with "_*". They
|
||||||
|
# should fall through to the not-CAPI logic until the attribute is set
|
||||||
|
# via the C-API always.
|
||||||
|
pass
|
||||||
|
|
||||||
fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
|
fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
|
||||||
if name not in self._node_def.attr:
|
if name not in self._node_def.attr:
|
||||||
raise ValueError("No attr named '" + name + "' in " + str(self._node_def))
|
raise ValueError("No attr named '" + name + "' in " + str(self._node_def))
|
||||||
|
|
|
||||||
|
|
@ -357,36 +357,58 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||||
self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
|
self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
|
||||||
|
|
||||||
def testGetAttr(self):
|
def testGetAttr(self):
|
||||||
# TODO(skyewm): implement get_attr with C API
|
# TODO(b/65162920): implement all tests for get_attr with C API
|
||||||
if ops._USE_C_API: return
|
if ops._USE_C_API:
|
||||||
|
op = test_ops.int_attr().op
|
||||||
|
self.assertEqual(op.get_attr("foo"), 1)
|
||||||
|
|
||||||
list_value = attr_value_pb2.AttrValue.ListValue()
|
op_str = test_ops.string_list_attr(a=["z"], b="y")
|
||||||
list_value.type.append(types_pb2.DT_STRING)
|
self.assertEqual(op_str.get_attr("a"), [b"z"])
|
||||||
list_value.type.append(types_pb2.DT_DOUBLE)
|
self.assertEqual(op_str.get_attr("b"), b"y")
|
||||||
op = ops.Operation(
|
|
||||||
ops._NodeDef(
|
|
||||||
"None",
|
|
||||||
"op1",
|
|
||||||
attrs={
|
|
||||||
"value": attr_value_pb2.AttrValue(i=32),
|
|
||||||
"dtype": attr_value_pb2.AttrValue(type=types_pb2.DT_INT32),
|
|
||||||
"list": attr_value_pb2.AttrValue(list=list_value),
|
|
||||||
"func": attr_value_pb2.AttrValue(
|
|
||||||
func=attr_value_pb2.NameAttrList())
|
|
||||||
}), ops.Graph(), [], [dtypes.int32])
|
|
||||||
self.assertEqual(32, op.get_attr("value"))
|
|
||||||
self.assertEqual("", op.get_attr("func").name)
|
|
||||||
|
|
||||||
d = op.get_attr("dtype")
|
else:
|
||||||
# First check that d is a DType, because the assertEquals will
|
list_value = attr_value_pb2.AttrValue.ListValue()
|
||||||
# work no matter what since DType overrides __eq__
|
|
||||||
self.assertIsInstance(d, dtypes.DType)
|
|
||||||
self.assertEqual(dtypes.int32, d)
|
|
||||||
|
|
||||||
l = op.get_attr("list")
|
list_value.type.append(types_pb2.DT_STRING)
|
||||||
for x in l:
|
list_value.type.append(types_pb2.DT_DOUBLE)
|
||||||
self.assertIsInstance(x, dtypes.DType)
|
op = ops.Operation(
|
||||||
self.assertEqual([dtypes.string, dtypes.double], l)
|
ops._NodeDef(
|
||||||
|
"None",
|
||||||
|
"op1",
|
||||||
|
attrs={
|
||||||
|
"value":
|
||||||
|
attr_value_pb2.AttrValue(i=32),
|
||||||
|
"dtype":
|
||||||
|
attr_value_pb2.AttrValue(type=types_pb2.DT_INT32),
|
||||||
|
"list":
|
||||||
|
attr_value_pb2.AttrValue(list=list_value),
|
||||||
|
"func":
|
||||||
|
attr_value_pb2.AttrValue(
|
||||||
|
func=attr_value_pb2.NameAttrList())
|
||||||
|
}), ops.Graph(), [], [dtypes.int32])
|
||||||
|
self.assertEqual(32, op.get_attr("value"))
|
||||||
|
self.assertEqual("", op.get_attr("func").name)
|
||||||
|
|
||||||
|
d = op.get_attr("dtype")
|
||||||
|
# First check that d is a DType, because the assertEquals will
|
||||||
|
# work no matter what since DType overrides __eq__
|
||||||
|
self.assertIsInstance(d, dtypes.DType)
|
||||||
|
self.assertEqual(dtypes.int32, d)
|
||||||
|
|
||||||
|
l = op.get_attr("list")
|
||||||
|
for x in l:
|
||||||
|
self.assertIsInstance(x, dtypes.DType)
|
||||||
|
self.assertEqual([dtypes.string, dtypes.double], l)
|
||||||
|
|
||||||
|
# TODO(b/65162920): remove this test when users who are directly mutating the
|
||||||
|
# node_def have been updated to proper usage.
|
||||||
|
def testSetAttr(self):
|
||||||
|
if not ops._USE_C_API:
|
||||||
|
return
|
||||||
|
op = test_ops.int_attr().op
|
||||||
|
op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
|
||||||
|
# TODO(skyewm): add node_def check
|
||||||
|
self.assertEqual(op.get_attr("foo"), 2)
|
||||||
|
|
||||||
# TODO(nolivia): test all error cases
|
# TODO(nolivia): test all error cases
|
||||||
def testAddControlInput(self):
|
def testAddControlInput(self):
|
||||||
|
|
|
||||||
|
|
@ -331,4 +331,14 @@ REGISTER_OP("OpWithDefaultAttr")
|
||||||
REGISTER_OP("OpWithFutureDefaultAttr")
|
REGISTER_OP("OpWithFutureDefaultAttr")
|
||||||
.SetShapeFn(shape_inference::UnknownShape);
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
|
REGISTER_OP("IntAttr")
|
||||||
|
.Output("out: int64")
|
||||||
|
.Attr("foo: int = 1")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
|
REGISTER_OP("StringListAttr")
|
||||||
|
.Attr("a: list(string)")
|
||||||
|
.Attr("b: string")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user