speechbrain.decoders.seq2seq module

Decoding methods for seq2seq autoregressive model.

Authors
  • Adel Moumen 2022, 2023

  • Ju-Chieh Chou 2020

  • Peter Plantinga 2020

  • Mirco Ravanelli 2020

  • Sung-Lin Yeh 2020

Summary

Classes:

AlivedHypotheses

This class handle the data for the hypotheses during the decoding.

S2SBaseSearcher

S2SBaseSearcher class to be inherited by other decoding approaches for seq2seq model.

S2SBeamSearcher

This class implements the beam-search algorithm for the seq2seq model.

S2SGreedySearcher

This class implements the general forward-pass of greedy decoding approach.

S2SHFTextBasedBeamSearcher

This class implements the beam search decoding for the text-based HF seq2seq models, such as mBART or NLLB.

S2SRNNBeamSearcher

This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py).

S2SRNNGreedySearcher

This class implements the greedy decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py).

S2STransformerBeamSearcher

This class implements the beam search decoding for Transformer. See also S2SBaseSearcher(), S2SBeamSearcher(). :param modules: model : torch.nn.Module A Transformer model. seq_lin : torch.nn.Module A linear output layer. :type modules: list with the followings one: :param linear: A linear output layer. :type linear: torch.nn.Module :param **kwargs: Arguments to pass to S2SBeamSearcher.

S2STransformerGreedySearch

This class implements the greedy decoding for Transformer.

S2SWhisperBeamSearch

This class implements the beam search decoding for Whisper neural nets made by OpenAI in https://cdn.openai.com/papers/whisper.pdf. :param module: model : torch.nn.Module A whisper model. It should have a decode() method. ctc_lin : torch.nn.Module (optional) A linear output layer for CTC. :type module: list with the followings one: :param language_token: The token to use for language. :type language_token: int :param bos_token: The token to use for beginning of sentence. :type bos_token: int :param task_token: The token to use for task. :type task_token: int :param timestamp_token: The token to use for timestamp. :type timestamp_token: int :param max_length: The maximum decoding steps to perform. The Whisper model has a maximum length of 448. :type max_length: int :param **kwargs: Arguments to pass to S2SBeamSearcher.

S2SWhisperGreedySearch

This class implements the greedy decoding for Whisper neural nets made by OpenAI in https://cdn.openai.com/papers/whisper.pdf. :param model: The Whisper model. :type model: HuggingFaceWhisper :param language_token: The language token to be used for the decoder input. :type language_token: int :param bos_token: The beginning of sentence token to be used for the decoder input. :type bos_token: int :param task_token: The task token to be used for the decoder input. :type task_token: int :param timestamp_token: The timestamp token to be used for the decoder input. :type timestamp_token: int :param max_length: The maximum decoding steps to perform. The Whisper model has a maximum length of 448. :type max_length: int :param **kwargs: see S2SBaseSearcher, arguments are directly passed.

Reference

class speechbrain.decoders.seq2seq.AlivedHypotheses(alived_seq, alived_log_probs, sequence_scores)[source]

Bases: Module

This class handle the data for the hypotheses during the decoding.

Parameters:
  • alived_seq (torch.Tensor) – The sequence of tokens for each hypothesis.

  • alived_log_probs (torch.Tensor) – The log probabilities of each token for each hypothesis.

  • sequence_scores (torch.Tensor) – The sum of log probabilities for each hypothesis.

training: bool
class speechbrain.decoders.seq2seq.S2SBaseSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]

Bases: Module

S2SBaseSearcher class to be inherited by other decoding approaches for seq2seq model.

Parameters:
  • bos_index (int) – The index of the beginning-of-sequence (bos) token.

  • eos_index (int) – The index of end-of-sequence (eos) token.

  • min_decode_radio (float) – The ratio of minimum decoding steps to the length of encoder states.

  • max_decode_radio (float) – The ratio of maximum decoding steps to the length of encoder states.

Returns:

  • hyps – The predicted tokens, as a list of lists or, if return_topk is True, a Tensor of shape (batch, topk, max length of token_id sequences).

  • top_lengths – The length of each topk sequence in the batch.

  • top_scores – This final scores of topk hypotheses.

  • top_log_probs – The log probabilities of each hypotheses.

forward(enc_states, wav_len)[source]

This method should implement the forward algorithm of decoding method.

Parameters:
  • enc_states (torch.Tensor) – The precomputed encoder states to be used when decoding. (ex. the encoded speech representation to be attended).

  • wav_len (torch.Tensor) – The speechbrain-style relative length.

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

This method should implement one step of forwarding operation in the autoregressive model.

Parameters:
  • inp_tokens (torch.Tensor) – The input tensor of the current step.

  • memory (No limit) – The memory variables input for this step. (ex. RNN hidden states).

  • enc_states (torch.Tensor) – The encoder states to be attended.

  • enc_lens (torch.Tensor) – The actual length of each enc_states sequence.

Returns:

  • log_probs (torch.Tensor) – Log-probabilities of the current step output.

  • memory (No limit) – The memory variables generated in this step. (ex. RNN hidden states).

  • attn (torch.Tensor) – The attention weight for doing penalty.

reset_mem(batch_size, device)[source]

This method should implement the resetting of memory variables for the seq2seq model. E.g., initializing zero vector as initial hidden states.

Parameters:
  • batch_size (int) – The size of the batch.

  • device (torch.device) – The device to put the initial variables.

Returns:

memory – The initial memory variable.

Return type:

No limit

change_max_decoding_length(min_decode_steps, max_decode_steps)[source]

set the minimum/maximum length the decoder can take.

set_n_out()[source]

set the number of output tokens. Overrides this function if the fc layer is embedded in the model, e.g., Whisper.

training: bool
class speechbrain.decoders.seq2seq.S2SGreedySearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]

Bases: S2SBaseSearcher

This class implements the general forward-pass of greedy decoding approach. See also S2SBaseSearcher().

forward(enc_states, wav_len)[source]

This method performs a greedy search.

Parameters:
  • enc_states (torch.Tensor) – The precomputed encoder states to be used when decoding. (ex. the encoded speech representation to be attended).

  • wav_len (torch.Tensor) – The speechbrain-style relative length.

Returns:

  • hyps (List containing hypotheses.)

  • top_lengths (torch.Tensor (batch)) – This tensor contains the final scores of hypotheses.

  • top_scores (torch.Tensor (batch)) – The length of each topk sequence in the batch.

  • top_log_probs (torch.Tensor (batch, max length of token_id sequences)) – The log probabilities of each hypotheses.

training: bool
class speechbrain.decoders.seq2seq.S2SRNNGreedySearcher(embedding, decoder, linear, **kwargs)[source]

Bases: S2SGreedySearcher

This class implements the greedy decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). See also S2SBaseSearcher() and S2SGreedySearcher().

Parameters:
  • embedding (torch.nn.Module) – An embedding layer.

  • decoder (torch.nn.Module) – Attentional RNN decoder.

  • linear (torch.nn.Module) – A linear output layer.

  • **kwargs – see S2SBaseSearcher, arguments are directly passed.

Example

>>> import speechbrain as sb
>>> from speechbrain.decoders import S2SRNNGreedySearcher
>>> emb = torch.nn.Embedding(5, 3)
>>> dec = sb.nnet.RNN.AttentionalRNNDecoder(
...     "gru", "content", 3, 3, 1, enc_dim=7, input_size=3
... )
>>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3)
>>> searcher = S2SRNNGreedySearcher(
...     embedding=emb,
...     decoder=dec,
...     linear=lin,
...     bos_index=0,
...     eos_index=1,
...     min_decode_ratio=0,
...     max_decode_ratio=1,
... )
>>> batch_size = 2
>>> enc = torch.rand([batch_size, 6, 7])
>>> wav_len = torch.ones([batch_size])
>>> top_hyps, top_lengths, _, _ = searcher(enc, wav_len)
reset_mem(batch_size, device)[source]

When doing greedy search, keep hidden state (hs) and context vector (c) as memory.

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

Performs a step in the implemented beamsearcher.

training: bool
class speechbrain.decoders.seq2seq.S2SBeamSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio, beam_size, scorer=None, return_topk=False, topk=1, using_eos_threshold=True, eos_threshold=1.5, length_normalization=True, using_max_attn_shift=False, max_attn_shift=60, minus_inf=-1e+20)[source]

Bases: S2SBaseSearcher

This class implements the beam-search algorithm for the seq2seq model. See also S2SBaseSearcher().

Parameters:
  • bos_index (int) – The index of beginning-of-sequence token.

  • eos_index (int) – The index of end-of-sequence token.

  • min_decode_radio (float) – The ratio of minimum decoding steps to length of encoder states.

  • max_decode_radio (float) – The ratio of maximum decoding steps to length of encoder states.

  • beam_size (int) – The width of beam.

  • scorer (speechbrain.decoders.scorers.ScorerBuilder) – Scorer instance. Default: None.

  • return_topk (bool) – Whether to return topk hypotheses. The topk hypotheses will be padded to the same length. Default: False.

  • topk (int) – If return_topk is True, then return topk hypotheses. Default: 1.

  • using_eos_threshold (bool) – Whether to use eos threshold. Default: True.

  • eos_threshold (float) – The threshold coefficient for eos token. Default: 1.5. See 3.1.2 in reference: https://arxiv.org/abs/1904.02619

  • length_normalization (bool) – Whether to divide the scores by the length. Default: True.

  • using_max_attn_shift (bool) – Whether using the max_attn_shift constraint. Default: False.

  • max_attn_shift (int) – Beam search will block the beams that attention shift more than max_attn_shift. Default: 60. Reference: https://arxiv.org/abs/1904.02619

  • minus_inf (float) – The value of minus infinity to block some path of the search. Default: -1e20.

init_hypotheses()[source]

This method initializes the AlivedHypotheses object.

Returns:

The alived hypotheses filled with the initial values.

Return type:

AlivedHypotheses

init_beam_search_data(enc_states, wav_len)[source]

Initialize the beam search data.

Parameters:
  • enc_states (torch.Tensor) – The encoder states to be attended.

  • wav_len (torch.Tensor) – The actual length of each enc_states sequence.

Returns:

  • alived_hyps (AlivedHypotheses) – The alived hypotheses.

  • inp_tokens (torch.Tensor) – The input tensor of the current step.

  • log_probs (torch.Tensor) – The log-probabilities of the current step output.

  • eos_hyps_and_log_probs_scores (list) – Generated hypotheses (the one that haved reached eos) and log probs scores.

  • memory (No limit) – The memory variables generated in this step.

  • scorer_memory (No limit) – The memory variables generated in this step.

  • attn (torch.Tensor) – The attention weight.

  • prev_attn_peak (torch.Tensor) – The previous attention peak place.

  • enc_states (torch.Tensor) – The encoder states to be attended.

  • enc_lens (torch.Tensor) – The actual length of each enc_states sequence.

search_step(alived_hyps, inp_tokens, log_probs, eos_hyps_and_log_probs_scores, memory, scorer_memory, attn, prev_attn_peak, enc_states, enc_lens, step)[source]

A search step for the next most likely tokens.

Parameters:
  • alived_hyps (AlivedHypotheses) – The alived hypotheses.

  • inp_tokens (torch.Tensor) – The input tensor of the current step.

  • log_probs (torch.Tensor) – The log-probabilities of the current step output.

  • eos_hyps_and_log_probs_scores (list) – Generated hypotheses (the one that haved reached eos) and log probs scores.

  • memory (No limit) – The memory variables input for this step. (ex. RNN hidden states).

  • scorer_memory (No limit) – The memory variables input for this step. (ex. RNN hidden states).

  • attn (torch.Tensor) – The attention weight.

  • prev_attn_peak (torch.Tensor) – The previous attention peak place.

  • enc_states (torch.Tensor) – The encoder states to be attended.

  • enc_lens (torch.Tensor) – The actual length of each enc_states sequence.

  • step (int) – The current decoding step.

Returns:

  • alived_hyps (AlivedHypotheses) – The alived hypotheses.

  • inp_tokens (torch.Tensor) – The input tensor of the current step.

  • log_probs (torch.Tensor) – The log-probabilities of the current step output.

  • eos_hyps_and_log_probs_scores (list) – Generated hypotheses (the one that haved reached eos) and log probs scores.

  • memory (No limit) – The memory variables generated in this step.

  • scorer_memory (No limit) – The memory variables generated in this step.

  • attn (torch.Tensor) – The attention weight.

  • prev_attn_peak (torch.Tensor) – The previous attention peak place.

  • scores (torch.Tensor) – The scores of the current step output.

forward(enc_states, wav_len)[source]

Applies beamsearch and returns the predicted tokens.

Parameters:
  • enc_states (torch.Tensor) – The encoder states to be attended.

  • wav_len (torch.Tensor) – The actual length of each enc_states sequence.

Returns:

  • hyps (list) – The predicted tokens.

  • best_lens (torch.Tensor) – The length of each predicted tokens.

  • best_scores (torch.Tensor) – The scores of each predicted tokens.

  • best_log_probs (torch.Tensor) – The log probabilities of each predicted tokens.

permute_mem(memory, index)[source]

This method permutes the seq2seq model memory to synchronize the memory index with the current output.

Parameters:
  • memory (No limit) – The memory variable to be permuted.

  • index (torch.Tensor) – The index of the previous path.

Return type:

The variable of the memory being permuted.

training: bool
class speechbrain.decoders.seq2seq.S2SRNNBeamSearcher(embedding, decoder, linear, temperature=1.0, **kwargs)[source]

Bases: S2SBeamSearcher

This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). See also S2SBaseSearcher(), S2SBeamSearcher().

Parameters:
  • embedding (torch.nn.Module) – An embedding layer.

  • decoder (torch.nn.Module) – Attentional RNN decoder.

  • linear (torch.nn.Module) – A linear output layer.

  • temperature (float) – Temperature factor applied to softmax. It changes the probability distribution, being softer when T>1 and sharper with T<1.

  • **kwargs – see S2SBeamSearcher, arguments are directly passed.

Example

>>> import speechbrain as sb
>>> vocab_size = 5
>>> emb = torch.nn.Embedding(vocab_size, 3)
>>> dec = sb.nnet.RNN.AttentionalRNNDecoder(
...     "gru", "content", 3, 3, 1, enc_dim=7, input_size=3
... )
>>> lin = sb.nnet.linear.Linear(n_neurons=vocab_size, input_size=3)
>>> coverage_scorer = sb.decoders.scorer.CoverageScorer(vocab_size)
>>> scorer = sb.decoders.scorer.ScorerBuilder(
...     full_scorers = [coverage_scorer],
...     partial_scorers = [],
...     weights= dict(coverage=1.5)
... )
>>> searcher = S2SRNNBeamSearcher(
...     embedding=emb,
...     decoder=dec,
...     linear=lin,
...     bos_index=4,
...     eos_index=4,
...     min_decode_ratio=0,
...     max_decode_ratio=1,
...     beam_size=2,
...     scorer=scorer,
... )
>>> batch_size = 2
>>> enc = torch.rand([batch_size, 6, 7])
>>> wav_len = torch.ones([batch_size])
>>> hyps, _, _, _ = searcher(enc, wav_len)
reset_mem(batch_size, device)[source]

Needed to reset the memory during beamsearch.

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

Performs a step in the implemented beamsearcher.

permute_mem(memory, index)[source]

Memory permutation during beamsearch.

training: bool
class speechbrain.decoders.seq2seq.S2STransformerBeamSearcher(modules, temperature=1.0, **kwargs)[source]

Bases: S2SBeamSearcher

This class implements the beam search decoding for Transformer. See also S2SBaseSearcher(), S2SBeamSearcher(). :param modules:

modeltorch.nn.Module

A Transformer model.

seq_lintorch.nn.Module

A linear output layer.

Parameters:
  • linear (torch.nn.Module) – A linear output layer.

  • **kwargs – Arguments to pass to S2SBeamSearcher

Example

>>> from speechbrain.nnet.linear import Linear
>>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
>>> from speechbrain.decoders import S2STransformerBeamSearcher
>>> batch_size=8
>>> n_channels=6
>>> input_size=40
>>> d_model=128
>>> tgt_vocab=140
>>> src = torch.rand([batch_size, n_channels, input_size])
>>> tgt = torch.randint(0, tgt_vocab, [batch_size, n_channels])
>>> net = TransformerASR(
...    tgt_vocab, input_size, d_model, 8, 1, 1, 1024, activation=torch.nn.GELU
... )
>>> ctc_lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab)
>>> lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab)
>>> searcher = S2STransformerBeamSearcher(
...     modules=[net, lin],
...     bos_index=1,
...     eos_index=2,
...     min_decode_ratio=0.0,
...     max_decode_ratio=1.0,
...     using_eos_threshold=False,
...     beam_size=7,
...     temperature=1.15,
... )
>>> enc, dec = net.forward(src, tgt)
>>> hyps, _, _, _  = searcher(enc, torch.ones(batch_size))
reset_mem(batch_size, device)[source]

Needed to reset the memory during beamsearch.

permute_mem(memory, index)[source]

Memory permutation during beamsearch.

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

Performs a step in the implemented beamsearcher.

training: bool
class speechbrain.decoders.seq2seq.S2SWhisperGreedySearch(model, language_token=50259, bos_token=50258, task_token=50359, timestamp_token=50363, max_length=448, **kwargs)[source]

Bases: S2SGreedySearcher

This class implements the greedy decoding for Whisper neural nets made by OpenAI in https://cdn.openai.com/papers/whisper.pdf. :param model: The Whisper model. :type model: HuggingFaceWhisper :param language_token: The language token to be used for the decoder input. :type language_token: int :param bos_token: The beginning of sentence token to be used for the decoder input. :type bos_token: int :param task_token: The task token to be used for the decoder input. :type task_token: int :param timestamp_token: The timestamp token to be used for the decoder input. :type timestamp_token: int :param max_length: The maximum decoding steps to perform.

The Whisper model has a maximum length of 448.

Parameters:

**kwargs – see S2SBaseSearcher, arguments are directly passed.

set_language_token(language_token)[source]

set the language token to be used for the decoder input.

set_bos_token(bos_token)[source]

set the bos token to be used for the decoder input.

set_task_token(task_token)[source]

set the task token to be used for the decoder input.

set_timestamp_token(timestamp_token)[source]

set the timestamp token to be used for the decoder input.

set_decoder_input_tokens(decoder_input_tokens)[source]

decoder_input_tokens are the tokens used as input to the decoder. They are directly taken from the tokenizer.prefix_tokens attribute. decoder_input_tokens = [bos_token, language_token, task_token, timestamp_token]

reset_mem(batch_size, device)[source]

This method set the first tokens to be decoder_input_tokens during search.

permute_mem(memory, index)[source]

Memory permutation during beamsearch.

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

Performs a step in the implemented beamsearcher.

change_max_decoding_length(min_decode_steps, max_decode_steps)[source]

set the minimum/maximum length the decoder can take.

training: bool
class speechbrain.decoders.seq2seq.S2STransformerGreedySearch(modules, temperature=1.0, **kwargs)[source]

Bases: S2SGreedySearcher

This class implements the greedy decoding for Transformer.

Parameters:
  • modules (list with the followings one:) –

    modeltorch.nn.Module

    A TransformerASR model.

    seq_lintorch.nn.Module

    A linear output layer for the seq2seq model.

  • temperature (float) – Temperature to use during decoding.

  • **kwargs – Arguments to pass to S2SGreedySearcher

reset_mem(batch_size, device)[source]

Needed to reset the memory during greedy search.

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

Performs a step in the implemented greedy searcher.

training: bool
class speechbrain.decoders.seq2seq.S2SWhisperBeamSearch(module, temperature=1.0, language_token=50259, bos_token=50258, task_token=50359, timestamp_token=50363, max_length=448, **kwargs)[source]

Bases: S2SBeamSearcher

This class implements the beam search decoding for Whisper neural nets made by OpenAI in https://cdn.openai.com/papers/whisper.pdf. :param module:

modeltorch.nn.Module

A whisper model. It should have a decode() method.

ctc_lintorch.nn.Module (optional)

A linear output layer for CTC.

Parameters:
  • language_token (int) – The token to use for language.

  • bos_token (int) – The token to use for beginning of sentence.

  • task_token (int) – The token to use for task.

  • timestamp_token (int) – The token to use for timestamp.

  • max_length (int) – The maximum decoding steps to perform. The Whisper model has a maximum length of 448.

  • **kwargs – Arguments to pass to S2SBeamSearcher

set_language_token(language_token)[source]

set the language token to use for the decoder input.

set_bos_token(bos_token)[source]

set the bos token to use for the decoder input.

set_task_token(task_token)[source]

set the task token to use for the decoder input.

set_timestamp_token(timestamp_token)[source]

set the timestamp token to use for the decoder input.

change_max_decoding_length(min_decode_steps, max_decode_steps)[source]

set the minimum/maximum length the decoder can take.

set_decoder_input_tokens(decoder_input_tokens)[source]

decoder_input_tokens are the tokens used as input to the decoder. They are directly taken from the tokenizer.prefix_tokens attribute. decoder_input_tokens = [bos_token, language_token, task_token, timestamp_token]

reset_mem(batch_size, device)[source]

This method set the first tokens to be decoder_input_tokens during search.

permute_mem(memory, index)[source]

Permutes the memory.

set_n_out()[source]

set the number of output tokens.

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

Performs a step in the implemented beamsearcher.

training: bool
class speechbrain.decoders.seq2seq.S2SHFTextBasedBeamSearcher(modules, vocab_size, **kwargs)[source]

Bases: S2STransformerBeamSearcher

This class implements the beam search decoding for the text-based HF seq2seq models, such as mBART or NLLB. It is NOT significantly different from S2STransformerBeamSearcher. This is why it inherits S2STransformerBeamSearcher. The main difference might arise when one wishes to use directly the lm_head of the text-based HF model rather than making a new projection layer (self.fc = None).

Parameters:
  • modules (list with the followings one:) –

    modeltorch.nn.Module

    A Transformer model.

    seq_lintorch.nn.Module

    A linear output layer. Normally set to None for this usecase.

  • vocab_size (int) – The dimension of the lm_head.

  • **kwargs – Arguments to pass to S2SBeamSearcher

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

Performs a step in the implemented beamsearcher.

training: bool
set_n_out()[source]

set the number of output tokens.