mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Apply parts of pyupgrade to torch (starting with the safest changes). This PR only does two things: removes the need to inherit from object and removes unused future imports. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308 Approved by: https://github.com/ezyang, https://github.com/albanD
494 lines
17 KiB
Python
494 lines
17 KiB
Python
## @package beam_search
|
|
# Module caffe2.python.models.seq2seq.beam_search
|
|
|
|
|
|
|
|
|
|
|
|
from collections import namedtuple
|
|
from caffe2.python import core
|
|
import caffe2.python.models.seq2seq.seq2seq_util as seq2seq_util
|
|
from caffe2.python.models.seq2seq.seq2seq_model_helper import Seq2SeqModelHelper
|
|
|
|
|
|
class BeamSearchForwardOnly:
|
|
"""
|
|
Class generalizing forward beam search for seq2seq models.
|
|
|
|
Also provides types to specify the recurrent structure of decoding:
|
|
|
|
StateConfig:
|
|
initial_value: blob providing value of state at first step_model
|
|
state_prev_link: LinkConfig describing how recurrent step receives
|
|
input from global state blob in each step
|
|
state_link: LinkConfig describing how step writes (produces new state)
|
|
to global state blob in each step
|
|
|
|
LinkConfig:
|
|
blob: blob connecting global state blob to step application
|
|
offset: offset from beginning of global blob for link in time dimension
|
|
window: width of global blob to read/write in time dimension
|
|
"""
|
|
|
|
LinkConfig = namedtuple('LinkConfig', ['blob', 'offset', 'window'])
|
|
|
|
StateConfig = namedtuple(
|
|
'StateConfig',
|
|
['initial_value', 'state_prev_link', 'state_link'],
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
beam_size,
|
|
model,
|
|
eos_token_id,
|
|
go_token_id=seq2seq_util.GO_ID,
|
|
post_eos_penalty=None,
|
|
):
|
|
self.beam_size = beam_size
|
|
self.model = model
|
|
self.step_model = Seq2SeqModelHelper(
|
|
name='step_model',
|
|
param_model=self.model,
|
|
)
|
|
self.go_token_id = go_token_id
|
|
self.eos_token_id = eos_token_id
|
|
self.post_eos_penalty = post_eos_penalty
|
|
|
|
(
|
|
self.timestep,
|
|
self.scores_t_prev,
|
|
self.tokens_t_prev,
|
|
self.hypo_t_prev,
|
|
self.attention_t_prev,
|
|
) = self.step_model.net.AddExternalInputs(
|
|
'timestep',
|
|
'scores_t_prev',
|
|
'tokens_t_prev',
|
|
'hypo_t_prev',
|
|
'attention_t_prev',
|
|
)
|
|
tokens_t_prev_int32 = self.step_model.net.Cast(
|
|
self.tokens_t_prev,
|
|
'tokens_t_prev_int32',
|
|
to=core.DataType.INT32,
|
|
)
|
|
self.tokens_t_prev_int32_flattened, _ = self.step_model.net.Reshape(
|
|
[tokens_t_prev_int32],
|
|
[tokens_t_prev_int32, 'input_t_int32_old_shape'],
|
|
shape=[1, -1],
|
|
)
|
|
|
|
def get_step_model(self):
|
|
return self.step_model
|
|
|
|
def get_previous_tokens(self):
|
|
return self.tokens_t_prev_int32_flattened
|
|
|
|
def get_timestep(self):
|
|
return self.timestep
|
|
|
|
# TODO: make attentions a generic state
|
|
# data_dependencies is a list of blobs that the operator should wait for
|
|
# before beginning execution. This ensures that ops are run in the correct
|
|
# order when the RecurrentNetwork op is embedded in a DAGNet, for ex.
|
|
def apply(
|
|
self,
|
|
inputs,
|
|
length,
|
|
log_probs,
|
|
attentions,
|
|
state_configs,
|
|
data_dependencies,
|
|
word_rewards=None,
|
|
possible_translation_tokens=None,
|
|
go_token_id=None,
|
|
):
|
|
ZERO = self.model.param_init_net.ConstantFill(
|
|
[],
|
|
'ZERO',
|
|
shape=[1],
|
|
value=0,
|
|
dtype=core.DataType.INT32,
|
|
)
|
|
on_initial_step = self.step_model.net.EQ(
|
|
[ZERO, self.timestep],
|
|
'on_initial_step',
|
|
)
|
|
|
|
if self.post_eos_penalty is not None:
|
|
eos_token = self.model.param_init_net.ConstantFill(
|
|
[],
|
|
'eos_token',
|
|
shape=[self.beam_size],
|
|
value=self.eos_token_id,
|
|
dtype=core.DataType.INT32,
|
|
)
|
|
finished_penalty = self.model.param_init_net.ConstantFill(
|
|
[],
|
|
'finished_penalty',
|
|
shape=[1],
|
|
value=float(self.post_eos_penalty),
|
|
dtype=core.DataType.FLOAT,
|
|
)
|
|
ZERO_FLOAT = self.model.param_init_net.ConstantFill(
|
|
[],
|
|
'ZERO_FLOAT',
|
|
shape=[1],
|
|
value=0.0,
|
|
dtype=core.DataType.FLOAT,
|
|
)
|
|
finished_penalty = self.step_model.net.Conditional(
|
|
[on_initial_step, ZERO_FLOAT, finished_penalty],
|
|
'possible_finished_penalty',
|
|
)
|
|
|
|
tokens_t_flat = self.step_model.net.FlattenToVec(
|
|
self.tokens_t_prev,
|
|
'tokens_t_flat',
|
|
)
|
|
tokens_t_flat_int = self.step_model.net.Cast(
|
|
tokens_t_flat,
|
|
'tokens_t_flat_int',
|
|
to=core.DataType.INT32,
|
|
)
|
|
|
|
predecessor_is_eos = self.step_model.net.EQ(
|
|
[tokens_t_flat_int, eos_token],
|
|
'predecessor_is_eos',
|
|
)
|
|
predecessor_is_eos_float = self.step_model.net.Cast(
|
|
predecessor_is_eos,
|
|
'predecessor_is_eos_float',
|
|
to=core.DataType.FLOAT,
|
|
)
|
|
predecessor_is_eos_penalty = self.step_model.net.Mul(
|
|
[predecessor_is_eos_float, finished_penalty],
|
|
'predecessor_is_eos_penalty',
|
|
broadcast=1,
|
|
)
|
|
|
|
log_probs = self.step_model.net.Add(
|
|
[log_probs, predecessor_is_eos_penalty],
|
|
'log_probs_penalized',
|
|
broadcast=1,
|
|
axis=0,
|
|
)
|
|
|
|
# [beam_size, beam_size]
|
|
best_scores_per_hypo, best_tokens_per_hypo = self.step_model.net.TopK(
|
|
log_probs,
|
|
['best_scores_per_hypo', 'best_tokens_per_hypo_indices'],
|
|
k=self.beam_size,
|
|
)
|
|
if possible_translation_tokens:
|
|
# [beam_size, beam_size]
|
|
best_tokens_per_hypo = self.step_model.net.Gather(
|
|
[possible_translation_tokens, best_tokens_per_hypo],
|
|
['best_tokens_per_hypo']
|
|
)
|
|
|
|
# [beam_size]
|
|
scores_t_prev_squeezed, _ = self.step_model.net.Reshape(
|
|
self.scores_t_prev,
|
|
['scores_t_prev_squeezed', 'scores_t_prev_old_shape'],
|
|
shape=[self.beam_size],
|
|
)
|
|
# [beam_size, beam_size]
|
|
output_scores = self.step_model.net.Add(
|
|
[best_scores_per_hypo, scores_t_prev_squeezed],
|
|
'output_scores',
|
|
broadcast=1,
|
|
axis=0,
|
|
)
|
|
if word_rewards is not None:
|
|
# [beam_size, beam_size]
|
|
word_rewards_for_best_tokens_per_hypo = self.step_model.net.Gather(
|
|
[word_rewards, best_tokens_per_hypo],
|
|
'word_rewards_for_best_tokens_per_hypo',
|
|
)
|
|
# [beam_size, beam_size]
|
|
output_scores = self.step_model.net.Add(
|
|
[output_scores, word_rewards_for_best_tokens_per_hypo],
|
|
'output_scores',
|
|
)
|
|
# [beam_size * beam_size]
|
|
output_scores_flattened, _ = self.step_model.net.Reshape(
|
|
[output_scores],
|
|
[output_scores, 'output_scores_old_shape'],
|
|
shape=[-1],
|
|
)
|
|
MINUS_ONE_INT32 = self.model.param_init_net.ConstantFill(
|
|
[],
|
|
'MINUS_ONE_INT32',
|
|
value=-1,
|
|
shape=[1],
|
|
dtype=core.DataType.INT32,
|
|
)
|
|
BEAM_SIZE = self.model.param_init_net.ConstantFill(
|
|
[],
|
|
'beam_size',
|
|
shape=[1],
|
|
value=self.beam_size,
|
|
dtype=core.DataType.INT32,
|
|
)
|
|
|
|
# current_beam_size (predecessor states from previous step)
|
|
# is 1 on first step (so we just need beam_size scores),
|
|
# and beam_size subsequently (so we need all beam_size * beam_size
|
|
# scores)
|
|
slice_end = self.step_model.net.Conditional(
|
|
[on_initial_step, BEAM_SIZE, MINUS_ONE_INT32],
|
|
['slice_end'],
|
|
)
|
|
|
|
# [current_beam_size * beam_size]
|
|
output_scores_flattened_slice = self.step_model.net.Slice(
|
|
[output_scores_flattened, ZERO, slice_end],
|
|
'output_scores_flattened_slice',
|
|
)
|
|
# [1, current_beam_size * beam_size]
|
|
output_scores_flattened_slice, _ = self.step_model.net.Reshape(
|
|
output_scores_flattened_slice,
|
|
[
|
|
output_scores_flattened_slice,
|
|
'output_scores_flattened_slice_old_shape',
|
|
],
|
|
shape=[1, -1],
|
|
)
|
|
# [1, beam_size]
|
|
scores_t, best_indices = self.step_model.net.TopK(
|
|
output_scores_flattened_slice,
|
|
['scores_t', 'best_indices'],
|
|
k=self.beam_size,
|
|
)
|
|
BEAM_SIZE_64 = self.model.param_init_net.Cast(
|
|
BEAM_SIZE,
|
|
'BEAM_SIZE_64',
|
|
to=core.DataType.INT64,
|
|
)
|
|
# [1, beam_size]
|
|
hypo_t_int32 = self.step_model.net.Div(
|
|
[best_indices, BEAM_SIZE_64],
|
|
'hypo_t_int32',
|
|
broadcast=1,
|
|
)
|
|
hypo_t = self.step_model.net.Cast(
|
|
hypo_t_int32,
|
|
'hypo_t',
|
|
to=core.DataType.FLOAT,
|
|
)
|
|
|
|
# [beam_size, encoder_length, 1]
|
|
attention_t = self.step_model.net.Gather(
|
|
[attentions, hypo_t_int32],
|
|
'attention_t',
|
|
)
|
|
# [1, beam_size, encoder_length]
|
|
attention_t, _ = self.step_model.net.Reshape(
|
|
attention_t,
|
|
[attention_t, 'attention_t_old_shape'],
|
|
shape=[1, self.beam_size, -1],
|
|
)
|
|
# [beam_size * beam_size]
|
|
best_tokens_per_hypo_flatten, _ = self.step_model.net.Reshape(
|
|
best_tokens_per_hypo,
|
|
[
|
|
'best_tokens_per_hypo_flatten',
|
|
'best_tokens_per_hypo_old_shape',
|
|
],
|
|
shape=[-1],
|
|
)
|
|
tokens_t_int32 = self.step_model.net.Gather(
|
|
[best_tokens_per_hypo_flatten, best_indices],
|
|
'tokens_t_int32',
|
|
)
|
|
tokens_t = self.step_model.net.Cast(
|
|
tokens_t_int32,
|
|
'tokens_t',
|
|
to=core.DataType.FLOAT,
|
|
)
|
|
|
|
def choose_state_per_hypo(state_config):
|
|
state_flattened, _ = self.step_model.net.Reshape(
|
|
state_config.state_link.blob,
|
|
[
|
|
state_config.state_link.blob,
|
|
state_config.state_link.blob + '_old_shape',
|
|
],
|
|
shape=[self.beam_size, -1],
|
|
)
|
|
state_chosen_per_hypo = self.step_model.net.Gather(
|
|
[state_flattened, hypo_t_int32],
|
|
str(state_config.state_link.blob) + '_chosen_per_hypo',
|
|
)
|
|
return self.StateConfig(
|
|
initial_value=state_config.initial_value,
|
|
state_prev_link=state_config.state_prev_link,
|
|
state_link=self.LinkConfig(
|
|
blob=state_chosen_per_hypo,
|
|
offset=state_config.state_link.offset,
|
|
window=state_config.state_link.window,
|
|
)
|
|
)
|
|
state_configs = [choose_state_per_hypo(c) for c in state_configs]
|
|
initial_scores = self.model.param_init_net.ConstantFill(
|
|
[],
|
|
'initial_scores',
|
|
shape=[1],
|
|
value=0.0,
|
|
dtype=core.DataType.FLOAT,
|
|
)
|
|
if go_token_id:
|
|
initial_tokens = self.model.net.Copy(
|
|
[go_token_id],
|
|
'initial_tokens',
|
|
)
|
|
else:
|
|
initial_tokens = self.model.param_init_net.ConstantFill(
|
|
[],
|
|
'initial_tokens',
|
|
shape=[1],
|
|
value=float(self.go_token_id),
|
|
dtype=core.DataType.FLOAT,
|
|
)
|
|
|
|
initial_hypo = self.model.param_init_net.ConstantFill(
|
|
[],
|
|
'initial_hypo',
|
|
shape=[1],
|
|
value=0.0,
|
|
dtype=core.DataType.FLOAT,
|
|
)
|
|
encoder_inputs_flattened, _ = self.model.net.Reshape(
|
|
inputs,
|
|
['encoder_inputs_flattened', 'encoder_inputs_old_shape'],
|
|
shape=[-1],
|
|
)
|
|
init_attention = self.model.net.ConstantFill(
|
|
encoder_inputs_flattened,
|
|
'init_attention',
|
|
value=0.0,
|
|
dtype=core.DataType.FLOAT,
|
|
)
|
|
state_configs = state_configs + [
|
|
self.StateConfig(
|
|
initial_value=initial_scores,
|
|
state_prev_link=self.LinkConfig(self.scores_t_prev, 0, 1),
|
|
state_link=self.LinkConfig(scores_t, 1, 1),
|
|
),
|
|
self.StateConfig(
|
|
initial_value=initial_tokens,
|
|
state_prev_link=self.LinkConfig(self.tokens_t_prev, 0, 1),
|
|
state_link=self.LinkConfig(tokens_t, 1, 1),
|
|
),
|
|
self.StateConfig(
|
|
initial_value=initial_hypo,
|
|
state_prev_link=self.LinkConfig(self.hypo_t_prev, 0, 1),
|
|
state_link=self.LinkConfig(hypo_t, 1, 1),
|
|
),
|
|
self.StateConfig(
|
|
initial_value=init_attention,
|
|
state_prev_link=self.LinkConfig(self.attention_t_prev, 0, 1),
|
|
state_link=self.LinkConfig(attention_t, 1, 1),
|
|
),
|
|
]
|
|
fake_input = self.model.net.ConstantFill(
|
|
length,
|
|
'beam_search_fake_input',
|
|
input_as_shape=True,
|
|
extra_shape=[self.beam_size, 1],
|
|
value=0.0,
|
|
dtype=core.DataType.FLOAT,
|
|
)
|
|
all_inputs = (
|
|
[fake_input] +
|
|
self.step_model.params +
|
|
[state_config.initial_value for state_config in state_configs] +
|
|
data_dependencies
|
|
)
|
|
forward_links = []
|
|
recurrent_states = []
|
|
for state_config in state_configs:
|
|
state_name = str(state_config.state_prev_link.blob) + '_states'
|
|
recurrent_states.append(state_name)
|
|
forward_links.append((
|
|
state_config.state_prev_link.blob,
|
|
state_name,
|
|
state_config.state_prev_link.offset,
|
|
state_config.state_prev_link.window,
|
|
))
|
|
forward_links.append((
|
|
state_config.state_link.blob,
|
|
state_name,
|
|
state_config.state_link.offset,
|
|
state_config.state_link.window,
|
|
))
|
|
link_internal, link_external, link_offset, link_window = (
|
|
zip(*forward_links)
|
|
)
|
|
all_outputs = [
|
|
str(s) + '_all'
|
|
for s in [scores_t, tokens_t, hypo_t, attention_t]
|
|
]
|
|
results = self.model.net.RecurrentNetwork(
|
|
all_inputs,
|
|
all_outputs + ['step_workspaces'],
|
|
param=[all_inputs.index(p) for p in self.step_model.params],
|
|
alias_src=[
|
|
str(s) + '_states'
|
|
for s in [
|
|
self.scores_t_prev,
|
|
self.tokens_t_prev,
|
|
self.hypo_t_prev,
|
|
self.attention_t_prev,
|
|
]
|
|
],
|
|
alias_dst=all_outputs,
|
|
alias_offset=[0] * 4,
|
|
recurrent_states=recurrent_states,
|
|
initial_recurrent_state_ids=[
|
|
all_inputs.index(state_config.initial_value)
|
|
for state_config in state_configs
|
|
],
|
|
link_internal=[str(l) for l in link_internal],
|
|
link_external=[str(l) for l in link_external],
|
|
link_offset=link_offset,
|
|
link_window=link_window,
|
|
backward_link_internal=[],
|
|
backward_link_external=[],
|
|
backward_link_offset=[],
|
|
step_net=self.step_model.net.Proto(),
|
|
timestep=str(self.timestep),
|
|
outputs_with_grads=[],
|
|
enable_rnn_executor=1,
|
|
rnn_executor_debug=0
|
|
)
|
|
score_t_all, tokens_t_all, hypo_t_all, attention_t_all = results[:4]
|
|
|
|
output_token_beam_list = self.model.net.Cast(
|
|
tokens_t_all,
|
|
'output_token_beam_list',
|
|
to=core.DataType.INT32,
|
|
)
|
|
output_prev_index_beam_list = self.model.net.Cast(
|
|
hypo_t_all,
|
|
'output_prev_index_beam_list',
|
|
to=core.DataType.INT32,
|
|
)
|
|
output_score_beam_list = self.model.net.Alias(
|
|
score_t_all,
|
|
'output_score_beam_list',
|
|
)
|
|
output_attention_weights_beam_list = self.model.net.Alias(
|
|
attention_t_all,
|
|
'output_attention_weights_beam_list',
|
|
)
|
|
|
|
return (
|
|
output_token_beam_list,
|
|
output_prev_index_beam_list,
|
|
output_score_beam_list,
|
|
output_attention_weights_beam_list,
|
|
)
|