speechbrain.decoders.seq2seq module

Decoding methods for seq2seq autoregressive model.

Authors
  • Ju-Chieh Chou 2020

  • Peter Plantinga 2020

  • Mirco Ravanelli 2020

  • Sung-Lin Yeh 2020

Summary

Classes:

S2SBaseSearcher

S2SBaseSearcher class to be inherited by other decoding approaches for seq2seq model.

S2SBeamSearcher

This class implements the beam-search algorithm for the seq2seq model.

S2SGreedySearcher

This class implements the general forward-pass of greedy decoding approach.

S2SRNNBeamSearchLM

This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py) with LM.

S2SRNNBeamSearchTransformerLM

This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py) with LM.

S2SRNNBeamSearcher

This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py).

S2SRNNGreedySearcher

This class implements the greedy decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py).

S2STransformerBeamSearch

This class implements the beam search decoding for Transformer.

Functions:

batch_filter_seq2seq_output

Calling batch_size times of filter_seq2seq_output.

filter_seq2seq_output

Filter the output until the first eos occurs (exclusive).

inflate_tensor

This function inflates the tensor for times along dim.

mask_by_condition

This function will mask some element in the tensor with fill_value, if condition=False.

Reference

class speechbrain.decoders.seq2seq.S2SBaseSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]

Bases: Module

S2SBaseSearcher class to be inherited by other decoding approaches for seq2seq model.

Parameters
  • bos_index (int) – The index of the beginning-of-sequence (bos) token.

  • eos_index (int) – The index of end-of-sequence token.

  • min_decode_radio (float) – The ratio of minimum decoding steps to the length of encoder states.

  • max_decode_radio (float) – The ratio of maximum decoding steps to the length of encoder states.

Returns

  • predictions – Outputs as Python list of lists, with “ragged” dimensions; padding has been removed.

  • scores – The sum of log probabilities (and possibly additional heuristic scores) for each prediction.

forward(enc_states, wav_len)[source]

This method should implement the forward algorithm of decoding method.

Parameters
  • enc_states (torch.Tensor) – The precomputed encoder states to be used when decoding. (ex. the encoded speech representation to be attended).

  • wav_len (torch.Tensor) – The speechbrain-style relative length.

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

This method should implement one step of forwarding operation in the autoregressive model.

Parameters
  • inp_tokens (torch.Tensor) – The input tensor of the current timestep.

  • memory (No limit) – The memory variables input for this timestep. (ex. RNN hidden states).

  • enc_states (torch.Tensor) – The encoder states to be attended.

  • enc_lens (torch.Tensor) – The actual length of each enc_states sequence.

Returns

  • log_probs (torch.Tensor) – Log-probabilities of the current timestep output.

  • memory (No limit) – The memory variables generated in this timestep. (ex. RNN hidden states).

  • attn (torch.Tensor) – The attention weight for doing penalty.

reset_mem(batch_size, device)[source]

This method should implement the resetting of memory variables for the seq2seq model. E.g., initializing zero vector as initial hidden states.

Parameters
  • batch_size (int) – The size of the batch.

  • device (torch.device) – The device to put the initial variables.

Returns

memory – The initial memory variable.

Return type

No limit

lm_forward_step(inp_tokens, memory)[source]

This method should implement one step of forwarding operation for language model.

Parameters
  • inp_tokens (torch.Tensor) – The input tensor of the current timestep.

  • memory (No limit) – The momory variables input for this timestep. (e.g., RNN hidden states).

Returns

  • log_probs (torch.Tensor) – Log-probabilities of the current timestep output.

  • memory (No limit) – The memory variables generated in this timestep. (e.g., RNN hidden states).

reset_lm_mem(batch_size, device)[source]

This method should implement the resetting of memory variables in the language model. E.g., initializing zero vector as initial hidden states.

Parameters
  • batch_size (int) – The size of the batch.

  • device (torch.device) – The device to put the initial variables.

Returns

memory – The initial memory variable.

Return type

No limit

training: bool
class speechbrain.decoders.seq2seq.S2SGreedySearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]

Bases: S2SBaseSearcher

This class implements the general forward-pass of greedy decoding approach. See also S2SBaseSearcher().

forward(enc_states, wav_len)[source]

This method performs a greedy search.

Parameters
  • enc_states (torch.Tensor) – The precomputed encoder states to be used when decoding. (ex. the encoded speech representation to be attended).

  • wav_len (torch.Tensor) – The speechbrain-style relative length.

training: bool
class speechbrain.decoders.seq2seq.S2SRNNGreedySearcher(embedding, decoder, linear, **kwargs)[source]

Bases: S2SGreedySearcher

This class implements the greedy decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). See also S2SBaseSearcher() and S2SGreedySearcher().

Parameters
  • embedding (torch.nn.Module) – An embedding layer.

  • decoder (torch.nn.Module) – Attentional RNN decoder.

  • linear (torch.nn.Module) – A linear output layer.

  • **kwargs – see S2SBaseSearcher, arguments are directly passed.

Example

>>> emb = torch.nn.Embedding(5, 3)
>>> dec = sb.nnet.RNN.AttentionalRNNDecoder(
...     "gru", "content", 3, 3, 1, enc_dim=7, input_size=3
... )
>>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3)
>>> searcher = S2SRNNGreedySearcher(
...     embedding=emb,
...     decoder=dec,
...     linear=lin,
...     bos_index=4,
...     eos_index=4,
...     min_decode_ratio=0,
...     max_decode_ratio=1,
... )
>>> enc = torch.rand([2, 6, 7])
>>> wav_len = torch.rand([2])
>>> hyps, scores = searcher(enc, wav_len)
reset_mem(batch_size, device)[source]

When doing greedy search, keep hidden state (hs) adn context vector (c) as memory.

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

Performs a step in the implemented beamsearcher.

training: bool
class speechbrain.decoders.seq2seq.S2SBeamSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio, beam_size, topk=1, return_log_probs=False, using_eos_threshold=True, eos_threshold=1.5, length_normalization=True, length_rewarding=0, coverage_penalty=0.0, lm_weight=0.0, lm_modules=None, ctc_weight=0.0, blank_index=0, ctc_score_mode='full', ctc_window_size=0, using_max_attn_shift=False, max_attn_shift=60, minus_inf=-1e+20)[source]

Bases: S2SBaseSearcher

This class implements the beam-search algorithm for the seq2seq model. See also S2SBaseSearcher().

Parameters
  • bos_index (int) – The index of beginning-of-sequence token.

  • eos_index (int) – The index of end-of-sequence token.

  • min_decode_radio (float) – The ratio of minimum decoding steps to length of encoder states.

  • max_decode_radio (float) – The ratio of maximum decoding steps to length of encoder states.

  • beam_size (int) – The width of beam.

  • topk (int) – The number of hypothesis to return. (default: 1)

  • return_log_probs (bool) – Whether to return log-probabilities. (default: False)

  • using_eos_threshold (bool) – Whether to use eos threshold. (default: true)

  • eos_threshold (float) – The threshold coefficient for eos token (default: 1.5). See 3.1.2 in reference: https://arxiv.org/abs/1904.02619

  • length_normalization (bool) – Whether to divide the scores by the length. (default: True)

  • length_rewarding (float) – The coefficient of length rewarding (γ). log P(y|x) + λ log P_LM(y) + γ*len(y). (default: 0.0)

  • coverage_penalty (float) – The coefficient of coverage penalty (η). log P(y|x) + λ log P_LM(y) + γ*len(y) + η*coverage(x,y). (default: 0.0) Reference: https://arxiv.org/pdf/1612.02695.pdf, https://arxiv.org/pdf/1808.10792.pdf

  • lm_weight (float) – The weight of LM when performing beam search (λ). log P(y|x) + λ log P_LM(y). (default: 0.0)

  • ctc_weight (float) – The weight of CTC probabilities when performing beam search (λ). (1-λ) log P(y|x) + λ log P_CTC(y|x). (default: 0.0)

  • blank_index (int) – The index of the blank token.

  • ctc_score_mode (str) – Default: “full” CTC prefix scoring on “partial” token or “full: token.

  • ctc_window_size (int) – Default: 0 Compute the ctc scores over the time frames using windowing based on attention peaks. If 0, no windowing applied.

  • using_max_attn_shift (bool) – Whether using the max_attn_shift constraint. (default: False)

  • max_attn_shift (int) – Beam search will block the beams that attention shift more than max_attn_shift. Reference: https://arxiv.org/abs/1904.02619

  • minus_inf (float) – DefaultL -1e20 The value of minus infinity to block some path of the search.

forward(enc_states, wav_len)[source]

Applies beamsearch and returns the predicted tokens.

ctc_forward_step(x)[source]

Applies a ctc step during bramsearch.

permute_mem(memory, index)[source]

This method permutes the seq2seq model memory to synchronize the memory index with the current output.

Parameters
  • memory (No limit) – The memory variable to be permuted.

  • index (torch.Tensor) – The index of the previous path.

Return type

The variable of the memory being permuted.

permute_lm_mem(memory, index)[source]

This method permutes the language model memory to synchronize the memory index with the current output.

Parameters
  • memory (No limit) – The memory variable to be permuted.

  • index (torch.Tensor) – The index of the previous path.

Return type

The variable of the memory being permuted.

training: bool
class speechbrain.decoders.seq2seq.S2SRNNBeamSearcher(embedding, decoder, linear, ctc_linear=None, temperature=1.0, **kwargs)[source]

Bases: S2SBeamSearcher

This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). See also S2SBaseSearcher(), S2SBeamSearcher().

Parameters
  • embedding (torch.nn.Module) – An embedding layer.

  • decoder (torch.nn.Module) – Attentional RNN decoder.

  • linear (torch.nn.Module) – A linear output layer.

  • temperature (float) – Temperature factor applied to softmax. It changes the probability distribution, being softer when T>1 and sharper with T<1.

  • **kwargs – see S2SBeamSearcher, arguments are directly passed.

Example

>>> emb = torch.nn.Embedding(5, 3)
>>> dec = sb.nnet.RNN.AttentionalRNNDecoder(
...     "gru", "content", 3, 3, 1, enc_dim=7, input_size=3
... )
>>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3)
>>> ctc_lin = sb.nnet.linear.Linear(n_neurons=5, input_size=7)
>>> searcher = S2SRNNBeamSearcher(
...     embedding=emb,
...     decoder=dec,
...     linear=lin,
...     ctc_linear=ctc_lin,
...     bos_index=4,
...     eos_index=4,
...     blank_index=4,
...     min_decode_ratio=0,
...     max_decode_ratio=1,
...     beam_size=2,
... )
>>> enc = torch.rand([2, 6, 7])
>>> wav_len = torch.rand([2])
>>> hyps, scores = searcher(enc, wav_len)
reset_mem(batch_size, device)[source]

Needed to reset the memory during beamsearch.

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

Performs a step in the implemented beamsearcher.

permute_mem(memory, index)[source]

Memory permutation during beamsearch.

training: bool
class speechbrain.decoders.seq2seq.S2SRNNBeamSearchLM(embedding, decoder, linear, language_model, temperature_lm=1.0, **kwargs)[source]

Bases: S2SRNNBeamSearcher

This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py) with LM. See also S2SBaseSearcher(), S2SBeamSearcher(), S2SRNNBeamSearcher().

Parameters
  • embedding (torch.nn.Module) – An embedding layer.

  • decoder (torch.nn.Module) – Attentional RNN decoder.

  • linear (torch.nn.Module) – A linear output layer.

  • language_model (torch.nn.Module) – A language model.

  • temperature_lm (float) – Temperature factor applied to softmax. It changes the probability distribution, being softer when T>1 and sharper with T<1.

  • **kwargs – Arguments to pass to S2SBeamSearcher.

Example

>>> from speechbrain.lobes.models.RNNLM import RNNLM
>>> emb = torch.nn.Embedding(5, 3)
>>> dec = sb.nnet.RNN.AttentionalRNNDecoder(
...     "gru", "content", 3, 3, 1, enc_dim=7, input_size=3
... )
>>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3)
>>> lm = RNNLM(output_neurons=5, return_hidden=True)
>>> searcher = S2SRNNBeamSearchLM(
...     embedding=emb,
...     decoder=dec,
...     linear=lin,
...     language_model=lm,
...     bos_index=4,
...     eos_index=4,
...     blank_index=4,
...     min_decode_ratio=0,
...     max_decode_ratio=1,
...     beam_size=2,
...     lm_weight=0.5,
... )
>>> enc = torch.rand([2, 6, 7])
>>> wav_len = torch.rand([2])
>>> hyps, scores = searcher(enc, wav_len)
lm_forward_step(inp_tokens, memory)[source]

Applies a step to the LM during beamsearch.

permute_lm_mem(memory, index)[source]

This is to permute lm memory to synchronize with current index during beam search. The order of beams will be shuffled by scores every timestep to allow batched beam search. Further details please refer to speechbrain/decoder/seq2seq.py.

reset_lm_mem(batch_size, device)[source]

Needed to reset the LM memory during beamsearch.

training: bool
class speechbrain.decoders.seq2seq.S2SRNNBeamSearchTransformerLM(embedding, decoder, linear, language_model, temperature_lm=1.0, **kwargs)[source]

Bases: S2SRNNBeamSearcher

This class implements the beam search decoding for AttentionalRNNDecoder (speechbrain/nnet/RNN.py) with LM. See also S2SBaseSearcher(), S2SBeamSearcher(), S2SRNNBeamSearcher().

Parameters
  • embedding (torch.nn.Module) – An embedding layer.

  • decoder (torch.nn.Module) – Attentional RNN decoder.

  • linear (torch.nn.Module) – A linear output layer.

  • language_model (torch.nn.Module) – A language model.

  • temperature_lm (float) – Temperature factor applied to softmax. It changes the probability distribution, being softer when T>1 and sharper with T<1.

  • **kwargs – Arguments to pass to S2SBeamSearcher.

Example

>>> from speechbrain.lobes.models.transformer.TransformerLM import TransformerLM
>>> emb = torch.nn.Embedding(5, 3)
>>> dec = sb.nnet.RNN.AttentionalRNNDecoder(
...     "gru", "content", 3, 3, 1, enc_dim=7, input_size=3
... )
>>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3)
>>> lm = TransformerLM(5, 512, 8, 1, 0, 1024, activation=torch.nn.GELU)
>>> searcher = S2SRNNBeamSearchTransformerLM(
...     embedding=emb,
...     decoder=dec,
...     linear=lin,
...     language_model=lm,
...     bos_index=4,
...     eos_index=4,
...     blank_index=4,
...     min_decode_ratio=0,
...     max_decode_ratio=1,
...     beam_size=2,
...     lm_weight=0.5,
... )
>>> enc = torch.rand([2, 6, 7])
>>> wav_len = torch.rand([2])
>>> hyps, scores = searcher(enc, wav_len)
lm_forward_step(inp_tokens, memory)[source]

Performs a step in the LM during beamsearch.

permute_lm_mem(memory, index)[source]

Permutes the LM ,emory during beamsearch

reset_lm_mem(batch_size, device)[source]

Needed to reset the LM memory during beamsearch

training: bool
speechbrain.decoders.seq2seq.inflate_tensor(tensor, times, dim)[source]

This function inflates the tensor for times along dim.

Parameters
  • tensor (torch.Tensor) – The tensor to be inflated.

  • times (int) – The tensor will inflate for this number of times.

  • dim (int) – The dim to be inflated.

Returns

The inflated tensor.

Return type

torch.Tensor

Example

>>> tensor = torch.Tensor([[1,2,3], [4,5,6]])
>>> new_tensor = inflate_tensor(tensor, 2, dim=0)
>>> new_tensor
tensor([[1., 2., 3.],
        [1., 2., 3.],
        [4., 5., 6.],
        [4., 5., 6.]])
speechbrain.decoders.seq2seq.mask_by_condition(tensor, cond, fill_value)[source]

This function will mask some element in the tensor with fill_value, if condition=False.

Parameters
  • tensor (torch.Tensor) – The tensor to be masked.

  • cond (torch.BoolTensor) – This tensor has to be the same size as tensor. Each element represents whether to keep the value in tensor.

  • fill_value (float) – The value to fill in the masked element.

Returns

The masked tensor.

Return type

torch.Tensor

Example

>>> tensor = torch.Tensor([[1,2,3], [4,5,6]])
>>> cond = torch.BoolTensor([[True, True, False], [True, False, False]])
>>> mask_by_condition(tensor, cond, 0)
tensor([[1., 2., 0.],
        [4., 0., 0.]])
class speechbrain.decoders.seq2seq.S2STransformerBeamSearch(modules, temperature=1.0, temperature_lm=1.0, **kwargs)[source]

Bases: S2SBeamSearcher

This class implements the beam search decoding for Transformer. See also S2SBaseSearcher(), S2SBeamSearcher().

Parameters
  • model (torch.nn.Module) – The model to use for decoding.

  • linear (torch.nn.Module) – A linear output layer.

  • **kwargs – Arguments to pass to S2SBeamSearcher

  • Example

  • --------

  • recipes/LibriSpeech/ASR_transformer/experiment.py (>>> # see) –

reset_mem(batch_size, device)[source]

Needed to reset the memory during beamsearch.

reset_lm_mem(batch_size, device)[source]

Needed to reset the LM memory during beamsearch.

permute_mem(memory, index)[source]

Permutes the memory.

permute_lm_mem(memory, index)[source]

Permutes the memory of the language model.

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

Performs a step in the implemented beamsearcher.

lm_forward_step(inp_tokens, memory)[source]

Performs a step in the implemented LM module.

training: bool
speechbrain.decoders.seq2seq.batch_filter_seq2seq_output(prediction, eos_id=-1)[source]

Calling batch_size times of filter_seq2seq_output.

Parameters
  • prediction (list of torch.Tensor) – A list containing the output ints predicted by the seq2seq system.

  • eos_id (int, string) – The id of the eos.

Returns

The output predicted by seq2seq model.

Return type

list

Example

>>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])]
>>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4)
>>> predictions
[[1, 2, 3], [2, 3]]
speechbrain.decoders.seq2seq.filter_seq2seq_output(string_pred, eos_id=-1)[source]

Filter the output until the first eos occurs (exclusive).

Parameters
  • string_pred (list) – A list containing the output strings/ints predicted by the seq2seq system.

  • eos_id (int, string) – The id of the eos.

Returns

The output predicted by seq2seq model.

Return type

list

Example

>>> string_pred = ['a','b','c','d','eos','e']
>>> string_out = filter_seq2seq_output(string_pred, eos_id='eos')
>>> string_out
['a', 'b', 'c', 'd']