diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index 662c2c679c8..e85bba11cd1 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -195,7 +195,12 @@ def _MakeShape(v, arg_name): str(v)) break return v - return tensor_shape.as_shape(v).as_proto() + try: + return tensor_shape.as_shape(v).as_proto() + except TypeError as e: + raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e)) + except ValueError as e: + raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e)) def _MakeTensor(v, arg_name): @@ -266,6 +271,7 @@ class OpDefLibrary(object): def __init__(self): self._ops = {} + # pylint: disable=invalid-name def add_op(self, op_def): """Register an OpDef. May call apply_op with the name afterwards.""" if not isinstance(op_def, op_def_pb2.OpDef): @@ -318,6 +324,20 @@ class OpDefLibrary(object): TypeError: On some errors. ValueError: On some errors. """ + output_structure, is_stateful, op = self._apply_op_helper( + op_type_name, name, **keywords) + if output_structure: + outputs = op.outputs + res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) + if isinstance(res, list) and not res and is_stateful: + return op + else: + return res + else: + return op + + def _apply_op_helper(self, op_type_name, name=None, **keywords): + """Implementation of apply_op that returns output_structure, op.""" op_info = self._ops.get(op_type_name, None) if op_info is None: raise RuntimeError("Unrecognized Op name " + op_type_name) @@ -617,8 +637,8 @@ class OpDefLibrary(object): if input_arg.is_ref: if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access raise TypeError( - ("'%s' Op requires that input '%s' be a mutable tensor " + - "(e.g.: a tf.Variable)") % (op_type_name, input_name)) + ("'%s' Op requires that input '%s' be a mutable tensor " + "(e.g.: a tf.Variable)") % (op_type_name, input_name)) input_types.extend(types) else: input_types.extend(base_types) @@ -765,12 +785,6 @@ class OpDefLibrary(object): op = g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) - if output_structure: - outputs = op.outputs - res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) - if isinstance(res, list) and not res and op_def.is_stateful: - return op - else: - return res - else: - return op + return output_structure, op_def.is_stateful, op + +# pylint: enable=invalid-name diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 090436aebf7..e7aaaeb2338 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -682,28 +682,26 @@ void GenPythonOp::AddDocStringOutputs() { } void GenPythonOp::AddBody(const string& prefix) { - AddBodyNoReturn(prefix); - strings::StrAppend(&result_, prefix, "return _result\n"); -} - -void GenPythonOp::AddBodyNoReturn(const string& prefix) { - string return_prefix = + const string apply_prefix = strings::StrCat(prefix, "_result = _op_def_lib.apply_op("); - string return_args = strings::StrCat("\"", op_def_.name(), "\", "); - for (size_t i = 0; i < param_names_.size(); ++i) { - strings::StrAppend(&return_args, param_names_[i], "=", param_names_[i], - ", "); - } - strings::StrAppend(&return_args, "name=name)"); - - strings::StrAppend(&result_, - // Wrap the arguments, and indent to the (. - WordWrap(return_prefix, return_args, kRightMargin), "\n"); - + AddBodyNoReturn(apply_prefix); if (num_outs_ > 1) { strings::StrAppend(&result_, prefix, "_result = _", op_def_.name(), "Output._make(_result)\n"); } + strings::StrAppend(&result_, prefix, "return _result\n"); +} + +void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) { + string args = strings::StrCat("\"", op_def_.name(), "\", "); + for (size_t i = 0; i < param_names_.size(); ++i) { + strings::StrAppend(&args, param_names_[i], "=", param_names_[i], ", "); + } + strings::StrAppend(&args, "name=name)"); + + strings::StrAppend(&result_, + // Wrap the arguments, and indent to the (. + WordWrap(apply_prefix, args, kRightMargin), "\n"); } } // namespace python_op_gen_internal diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h index d588f362d82..92237ac81a2 100644 --- a/tensorflow/python/framework/python_op_gen_internal.h +++ b/tensorflow/python/framework/python_op_gen_internal.h @@ -38,6 +38,8 @@ string AttrValueToPython(const string& type, const AttrValue& value, void GenerateLowerCaseOpName(const string& str, string* result); +string DataTypeToPython(DataType dtype, const string& dtype_module); + class GenPythonOp { public: GenPythonOp(const OpDef& op_def, const string& function_name); @@ -59,11 +61,11 @@ class GenPythonOp { void AddOutputGlobals(); void AddDocStringOutputs(); void AddBody(const string& prefix); - void AddBodyNoReturn(const string& prefix); + void AddBodyNoReturn(const string& apply_prefix); // From constructor arguments const OpDef& op_def_; - const string& function_name_; + const string function_name_; const int num_outs_; // Return value from Code() is prelude_ + result_.