Small changes to op framework.

PiperOrigin-RevId: 163361071
This commit is contained in:
A. Unique TensorFlower 2017-07-27 10:49:17 -07:00 committed by TensorFlower Gardener
parent 86ca3506f5
commit ae3119d16b
3 changed files with 45 additions and 31 deletions

View File

@ -195,7 +195,12 @@ def _MakeShape(v, arg_name):
str(v))
break
return v
return tensor_shape.as_shape(v).as_proto()
try:
return tensor_shape.as_shape(v).as_proto()
except TypeError as e:
raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e))
except ValueError as e:
raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e))
def _MakeTensor(v, arg_name):
@ -266,6 +271,7 @@ class OpDefLibrary(object):
def __init__(self):
self._ops = {}
# pylint: disable=invalid-name
def add_op(self, op_def):
"""Register an OpDef. May call apply_op with the name afterwards."""
if not isinstance(op_def, op_def_pb2.OpDef):
@ -318,6 +324,20 @@ class OpDefLibrary(object):
TypeError: On some errors.
ValueError: On some errors.
"""
output_structure, is_stateful, op = self._apply_op_helper(
op_type_name, name, **keywords)
if output_structure:
outputs = op.outputs
res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
if isinstance(res, list) and not res and is_stateful:
return op
else:
return res
else:
return op
def _apply_op_helper(self, op_type_name, name=None, **keywords):
"""Implementation of apply_op that returns output_structure, op."""
op_info = self._ops.get(op_type_name, None)
if op_info is None:
raise RuntimeError("Unrecognized Op name " + op_type_name)
@ -617,8 +637,8 @@ class OpDefLibrary(object):
if input_arg.is_ref:
if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access
raise TypeError(
("'%s' Op requires that input '%s' be a mutable tensor " +
"(e.g.: a tf.Variable)") % (op_type_name, input_name))
("'%s' Op requires that input '%s' be a mutable tensor "
"(e.g.: a tf.Variable)") % (op_type_name, input_name))
input_types.extend(types)
else:
input_types.extend(base_types)
@ -765,12 +785,6 @@ class OpDefLibrary(object):
op = g.create_op(op_type_name, inputs, output_types, name=scope,
input_types=input_types, attrs=attr_protos,
op_def=op_def)
if output_structure:
outputs = op.outputs
res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
if isinstance(res, list) and not res and op_def.is_stateful:
return op
else:
return res
else:
return op
return output_structure, op_def.is_stateful, op
# pylint: enable=invalid-name

View File

@ -682,28 +682,26 @@ void GenPythonOp::AddDocStringOutputs() {
}
void GenPythonOp::AddBody(const string& prefix) {
AddBodyNoReturn(prefix);
strings::StrAppend(&result_, prefix, "return _result\n");
}
void GenPythonOp::AddBodyNoReturn(const string& prefix) {
string return_prefix =
const string apply_prefix =
strings::StrCat(prefix, "_result = _op_def_lib.apply_op(");
string return_args = strings::StrCat("\"", op_def_.name(), "\", ");
for (size_t i = 0; i < param_names_.size(); ++i) {
strings::StrAppend(&return_args, param_names_[i], "=", param_names_[i],
", ");
}
strings::StrAppend(&return_args, "name=name)");
strings::StrAppend(&result_,
// Wrap the arguments, and indent to the (.
WordWrap(return_prefix, return_args, kRightMargin), "\n");
AddBodyNoReturn(apply_prefix);
if (num_outs_ > 1) {
strings::StrAppend(&result_, prefix, "_result = _", op_def_.name(),
"Output._make(_result)\n");
}
strings::StrAppend(&result_, prefix, "return _result\n");
}
void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) {
string args = strings::StrCat("\"", op_def_.name(), "\", ");
for (size_t i = 0; i < param_names_.size(); ++i) {
strings::StrAppend(&args, param_names_[i], "=", param_names_[i], ", ");
}
strings::StrAppend(&args, "name=name)");
strings::StrAppend(&result_,
// Wrap the arguments, and indent to the (.
WordWrap(apply_prefix, args, kRightMargin), "\n");
}
} // namespace python_op_gen_internal

View File

@ -38,6 +38,8 @@ string AttrValueToPython(const string& type, const AttrValue& value,
void GenerateLowerCaseOpName(const string& str, string* result);
string DataTypeToPython(DataType dtype, const string& dtype_module);
class GenPythonOp {
public:
GenPythonOp(const OpDef& op_def, const string& function_name);
@ -59,11 +61,11 @@ class GenPythonOp {
void AddOutputGlobals();
void AddDocStringOutputs();
void AddBody(const string& prefix);
void AddBodyNoReturn(const string& prefix);
void AddBodyNoReturn(const string& apply_prefix);
// From constructor arguments
const OpDef& op_def_;
const string& function_name_;
const string function_name_;
const int num_outs_;
// Return value from Code() is prelude_ + result_.