mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Internal cleanup.
PiperOrigin-RevId: 170922297
This commit is contained in:
parent
d0c76cd188
commit
5123f29718
|
|
@ -324,13 +324,19 @@ def imperative_grad(
|
|||
result.append(_aggregate_grads(g))
|
||||
return result
|
||||
|
||||
_op_attr_type_cache = {}
|
||||
|
||||
|
||||
def op_attr_type(op_type, attr_name):
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
h = context.context()._handle # pylint: disable=protected-access
|
||||
op = pywrap_tensorflow.TFE_NewOp(h, op_type, status)
|
||||
attr_type = pywrap_tensorflow.TFE_OpGetAttrType(op, attr_name, status)
|
||||
return attr_type
|
||||
try:
|
||||
return _op_attr_type_cache[(op_type, attr_name)]
|
||||
except KeyError:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
h = context.context()._handle # pylint: disable=protected-access
|
||||
op = pywrap_tensorflow.TFE_NewOp(h, op_type, status)
|
||||
attr_type = pywrap_tensorflow.TFE_OpGetAttrType(op, attr_name, status)
|
||||
_op_attr_type_cache[(op_type, attr_name)] = attr_type
|
||||
return attr_type
|
||||
|
||||
|
||||
def make_attr(attr_type, value):
|
||||
|
|
|
|||
|
|
@ -277,7 +277,9 @@ class BatchNormalization(base.Layer):
|
|||
with ops.name_scope(None, 'AssignMovingAvg',
|
||||
[variable, value, one_minus_decay]) as scope:
|
||||
with ops.colocate_with(variable):
|
||||
update_delta = (variable.read_value() - value) * one_minus_decay
|
||||
update_delta = math_ops.multiply(
|
||||
math_ops.subtract(variable.read_value(), value),
|
||||
one_minus_decay)
|
||||
if isinstance(variable, resource_variable_ops.ResourceVariable):
|
||||
# state_ops.assign_sub does an extra read_variable_op after the
|
||||
# assign. We avoid that here.
|
||||
|
|
|
|||
|
|
@ -79,15 +79,16 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
|
|||
|
||||
def _ExtractInputShapes(inputs):
|
||||
"""Extract the shapes of a set of input tensors."""
|
||||
if not context.in_graph_mode():
|
||||
return array_ops.shape_n(inputs)
|
||||
sizes = []
|
||||
fully_known = True
|
||||
for x in inputs:
|
||||
input_shape = array_ops.shape(x)
|
||||
if context.in_graph_mode():
|
||||
if not isinstance(input_shape,
|
||||
ops.Tensor) or input_shape.op.type != "Const":
|
||||
fully_known = False
|
||||
break
|
||||
if not isinstance(input_shape,
|
||||
ops.Tensor) or input_shape.op.type != "Const":
|
||||
fully_known = False
|
||||
break
|
||||
sizes.append(input_shape)
|
||||
|
||||
if fully_known:
|
||||
|
|
|
|||
|
|
@ -460,16 +460,25 @@ def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
|
|||
|
||||
@ops.RegisterGradient("Conv2D")
|
||||
def _Conv2DGrad(op, grad):
|
||||
return [nn_ops.conv2d_backprop_input(
|
||||
array_ops.shape(op.inputs[0]), op.inputs[1], grad, op.get_attr("strides"),
|
||||
op.get_attr("padding"), op.get_attr("use_cudnn_on_gpu"),
|
||||
op.get_attr("data_format")),
|
||||
strides = op.get_attr("strides")
|
||||
padding = op.get_attr("padding")
|
||||
use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu")
|
||||
data_format = op.get_attr("data_format")
|
||||
shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]])
|
||||
return [nn_ops.conv2d_backprop_input(shape_0,
|
||||
op.inputs[1],
|
||||
grad,
|
||||
strides,
|
||||
padding,
|
||||
use_cudnn_on_gpu,
|
||||
data_format),
|
||||
nn_ops.conv2d_backprop_filter(op.inputs[0],
|
||||
array_ops.shape(op.inputs[1]), grad,
|
||||
op.get_attr("strides"),
|
||||
op.get_attr("padding"),
|
||||
op.get_attr("use_cudnn_on_gpu"),
|
||||
op.get_attr("data_format"))]
|
||||
shape_1,
|
||||
grad,
|
||||
strides,
|
||||
padding,
|
||||
use_cudnn_on_gpu,
|
||||
data_format)]
|
||||
|
||||
|
||||
@ops.RegisterGradient("DepthwiseConv2dNative")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user