mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Small changes to op framework.
PiperOrigin-RevId: 163361071
This commit is contained in:
parent
86ca3506f5
commit
ae3119d16b
|
|
@ -195,7 +195,12 @@ def _MakeShape(v, arg_name):
|
|||
str(v))
|
||||
break
|
||||
return v
|
||||
return tensor_shape.as_shape(v).as_proto()
|
||||
try:
|
||||
return tensor_shape.as_shape(v).as_proto()
|
||||
except TypeError as e:
|
||||
raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e))
|
||||
except ValueError as e:
|
||||
raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e))
|
||||
|
||||
|
||||
def _MakeTensor(v, arg_name):
|
||||
|
|
@ -266,6 +271,7 @@ class OpDefLibrary(object):
|
|||
def __init__(self):
|
||||
self._ops = {}
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def add_op(self, op_def):
|
||||
"""Register an OpDef. May call apply_op with the name afterwards."""
|
||||
if not isinstance(op_def, op_def_pb2.OpDef):
|
||||
|
|
@ -318,6 +324,20 @@ class OpDefLibrary(object):
|
|||
TypeError: On some errors.
|
||||
ValueError: On some errors.
|
||||
"""
|
||||
output_structure, is_stateful, op = self._apply_op_helper(
|
||||
op_type_name, name, **keywords)
|
||||
if output_structure:
|
||||
outputs = op.outputs
|
||||
res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
|
||||
if isinstance(res, list) and not res and is_stateful:
|
||||
return op
|
||||
else:
|
||||
return res
|
||||
else:
|
||||
return op
|
||||
|
||||
def _apply_op_helper(self, op_type_name, name=None, **keywords):
|
||||
"""Implementation of apply_op that returns output_structure, op."""
|
||||
op_info = self._ops.get(op_type_name, None)
|
||||
if op_info is None:
|
||||
raise RuntimeError("Unrecognized Op name " + op_type_name)
|
||||
|
|
@ -617,8 +637,8 @@ class OpDefLibrary(object):
|
|||
if input_arg.is_ref:
|
||||
if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access
|
||||
raise TypeError(
|
||||
("'%s' Op requires that input '%s' be a mutable tensor " +
|
||||
"(e.g.: a tf.Variable)") % (op_type_name, input_name))
|
||||
("'%s' Op requires that input '%s' be a mutable tensor "
|
||||
"(e.g.: a tf.Variable)") % (op_type_name, input_name))
|
||||
input_types.extend(types)
|
||||
else:
|
||||
input_types.extend(base_types)
|
||||
|
|
@ -765,12 +785,6 @@ class OpDefLibrary(object):
|
|||
op = g.create_op(op_type_name, inputs, output_types, name=scope,
|
||||
input_types=input_types, attrs=attr_protos,
|
||||
op_def=op_def)
|
||||
if output_structure:
|
||||
outputs = op.outputs
|
||||
res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
|
||||
if isinstance(res, list) and not res and op_def.is_stateful:
|
||||
return op
|
||||
else:
|
||||
return res
|
||||
else:
|
||||
return op
|
||||
return output_structure, op_def.is_stateful, op
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
|
|
|||
|
|
@ -682,28 +682,26 @@ void GenPythonOp::AddDocStringOutputs() {
|
|||
}
|
||||
|
||||
void GenPythonOp::AddBody(const string& prefix) {
|
||||
AddBodyNoReturn(prefix);
|
||||
strings::StrAppend(&result_, prefix, "return _result\n");
|
||||
}
|
||||
|
||||
void GenPythonOp::AddBodyNoReturn(const string& prefix) {
|
||||
string return_prefix =
|
||||
const string apply_prefix =
|
||||
strings::StrCat(prefix, "_result = _op_def_lib.apply_op(");
|
||||
string return_args = strings::StrCat("\"", op_def_.name(), "\", ");
|
||||
for (size_t i = 0; i < param_names_.size(); ++i) {
|
||||
strings::StrAppend(&return_args, param_names_[i], "=", param_names_[i],
|
||||
", ");
|
||||
}
|
||||
strings::StrAppend(&return_args, "name=name)");
|
||||
|
||||
strings::StrAppend(&result_,
|
||||
// Wrap the arguments, and indent to the (.
|
||||
WordWrap(return_prefix, return_args, kRightMargin), "\n");
|
||||
|
||||
AddBodyNoReturn(apply_prefix);
|
||||
if (num_outs_ > 1) {
|
||||
strings::StrAppend(&result_, prefix, "_result = _", op_def_.name(),
|
||||
"Output._make(_result)\n");
|
||||
}
|
||||
strings::StrAppend(&result_, prefix, "return _result\n");
|
||||
}
|
||||
|
||||
void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) {
|
||||
string args = strings::StrCat("\"", op_def_.name(), "\", ");
|
||||
for (size_t i = 0; i < param_names_.size(); ++i) {
|
||||
strings::StrAppend(&args, param_names_[i], "=", param_names_[i], ", ");
|
||||
}
|
||||
strings::StrAppend(&args, "name=name)");
|
||||
|
||||
strings::StrAppend(&result_,
|
||||
// Wrap the arguments, and indent to the (.
|
||||
WordWrap(apply_prefix, args, kRightMargin), "\n");
|
||||
}
|
||||
|
||||
} // namespace python_op_gen_internal
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ string AttrValueToPython(const string& type, const AttrValue& value,
|
|||
|
||||
void GenerateLowerCaseOpName(const string& str, string* result);
|
||||
|
||||
string DataTypeToPython(DataType dtype, const string& dtype_module);
|
||||
|
||||
class GenPythonOp {
|
||||
public:
|
||||
GenPythonOp(const OpDef& op_def, const string& function_name);
|
||||
|
|
@ -59,11 +61,11 @@ class GenPythonOp {
|
|||
void AddOutputGlobals();
|
||||
void AddDocStringOutputs();
|
||||
void AddBody(const string& prefix);
|
||||
void AddBodyNoReturn(const string& prefix);
|
||||
void AddBodyNoReturn(const string& apply_prefix);
|
||||
|
||||
// From constructor arguments
|
||||
const OpDef& op_def_;
|
||||
const string& function_name_;
|
||||
const string function_name_;
|
||||
const int num_outs_;
|
||||
|
||||
// Return value from Code() is prelude_ + result_.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user