speechbrain.decoders.seq2seq module
Decoding methods for seq2seq autoregressive model.
- Authors
Adel Moumen 2022, 2023, 2024
Ju-Chieh Chou 2020
Peter Plantinga 2020
Mirco Ravanelli 2020
Sung-Lin Yeh 2020
Summary
Classes:
This class handle the data for the hypotheses during the decoding. |
|
S2SBaseSearcher class to be inherited by other decoding approaches for seq2seq model. |
|
This class implements the beam-search algorithm for the seq2seq model. |
|
This class implements the general forward-pass of greedy decoding approach. |
|
This class implements the beam search decoding for the text-based HF seq2seq models, such as mBART or NLLB. |
|
This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). |
|
This class implements the greedy decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). |
|
This class implements the beam search decoding for Transformer. |
|
This class implements the greedy decoding for Transformer. |
|
This class implements the beam search decoding for Whisper neural nets made by OpenAI in https://cdn.openai.com/papers/whisper.pdf. |
|
This class implements the greedy decoding for Whisper neural nets made by OpenAI in https://cdn.openai.com/papers/whisper.pdf. |
Reference
- class speechbrain.decoders.seq2seq.AlivedHypotheses(alived_seq, alived_log_probs, sequence_scores)[source]
Bases:
Module
This class handle the data for the hypotheses during the decoding.
- Parameters:
alived_seq (torch.Tensor) – The sequence of tokens for each hypothesis.
alived_log_probs (torch.Tensor) – The log probabilities of each token for each hypothesis.
sequence_scores (torch.Tensor) – The sum of log probabilities for each hypothesis.
- class speechbrain.decoders.seq2seq.S2SBaseSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]
Bases:
Module
S2SBaseSearcher class to be inherited by other decoding approaches for seq2seq model.
- Parameters:
bos_index (int) – The index of the beginning-of-sequence (bos) token.
eos_index (int) – The index of end-of-sequence (eos) token.
min_decode_ratio (float) – The ratio of minimum decoding steps to the length of encoder states.
max_decode_ratio (float) – The ratio of maximum decoding steps to the length of encoder states.
- forward(enc_states, wav_len)[source]
This method should implement the forward algorithm of decoding method.
- Parameters:
enc_states (torch.Tensor) – The precomputed encoder states to be used when decoding. (ex. the encoded speech representation to be attended).
wav_len (torch.Tensor) – The speechbrain-style relative length.
- Returns:
hyps – The predicted tokens, as a list of lists or, if return_topk is True, a Tensor of shape (batch, topk, max length of token_id sequences).
top_lengths – The length of each topk sequence in the batch.
top_scores – This final scores of topk hypotheses.
top_log_probs – The log probabilities of each hypotheses.
- forward_step(inp_tokens, memory, enc_states, enc_lens)[source]
This method should implement one step of forwarding operation in the autoregressive model.
- Parameters:
inp_tokens (torch.Tensor) – The input tensor of the current step.
memory (No limit) – The memory variables input for this step. (ex. RNN hidden states).
enc_states (torch.Tensor) – The encoder states to be attended.
enc_lens (torch.Tensor) – The actual length of each enc_states sequence.
- Returns:
log_probs (torch.Tensor) – Log-probabilities of the current step output.
memory (No limit) – The memory variables generated in this step. (ex. RNN hidden states).
attn (torch.Tensor) – The attention weight for doing penalty.
- reset_mem(batch_size, device)[source]
This method should implement the resetting of memory variables for the seq2seq model. E.g., initializing zero vector as initial hidden states.
- Parameters:
batch_size (int) – The size of the batch.
device (torch.device) – The device to put the initial variables.
- Returns:
memory – The initial memory variable.
- Return type:
No limit
- class speechbrain.decoders.seq2seq.S2SGreedySearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]
Bases:
S2SBaseSearcher
This class implements the general forward-pass of greedy decoding approach. See also S2SBaseSearcher().
- forward(enc_states, wav_len)[source]
This method performs a greedy search.
- Parameters:
enc_states (torch.Tensor) – The precomputed encoder states to be used when decoding. (ex. the encoded speech representation to be attended).
wav_len (torch.Tensor) – The speechbrain-style relative length.
- Returns:
hyps (List[List[int]]) – List containing the hypotheses.
top_lengths (torch.Tensor (batch)) – This tensor contains the length of each hypothesis.
top_scores (torch.Tensor (batch)) – The score of each hypotheses.
top_log_probs (torch.Tensor (batch, max length of token_id sequences)) – The log probabilities of each hypotheses.
- class speechbrain.decoders.seq2seq.S2STransformerGreedySearcher(modules, temperature=0.0, **kwargs)[source]
Bases:
S2SGreedySearcher
This class implements the greedy decoding for Transformer.
- Parameters:
modules (list with the following one:) –
- modeltorch.nn.Module
A TransformerASR model.
- seq_lintorch.nn.Module
A linear output layer for the seq2seq model.
temperature (float) – Temperature to use during decoding.
**kwargs – Arguments to pass to S2SGreedySearcher
- class speechbrain.decoders.seq2seq.S2SWhisperGreedySearcher(model, temperature=0.0, use_kv_cache=True, suppress_blank=True, suppress_tokens='-1', sample_len=None, prefix=None, prompt=None, **kwargs)[source]
Bases:
S2SGreedySearcher
This class implements the greedy decoding for Whisper neural nets made by OpenAI in https://cdn.openai.com/papers/whisper.pdf.
- Parameters:
model (HuggingFaceWhisper) – The Whisper model.
temperature (float) – The temperature to use during decoding.
use_kv_cache (bool (default: True)) – Whether to use key-value cache.
suppress_blank (bool (default: True)) – This will suppress blank outputs.
suppress_tokens (str or list (default: "-1")) – list of tokens ids (or comma-separated token ids) to suppress “-1” will suppress a set of symbols as defined in
model.non_speech_tokens()
sample_len (int (default: None)) – Maximum number of tokens to sample.
prefix (str or list (default: None)) – Prefix to add to the input tokens. See: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt (str or list (default: None)) – Prompt to add to the input tokens. See: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
**kwargs – see S2SBaseSearcher, arguments are directly passed.
- property get_tokens_to_suppress
Get the tokens to suppress during decoding if self.config.suppress_tokens is None.
- class speechbrain.decoders.seq2seq.S2SRNNGreedySearcher(embedding, decoder, linear, temperature=0.0, **kwargs)[source]
Bases:
S2SGreedySearcher
This class implements the greedy decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). See also S2SBaseSearcher() and S2SGreedySearcher().
- Parameters:
embedding (torch.nn.Module) – An embedding layer.
decoder (torch.nn.Module) – Attentional RNN decoder.
linear (torch.nn.Module) – A linear output layer.
temperature (float) – The temperature to use during decoding.
**kwargs – see S2SBaseSearcher, arguments are directly passed.
Example
>>> import speechbrain as sb >>> from speechbrain.decoders import S2SRNNGreedySearcher >>> emb = torch.nn.Embedding(5, 3) >>> dec = sb.nnet.RNN.AttentionalRNNDecoder( ... "gru", "content", 3, 3, 1, enc_dim=7, input_size=3 ... ) >>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3) >>> searcher = S2SRNNGreedySearcher( ... embedding=emb, ... decoder=dec, ... linear=lin, ... bos_index=0, ... eos_index=1, ... min_decode_ratio=0, ... max_decode_ratio=1, ... ) >>> batch_size = 2 >>> enc = torch.rand([batch_size, 6, 7]) >>> wav_len = torch.ones([batch_size]) >>> top_hyps, top_lengths, _, _ = searcher(enc, wav_len)
- class speechbrain.decoders.seq2seq.S2SBeamSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio, beam_size, scorer=None, return_topk=False, topk=1, using_eos_threshold=True, eos_threshold=1.5, length_normalization=True, using_max_attn_shift=False, max_attn_shift=60, minus_inf=-1e+20)[source]
Bases:
S2SBaseSearcher
This class implements the beam-search algorithm for the seq2seq model. See also S2SBaseSearcher().
- Parameters:
bos_index (int) – The index of beginning-of-sequence token.
eos_index (int) – The index of end-of-sequence token.
min_decode_ratio (float) – The ratio of minimum decoding steps to length of encoder states.
max_decode_ratio (float) – The ratio of maximum decoding steps to length of encoder states.
beam_size (int) – The width of beam.
scorer (speechbrain.decoders.scorers.ScorerBuilder) – Scorer instance. Default: None.
return_topk (bool) – Whether to return topk hypotheses. The topk hypotheses will be padded to the same length. Default: False.
topk (int) – If return_topk is True, then return topk hypotheses. Default: 1.
using_eos_threshold (bool) – Whether to use eos threshold. Default: True.
eos_threshold (float) – The threshold coefficient for eos token. Default: 1.5. See 3.1.2 in reference: https://arxiv.org/abs/1904.02619
length_normalization (bool) – Whether to divide the scores by the length. Default: True.
using_max_attn_shift (bool) – Whether using the max_attn_shift constraint. Default: False.
max_attn_shift (int) – Beam search will block the beams that attention shift more than max_attn_shift. Default: 60. Reference: https://arxiv.org/abs/1904.02619
minus_inf (float) – The value of minus infinity to block some path of the search. Default: -1e20.
- init_hypotheses()[source]
This method initializes the AlivedHypotheses object.
- Returns:
The alived hypotheses filled with the initial values.
- Return type:
- init_beam_search_data(enc_states, wav_len)[source]
Initialize the beam search data.
- Parameters:
enc_states (torch.Tensor) – The encoder states to be attended.
wav_len (torch.Tensor) – The actual length of each enc_states sequence.
- Returns:
alived_hyps (AlivedHypotheses) – The alived hypotheses.
inp_tokens (torch.Tensor) – The input tensor of the current step.
log_probs (torch.Tensor) – The log-probabilities of the current step output.
eos_hyps_and_log_probs_scores (list) – Generated hypotheses (the ones that have reached eos) and log probs scores.
memory (No limit) – The memory variables generated in this step.
scorer_memory (No limit) – The memory variables generated in this step.
attn (torch.Tensor) – The attention weight.
prev_attn_peak (torch.Tensor) – The previous attention peak place.
enc_states (torch.Tensor) – The encoder states to be attended.
enc_lens (torch.Tensor) – The actual length of each enc_states sequence.
- search_step(alived_hyps, inp_tokens, log_probs, eos_hyps_and_log_probs_scores, memory, scorer_memory, attn, prev_attn_peak, enc_states, enc_lens, step)[source]
A search step for the next most likely tokens.
- Parameters:
alived_hyps (AlivedHypotheses) – The alived hypotheses.
inp_tokens (torch.Tensor) – The input tensor of the current step.
log_probs (torch.Tensor) – The log-probabilities of the current step output.
eos_hyps_and_log_probs_scores (list) – Generated hypotheses (the ones that have reached eos) and log probs scores.
memory (No limit) – The memory variables input for this step. (ex. RNN hidden states).
scorer_memory (No limit) – The memory variables input for this step. (ex. RNN hidden states).
attn (torch.Tensor) – The attention weight.
prev_attn_peak (torch.Tensor) – The previous attention peak place.
enc_states (torch.Tensor) – The encoder states to be attended.
enc_lens (torch.Tensor) – The actual length of each enc_states sequence.
step (int) – The current decoding step.
- Returns:
alived_hyps (AlivedHypotheses) – The alived hypotheses.
inp_tokens (torch.Tensor) – The input tensor of the current step.
log_probs (torch.Tensor) – The log-probabilities of the current step output.
eos_hyps_and_log_probs_scores (list) – Generated hypotheses (the ones that have reached eos) and log probs scores.
memory (No limit) – The memory variables generated in this step.
scorer_memory (No limit) – The memory variables generated in this step.
attn (torch.Tensor) – The attention weight.
prev_attn_peak (torch.Tensor) – The previous attention peak place.
scores (torch.Tensor) – The scores of the current step output.
- forward(enc_states, wav_len)[source]
Applies beamsearch and returns the predicted tokens.
- Parameters:
enc_states (torch.Tensor) – The encoder states to be attended.
wav_len (torch.Tensor) – The actual length of each enc_states sequence.
- Returns:
hyps (list) – The predicted tokens.
best_lens (torch.Tensor) – The length of each predicted tokens.
best_scores (torch.Tensor) – The scores of each predicted tokens.
best_log_probs (torch.Tensor) – The log probabilities of each predicted tokens.
- permute_mem(memory, index)[source]
This method permutes the seq2seq model memory to synchronize the memory index with the current output.
- Parameters:
memory (No limit) – The memory variable to be permuted.
index (torch.Tensor) – The index of the previous path.
- Return type:
The variable of the memory being permuted.
- class speechbrain.decoders.seq2seq.S2SRNNBeamSearcher(embedding, decoder, linear, temperature=1.0, **kwargs)[source]
Bases:
S2SBeamSearcher
This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). See also S2SBaseSearcher(), S2SBeamSearcher().
- Parameters:
embedding (torch.nn.Module) – An embedding layer.
decoder (torch.nn.Module) – Attentional RNN decoder.
linear (torch.nn.Module) – A linear output layer.
temperature (float) – Temperature factor applied to softmax. It changes the probability distribution, being softer when T>1 and sharper with T<1.
**kwargs – see S2SBeamSearcher, arguments are directly passed.
Example
>>> import speechbrain as sb >>> vocab_size = 5 >>> emb = torch.nn.Embedding(vocab_size, 3) >>> dec = sb.nnet.RNN.AttentionalRNNDecoder( ... "gru", "content", 3, 3, 1, enc_dim=7, input_size=3 ... ) >>> lin = sb.nnet.linear.Linear(n_neurons=vocab_size, input_size=3) >>> coverage_scorer = sb.decoders.scorer.CoverageScorer(vocab_size) >>> scorer = sb.decoders.scorer.ScorerBuilder( ... full_scorers = [coverage_scorer], ... partial_scorers = [], ... weights= dict(coverage=1.5) ... ) >>> searcher = S2SRNNBeamSearcher( ... embedding=emb, ... decoder=dec, ... linear=lin, ... bos_index=4, ... eos_index=4, ... min_decode_ratio=0, ... max_decode_ratio=1, ... beam_size=2, ... scorer=scorer, ... ) >>> batch_size = 2 >>> enc = torch.rand([batch_size, 6, 7]) >>> wav_len = torch.ones([batch_size]) >>> hyps, _, _, _ = searcher(enc, wav_len)
- class speechbrain.decoders.seq2seq.S2STransformerBeamSearcher(modules, temperature=1.0, **kwargs)[source]
Bases:
S2SBeamSearcher
This class implements the beam search decoding for Transformer. See also S2SBaseSearcher(), S2SBeamSearcher().
- Parameters:
modules (list with the following one:) –
- modeltorch.nn.Module
A Transformer model.
- seq_lintorch.nn.Module
A linear output layer.
temperature (float) – Temperature factor applied to softmax. It changes the probability distribution, being softer when T>1 and sharper with T<1.
**kwargs – Arguments to pass to S2SBeamSearcher
Example
>>> from speechbrain.nnet.linear import Linear >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR >>> from speechbrain.decoders import S2STransformerBeamSearcher >>> batch_size=8 >>> n_channels=6 >>> input_size=40 >>> d_model=128 >>> tgt_vocab=140 >>> src = torch.rand([batch_size, n_channels, input_size]) >>> tgt = torch.randint(0, tgt_vocab, [batch_size, n_channels]) >>> net = TransformerASR( ... tgt_vocab, input_size, d_model, 8, 1, 1, 1024, activation=torch.nn.GELU ... ) >>> ctc_lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab) >>> lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab) >>> searcher = S2STransformerBeamSearcher( ... modules=[net, lin], ... bos_index=1, ... eos_index=2, ... min_decode_ratio=0.0, ... max_decode_ratio=1.0, ... using_eos_threshold=False, ... beam_size=7, ... temperature=1.15, ... ) >>> enc, dec = net.forward(src, tgt) >>> hyps, _, _, _ = searcher(enc, torch.ones(batch_size))
- class speechbrain.decoders.seq2seq.S2SWhisperBeamSearcher(module, temperature=1.0, use_kv_cache=True, suppress_blank=True, suppress_tokens='-1', sample_len=None, prefix=None, prompt=None, **kwargs)[source]
Bases:
S2SBeamSearcher
This class implements the beam search decoding for Whisper neural nets made by OpenAI in https://cdn.openai.com/papers/whisper.pdf.
The beam search is stateful, meaning that some variables are stored in the searcher. If you want to reuse the searcher in different contexts, you should make sure that the variables are updated accordingly.
- Parameters:
module (list with the following one:) –
- modeltorch.nn.Module
A whisper model. It should have a decode() method.
temperature (float) – The temperature to use during decoding.
use_kv_cache (bool (default: True)) – Whether to use key-value cache.
suppress_blank (bool (default: True)) – This will suppress blank outputs.
suppress_tokens (str or list (default: "-1")) – list of tokens ids (or comma-separated token ids) to suppress “-1” will suppress a set of symbols as defined in
model.non_speech_tokens()
sample_len (int (default: None)) – Maximum number of tokens to sample.
prefix (str or list (default: None)) – Prefix to add to the input tokens. See: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt (str or list (default: None)) – Prompt to add to the input tokens. See: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
**kwargs – see S2SBeamSearcher, arguments are directly passed.
- property get_tokens_to_suppress
Get the tokens to suppress during decoding if self.config.suppress_tokens is None.
- class speechbrain.decoders.seq2seq.S2SHFTextBasedBeamSearcher(modules, vocab_size, **kwargs)[source]
Bases:
S2STransformerBeamSearcher
This class implements the beam search decoding for the text-based HF seq2seq models, such as mBART or NLLB. It is NOT significantly different from S2STransformerBeamSearcher. This is why it inherits S2STransformerBeamSearcher. The main difference might arise when one wishes to use directly the lm_head of the text-based HF model rather than making a new projection layer (self.fc = None).
- Parameters:
modules (list with the following one:) –
- modeltorch.nn.Module
A Transformer model.
- seq_lintorch.nn.Module
A linear output layer. Normally set to None for this usecase.
vocab_size (int) – The dimension of the lm_head.
**kwargs – Arguments to pass to S2SBeamSearcher