mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix BN for test phase
This commit is contained in:
parent
b861d9a264
commit
00c493864e
|
|
@ -470,9 +470,15 @@ class CNNModelHelper(object):
|
|||
self.biases.append(bias)
|
||||
blob_outs = [blob_out, blob_out + "_rm", blob_out + "_riv",
|
||||
blob_out + "_sm", blob_out + "_siv"]
|
||||
blob_outputs = self.net.SpatialBN(
|
||||
[blob_in, scale, bias, blob_outs[1], blob_outs[2]], blob_outs,
|
||||
order=self.order, **kwargs)
|
||||
if kwargs['is_test']:
|
||||
blob_outputs = self.net.SpatialBN(
|
||||
[blob_in, scale, bias, blob_outs[1], blob_outs[2]], [blob_out],
|
||||
order=self.order, **kwargs)
|
||||
return blob_outputs
|
||||
else:
|
||||
blob_outputs = self.net.SpatialBN(
|
||||
[blob_in, scale, bias, blob_outs[1], blob_outs[2]], blob_outs,
|
||||
order=self.order, **kwargs)
|
||||
# Return the output
|
||||
return blob_outputs[0]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user