speechbrain.decoders.seq2seq module
Decoding methods for seq2seq autoregressive model.
- Authors
Adel Moumen 2022
Ju-Chieh Chou 2020
Peter Plantinga 2020
Mirco Ravanelli 2020
Sung-Lin Yeh 2020
Summary
Classes:
S2SBaseSearcher class to be inherited by other decoding approaches for seq2seq model. |
|
This class implements the beam-search algorithm for the seq2seq model. |
|
This class implements the general forward-pass of greedy decoding approach. |
|
This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py) with LM. |
|
This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py) with LM. |
|
This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). |
|
This class implements the greedy decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). |
|
This class implements the beam search decoding for Transformer. |
|
This class implements the beam search decoding for Whisper neural nets made by OpenAI in https://cdn.openai.com/papers/whisper.pdf. |
|
This class implements the greedy decoding for Whisper neural nets made by OpenAI in https://cdn.openai.com/papers/whisper.pdf. |
Functions:
Calling batch_size times of filter_seq2seq_output. |
|
Filter the output until the first eos occurs (exclusive). |
|
This function inflates the tensor for times along dim. |
|
This function will mask some element in the tensor with fill_value, if condition=False. |
Reference
- class speechbrain.decoders.seq2seq.S2SBaseSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]
Bases:
Module
S2SBaseSearcher class to be inherited by other decoding approaches for seq2seq model.
- Parameters
bos_index (int) – The index of the beginning-of-sequence (bos) token.
eos_index (int) – The index of end-of-sequence token.
min_decode_radio (float) – The ratio of minimum decoding steps to the length of encoder states.
max_decode_radio (float) – The ratio of maximum decoding steps to the length of encoder states.
- Returns
predictions – Outputs as Python list of lists, with “ragged” dimensions; padding has been removed.
scores – The sum of log probabilities (and possibly additional heuristic scores) for each prediction.
- forward(enc_states, wav_len)[source]
This method should implement the forward algorithm of decoding method.
- Parameters
enc_states (torch.Tensor) – The precomputed encoder states to be used when decoding. (ex. the encoded speech representation to be attended).
wav_len (torch.Tensor) – The speechbrain-style relative length.
- forward_step(inp_tokens, memory, enc_states, enc_lens)[source]
This method should implement one step of forwarding operation in the autoregressive model.
- Parameters
inp_tokens (torch.Tensor) – The input tensor of the current timestep.
memory (No limit) – The memory variables input for this timestep. (ex. RNN hidden states).
enc_states (torch.Tensor) – The encoder states to be attended.
enc_lens (torch.Tensor) – The actual length of each enc_states sequence.
- Returns
log_probs (torch.Tensor) – Log-probabilities of the current timestep output.
memory (No limit) – The memory variables generated in this timestep. (ex. RNN hidden states).
attn (torch.Tensor) – The attention weight for doing penalty.
- reset_mem(batch_size, device)[source]
This method should implement the resetting of memory variables for the seq2seq model. E.g., initializing zero vector as initial hidden states.
- Parameters
batch_size (int) – The size of the batch.
device (torch.device) – The device to put the initial variables.
- Returns
memory – The initial memory variable.
- Return type
No limit
- lm_forward_step(inp_tokens, memory)[source]
This method should implement one step of forwarding operation for language model.
- Parameters
inp_tokens (torch.Tensor) – The input tensor of the current timestep.
memory (No limit) – The momory variables input for this timestep. (e.g., RNN hidden states).
- Returns
log_probs (torch.Tensor) – Log-probabilities of the current timestep output.
memory (No limit) – The memory variables generated in this timestep. (e.g., RNN hidden states).
- reset_lm_mem(batch_size, device)[source]
This method should implement the resetting of memory variables in the language model. E.g., initializing zero vector as initial hidden states.
- Parameters
batch_size (int) – The size of the batch.
device (torch.device) – The device to put the initial variables.
- Returns
memory – The initial memory variable.
- Return type
No limit
- class speechbrain.decoders.seq2seq.S2SGreedySearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]
Bases:
S2SBaseSearcher
This class implements the general forward-pass of greedy decoding approach. See also S2SBaseSearcher().
- forward(enc_states, wav_len)[source]
This method performs a greedy search. :param enc_states: The precomputed encoder states to be used when decoding.
(ex. the encoded speech representation to be attended).
- Parameters
wav_len (torch.Tensor) – The speechbrain-style relative length.
- class speechbrain.decoders.seq2seq.S2SWhisperGreedySearch(model, language_token=50259, bos_token=50258, task_token=50359, timestamp_token=50363, **kwargs)[source]
Bases:
S2SGreedySearcher
This class implements the greedy decoding for Whisper neural nets made by OpenAI in https://cdn.openai.com/papers/whisper.pdf.
- Parameters
model (HuggingFaceWhisper) – The Whisper model.
**kwargs – see S2SBaseSearcher, arguments are directly passed.
- set_language_token(language_token)[source]
set the language token to be used for the decoder input.
- set_timestamp_token(timestamp_token)[source]
set the timestamp token to be used for the decoder input.
- set_decoder_input_tokens(decoder_input_tokens)[source]
decoder_input_tokens are the tokens used as input to the decoder. They are directly taken from the tokenizer.prefix_tokens attribute.
decoder_input_tokens = [bos_token, language_token, task_token, timestamp_token]
- reset_mem(batch_size, device)[source]
This method set the first tokens to be decoder_input_tokens during search.
- class speechbrain.decoders.seq2seq.S2SRNNGreedySearcher(embedding, decoder, linear, **kwargs)[source]
Bases:
S2SGreedySearcher
This class implements the greedy decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). See also S2SBaseSearcher() and S2SGreedySearcher().
- Parameters
embedding (torch.nn.Module) – An embedding layer.
decoder (torch.nn.Module) – Attentional RNN decoder.
linear (torch.nn.Module) – A linear output layer.
**kwargs – see S2SBaseSearcher, arguments are directly passed.
Example
>>> emb = torch.nn.Embedding(5, 3) >>> dec = sb.nnet.RNN.AttentionalRNNDecoder( ... "gru", "content", 3, 3, 1, enc_dim=7, input_size=3 ... ) >>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3) >>> searcher = S2SRNNGreedySearcher( ... embedding=emb, ... decoder=dec, ... linear=lin, ... bos_index=4, ... eos_index=4, ... min_decode_ratio=0, ... max_decode_ratio=1, ... ) >>> enc = torch.rand([2, 6, 7]) >>> wav_len = torch.rand([2]) >>> hyps, scores = searcher(enc, wav_len)
- reset_mem(batch_size, device)[source]
When doing greedy search, keep hidden state (hs) adn context vector (c) as memory.
- class speechbrain.decoders.seq2seq.S2SBeamSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio, beam_size, topk=1, return_log_probs=False, using_eos_threshold=True, eos_threshold=1.5, length_normalization=True, length_rewarding=0, coverage_penalty=0.0, lm_weight=0.0, lm_modules=None, ctc_weight=0.0, blank_index=0, ctc_score_mode='full', ctc_window_size=0, using_max_attn_shift=False, max_attn_shift=60, minus_inf=-1e+20)[source]
Bases:
S2SBaseSearcher
This class implements the beam-search algorithm for the seq2seq model. See also S2SBaseSearcher().
- Parameters
bos_index (int) – The index of beginning-of-sequence token.
eos_index (int) – The index of end-of-sequence token.
min_decode_radio (float) – The ratio of minimum decoding steps to length of encoder states.
max_decode_radio (float) – The ratio of maximum decoding steps to length of encoder states.
beam_size (int) – The width of beam.
topk (int) – The number of hypothesis to return. (default: 1)
return_log_probs (bool) – Whether to return log-probabilities. (default: False)
using_eos_threshold (bool) – Whether to use eos threshold. (default: true)
eos_threshold (float) – The threshold coefficient for eos token (default: 1.5). See 3.1.2 in reference: https://arxiv.org/abs/1904.02619
length_normalization (bool) – Whether to divide the scores by the length. (default: True)
length_rewarding (float) – The coefficient of length rewarding (γ). log P(y|x) + λ log P_LM(y) + γ*len(y). (default: 0.0)
coverage_penalty (float) – The coefficient of coverage penalty (η). log P(y|x) + λ log P_LM(y) + γ*len(y) + η*coverage(x,y). (default: 0.0) Reference: https://arxiv.org/pdf/1612.02695.pdf, https://arxiv.org/pdf/1808.10792.pdf
lm_weight (float) – The weight of LM when performing beam search (λ). log P(y|x) + λ log P_LM(y). (default: 0.0)
ctc_weight (float) – The weight of CTC probabilities when performing beam search (λ). (1-λ) log P(y|x) + λ log P_CTC(y|x). (default: 0.0)
blank_index (int) – The index of the blank token.
ctc_score_mode (str) – Default: “full” CTC prefix scoring on “partial” token or “full: token.
ctc_window_size (int) – Default: 0 Compute the ctc scores over the time frames using windowing based on attention peaks. If 0, no windowing applied.
using_max_attn_shift (bool) – Whether using the max_attn_shift constraint. (default: False)
max_attn_shift (int) – Beam search will block the beams that attention shift more than max_attn_shift. Reference: https://arxiv.org/abs/1904.02619
minus_inf (float) – DefaultL -1e20 The value of minus infinity to block some path of the search.
- permute_mem(memory, index)[source]
This method permutes the seq2seq model memory to synchronize the memory index with the current output.
- Parameters
memory (No limit) – The memory variable to be permuted.
index (torch.Tensor) – The index of the previous path.
- Return type
The variable of the memory being permuted.
- permute_lm_mem(memory, index)[source]
This method permutes the language model memory to synchronize the memory index with the current output.
- Parameters
memory (No limit) – The memory variable to be permuted.
index (torch.Tensor) – The index of the previous path.
- Return type
The variable of the memory being permuted.
- class speechbrain.decoders.seq2seq.S2SRNNBeamSearcher(embedding, decoder, linear, ctc_linear=None, temperature=1.0, **kwargs)[source]
Bases:
S2SBeamSearcher
This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). See also S2SBaseSearcher(), S2SBeamSearcher().
- Parameters
embedding (torch.nn.Module) – An embedding layer.
decoder (torch.nn.Module) – Attentional RNN decoder.
linear (torch.nn.Module) – A linear output layer.
temperature (float) – Temperature factor applied to softmax. It changes the probability distribution, being softer when T>1 and sharper with T<1.
**kwargs – see S2SBeamSearcher, arguments are directly passed.
Example
>>> emb = torch.nn.Embedding(5, 3) >>> dec = sb.nnet.RNN.AttentionalRNNDecoder( ... "gru", "content", 3, 3, 1, enc_dim=7, input_size=3 ... ) >>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3) >>> ctc_lin = sb.nnet.linear.Linear(n_neurons=5, input_size=7) >>> searcher = S2SRNNBeamSearcher( ... embedding=emb, ... decoder=dec, ... linear=lin, ... ctc_linear=ctc_lin, ... bos_index=4, ... eos_index=4, ... blank_index=4, ... min_decode_ratio=0, ... max_decode_ratio=1, ... beam_size=2, ... ) >>> enc = torch.rand([2, 6, 7]) >>> wav_len = torch.rand([2]) >>> hyps, scores = searcher(enc, wav_len)
- class speechbrain.decoders.seq2seq.S2SRNNBeamSearchLM(embedding, decoder, linear, language_model, temperature_lm=1.0, **kwargs)[source]
Bases:
S2SRNNBeamSearcher
This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py) with LM. See also S2SBaseSearcher(), S2SBeamSearcher(), S2SRNNBeamSearcher().
- Parameters
embedding (torch.nn.Module) – An embedding layer.
decoder (torch.nn.Module) – Attentional RNN decoder.
linear (torch.nn.Module) – A linear output layer.
language_model (torch.nn.Module) – A language model.
temperature_lm (float) – Temperature factor applied to softmax. It changes the probability distribution, being softer when T>1 and sharper with T<1.
**kwargs – Arguments to pass to S2SBeamSearcher.
Example
>>> from speechbrain.lobes.models.RNNLM import RNNLM >>> emb = torch.nn.Embedding(5, 3) >>> dec = sb.nnet.RNN.AttentionalRNNDecoder( ... "gru", "content", 3, 3, 1, enc_dim=7, input_size=3 ... ) >>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3) >>> lm = RNNLM(output_neurons=5, return_hidden=True) >>> searcher = S2SRNNBeamSearchLM( ... embedding=emb, ... decoder=dec, ... linear=lin, ... language_model=lm, ... bos_index=4, ... eos_index=4, ... blank_index=4, ... min_decode_ratio=0, ... max_decode_ratio=1, ... beam_size=2, ... lm_weight=0.5, ... ) >>> enc = torch.rand([2, 6, 7]) >>> wav_len = torch.rand([2]) >>> hyps, scores = searcher(enc, wav_len)
- class speechbrain.decoders.seq2seq.S2SRNNBeamSearchTransformerLM(embedding, decoder, linear, language_model, temperature_lm=1.0, **kwargs)[source]
Bases:
S2SRNNBeamSearcher
This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py) with LM. See also S2SBaseSearcher(), S2SBeamSearcher(), S2SRNNBeamSearcher().
- Parameters
embedding (torch.nn.Module) – An embedding layer.
decoder (torch.nn.Module) – Attentional RNN decoder.
linear (torch.nn.Module) – A linear output layer.
language_model (torch.nn.Module) – A language model.
temperature_lm (float) – Temperature factor applied to softmax. It changes the probability distribution, being softer when T>1 and sharper with T<1.
**kwargs – Arguments to pass to S2SBeamSearcher.
Example
>>> from speechbrain.lobes.models.transformer.TransformerLM import TransformerLM >>> emb = torch.nn.Embedding(5, 3) >>> dec = sb.nnet.RNN.AttentionalRNNDecoder( ... "gru", "content", 3, 3, 1, enc_dim=7, input_size=3 ... ) >>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3) >>> lm = TransformerLM(5, 512, 8, 1, 0, 1024, activation=torch.nn.GELU) >>> searcher = S2SRNNBeamSearchTransformerLM( ... embedding=emb, ... decoder=dec, ... linear=lin, ... language_model=lm, ... bos_index=4, ... eos_index=4, ... blank_index=4, ... min_decode_ratio=0, ... max_decode_ratio=1, ... beam_size=2, ... lm_weight=0.5, ... ) >>> enc = torch.rand([2, 6, 7]) >>> wav_len = torch.rand([2]) >>> hyps, scores = searcher(enc, wav_len)
- speechbrain.decoders.seq2seq.inflate_tensor(tensor, times, dim)[source]
This function inflates the tensor for times along dim.
- Parameters
tensor (torch.Tensor) – The tensor to be inflated.
times (int) – The tensor will inflate for this number of times.
dim (int) – The dim to be inflated.
- Returns
The inflated tensor.
- Return type
Example
>>> tensor = torch.Tensor([[1,2,3], [4,5,6]]) >>> new_tensor = inflate_tensor(tensor, 2, dim=0) >>> new_tensor tensor([[1., 2., 3.], [1., 2., 3.], [4., 5., 6.], [4., 5., 6.]])
- speechbrain.decoders.seq2seq.mask_by_condition(tensor, cond, fill_value)[source]
This function will mask some element in the tensor with fill_value, if condition=False.
- Parameters
tensor (torch.Tensor) – The tensor to be masked.
cond (torch.BoolTensor) – This tensor has to be the same size as tensor. Each element represents whether to keep the value in tensor.
fill_value (float) – The value to fill in the masked element.
- Returns
The masked tensor.
- Return type
Example
>>> tensor = torch.Tensor([[1,2,3], [4,5,6]]) >>> cond = torch.BoolTensor([[True, True, False], [True, False, False]]) >>> mask_by_condition(tensor, cond, 0) tensor([[1., 2., 0.], [4., 0., 0.]])
- class speechbrain.decoders.seq2seq.S2STransformerBeamSearch(modules, temperature=1.0, temperature_lm=1.0, **kwargs)[source]
Bases:
S2SBeamSearcher
This class implements the beam search decoding for Transformer. See also S2SBaseSearcher(), S2SBeamSearcher().
- Parameters
model (torch.nn.Module) – The model to use for decoding.
linear (torch.nn.Module) – A linear output layer.
**kwargs – Arguments to pass to S2SBeamSearcher
Example –
-------- –
recipes/LibriSpeech/ASR_transformer/experiment.py (>>> # see) –
- class speechbrain.decoders.seq2seq.S2SWhisperBeamSearch(module, temperature=1.0, temperature_lm=1.0, language_token=50259, bos_token=50258, task_token=50359, timestamp_token=50363, **kwargs)[source]
Bases:
S2SBeamSearcher
This class implements the beam search decoding for Whisper neural nets made by OpenAI in https://cdn.openai.com/papers/whisper.pdf.
- Parameters
module (list with the followings one:) –
- modeltorch.nn.Module
A whisper model. It should have a decode() method.
- ctc_lintorch.nn.Module (optional)
A linear output layer for CTC.
**kwargs – Arguments to pass to S2SBeamSearcher
- set_decoder_input_tokens(decoder_input_tokens)[source]
decoder_input_tokens are the tokens used as input to the decoder. They are directly taken from the tokenizer.prefix_tokens attribute.
decoder_input_tokens = [bos_token, language_token, task_token, timestamp_token]
- reset_mem(batch_size, device)[source]
This method set the first tokens to be decoder_input_tokens during search.
- speechbrain.decoders.seq2seq.batch_filter_seq2seq_output(prediction, eos_id=-1)[source]
Calling batch_size times of filter_seq2seq_output.
- Parameters
prediction (list of torch.Tensor) – A list containing the output ints predicted by the seq2seq system.
eos_id (int, string) – The id of the eos.
- Returns
The output predicted by seq2seq model.
- Return type
Example
>>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])] >>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4) >>> predictions [[1, 2, 3], [2, 3]]
- speechbrain.decoders.seq2seq.filter_seq2seq_output(string_pred, eos_id=-1)[source]
Filter the output until the first eos occurs (exclusive).
- Parameters
- Returns
The output predicted by seq2seq model.
- Return type
Example
>>> string_pred = ['a','b','c','d','eos','e'] >>> string_out = filter_seq2seq_output(string_pred, eos_id='eos') >>> string_out ['a', 'b', 'c', 'd']