mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: There is a module called `2to3` which you can target for future specifically to remove these, the directory of `caffe2` has the most redundant imports: ```2to3 -f future -w caffe2``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/45033 Reviewed By: seemethere Differential Revision: D23808648 Pulled By: bugra fbshipit-source-id: 38971900f0fe43ab44a9168e57f2307580d36a38
109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
|
|
|
|
|
|
|
|
import numpy as np
|
|
from caffe2.proto import caffe2_pb2
|
|
|
|
from caffe2.python import core, workspace, dyndep, test_util
|
|
|
|
dyndep.InitOpsLibrary('@/caffe2/caffe2/contrib/warpctc:ctc_ops')
|
|
workspace.GlobalInit(["python"])
|
|
|
|
|
|
def softmax(w):
|
|
maxes = np.amax(w, axis=-1, keepdims=True)
|
|
e = np.exp(w - maxes)
|
|
dist = e / np.sum(e, axis=-1, keepdims=True)
|
|
return dist
|
|
|
|
|
|
class CTCOpsTest(test_util.TestCase):
|
|
def verify_cost(self, device_option, is_test, skip_input_lengths=False):
|
|
alphabet_size = 5
|
|
N = 1
|
|
T = 2
|
|
|
|
inputs = np.asarray(
|
|
[
|
|
[[0.1, 0.6, 0.1, 0.1, 0.1]],
|
|
[[0.1, 0.1, 0.6, 0.1, 0.1]],
|
|
]
|
|
).reshape(T, N, alphabet_size).astype(np.float32)
|
|
|
|
labels = np.asarray([1, 2]).astype(np.int32).reshape(T)
|
|
label_lengths = np.asarray([2]).astype(np.int32).reshape(N)
|
|
input_lengths = np.asarray([T]).astype(np.int32)
|
|
|
|
net = core.Net("test-net")
|
|
input_blobs = ["inputs", "labels", "label_lengths"]
|
|
if not skip_input_lengths:
|
|
input_blobs.append("input_lengths")
|
|
output_blobs = ["costs", "workspace"] if is_test \
|
|
else ["inputs_grad_to_be_copied", "costs", "workspace"]
|
|
net.CTC(input_blobs,
|
|
output_blobs,
|
|
is_test=is_test,
|
|
device_option=device_option)
|
|
if not is_test:
|
|
net.AddGradientOperators(["costs"])
|
|
self.ws.create_blob("inputs").feed(inputs, device_option=device_option)
|
|
self.ws.create_blob("labels").feed(labels)
|
|
self.ws.create_blob("label_lengths").feed(label_lengths)
|
|
if not skip_input_lengths:
|
|
self.ws.create_blob("input_lengths").feed(input_lengths)
|
|
self.ws.run(net)
|
|
probs = softmax(inputs)
|
|
expected = probs[0, 0, 1] * probs[1, 0, 2]
|
|
self.assertEqual(self.ws.blobs["costs"].fetch().shape, (N,))
|
|
self.assertEqual(self.ws.blobs["costs"].fetch().dtype, np.float32)
|
|
cost = self.ws.blobs["costs"].fetch()[0]
|
|
print(cost)
|
|
self.assertAlmostEqual(np.exp(-cost), expected)
|
|
if not is_test:
|
|
# Make sure inputs_grad was added by AddGradientOperators and
|
|
# it is equal to the inputs_grad_to_be_copied blob returned by CTCop
|
|
assert np.array_equal(
|
|
self.ws.blobs["inputs_grad"].fetch(),
|
|
self.ws.blobs["inputs_grad_to_be_copied"].fetch()
|
|
)
|
|
|
|
def test_ctc_cost_cpu(self):
|
|
self.verify_cost(
|
|
caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU),
|
|
is_test=False)
|
|
self.verify_cost(
|
|
caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU),
|
|
is_test=False, skip_input_lengths=True)
|
|
|
|
def test_ctc_cost_gpu(self):
|
|
self.verify_cost(
|
|
caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA,
|
|
device_id=0),
|
|
is_test=False)
|
|
self.verify_cost(
|
|
caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA,
|
|
device_id=0),
|
|
is_test=False,
|
|
skip_input_lengths=True)
|
|
|
|
def test_ctc_forward_only_cpu(self):
|
|
self.verify_cost(
|
|
caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU),
|
|
is_test=True)
|
|
self.verify_cost(
|
|
caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU),
|
|
is_test=True,
|
|
skip_input_lengths=True)
|
|
|
|
def test_ctc_forward_only_gpu(self):
|
|
self.verify_cost(
|
|
caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA,
|
|
device_id=0),
|
|
is_test=True)
|
|
self.verify_cost(
|
|
caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA,
|
|
device_id=0),
|
|
is_test=True,
|
|
skip_input_lengths=True)
|