speechbrain.decoders.ctc module

Decoders and output normalization for CTC.

Authors
  • Mirco Ravanelli 2020

  • Aku Rouhe 2020

  • Sung-Lin Yeh 2020

  • Adel Moumen 2023, 2024

Summary

Classes:

CTCBaseSearcher

CTCBaseSearcher class to be inherited by other CTC beam searchers.

CTCBeam

This class handle the CTC beam informations during decoding.

CTCBeamSearcher

CTC Beam Search is a Beam Search for CTC which does not keep track of the blank and non-blank probabilities.

CTCHypothesis

This class is a data handler over the generated hypotheses.

CTCPrefixBeamSearcher

CTC Prefix Beam Search is based on the paper First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs by Awni Y.

CTCPrefixScore

This class implements the CTC prefix score of Algorithm 2 in reference: https://www.merl.com/publications/docs/TR2017-190.pdf.

LMCTCBeam

This class handle the LM scores during decoding.

TorchAudioCTCPrefixBeamSearcher

TorchAudio CTC Prefix Beam Search Decoder.

Functions:

ctc_greedy_decode

Greedy decode a batch of probabilities and apply CTC rules.

filter_ctc_output

Apply CTC output merge and filter rules.

Reference

class speechbrain.decoders.ctc.CTCPrefixScore(x, enc_lens, blank_index, eos_index, ctc_window_size=0)[source]

Bases: object

This class implements the CTC prefix score of Algorithm 2 in reference: https://www.merl.com/publications/docs/TR2017-190.pdf. Official implementation: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py

Parameters:
  • x (torch.Tensor) – The encoder states.

  • enc_lens (torch.Tensor) – The actual length of each enc_states sequence.

  • batch_size (int) – The size of the batch.

  • beam_size (int) – The width of beam.

  • blank_index (int) – The index of the blank token.

  • eos_index (int) – The index of the end-of-sequence (eos) token.

  • ctc_window_size (int) – Compute the ctc scores over the time frames using windowing based on attention peaks. If 0, no windowing applied.

forward_step(inp_tokens, states, candidates=None, attn=None)[source]

This method if one step of forwarding operation for the prefix ctc scorer.

Parameters:
  • inp_tokens (torch.Tensor) – The last chars of prefix label sequences g, where h = g + c.

  • states (tuple) – Previous ctc states.

  • candidates (torch.Tensor) – (batch_size * beam_size, ctc_beam_size), The topk candidates for rescoring. If given, performing partial ctc scoring.

  • attn (torch.Tensor) – (batch_size * beam_size, max_enc_len), The attention weights.

permute_mem(memory, index)[source]

This method permutes the CTC model memory to synchronize the memory index with the current output.

Parameters:
  • memory (No limit) – The memory variable to be permuted.

  • index (torch.Tensor) – The index of the previous path.

Return type:

The variable of the memory being permuted.

speechbrain.decoders.ctc.filter_ctc_output(string_pred, blank_id=-1)[source]

Apply CTC output merge and filter rules.

Removes the blank symbol and output repetitions.

Parameters:
  • string_pred (list) – A list containing the output strings/ints predicted by the CTC system.

  • blank_id (int, string) – The id of the blank.

Returns:

The output predicted by CTC without the blank symbol and the repetitions.

Return type:

list

Example

>>> string_pred = ['a','a','blank','b','b','blank','c']
>>> string_out = filter_ctc_output(string_pred, blank_id='blank')
>>> print(string_out)
['a', 'b', 'c']
speechbrain.decoders.ctc.ctc_greedy_decode(probabilities, seq_lens, blank_id=-1)[source]

Greedy decode a batch of probabilities and apply CTC rules.

Parameters:
  • probabilities (torch.tensor) – Output probabilities (or log-probabilities) from the network with shape [batch, lengths, probabilities]

  • seq_lens (torch.tensor) – Relative true sequence lengths (to deal with padded inputs), the longest sequence has length 1.0, others a value between zero and one shape [batch, lengths].

  • blank_id (int, string) – The blank symbol/index. Default: -1. If a negative number is given, it is assumed to mean counting down from the maximum possible index, so that -1 refers to the maximum possible index.

Returns:

Outputs as Python list of lists, with “ragged” dimensions; padding has been removed.

Return type:

list

Example

>>> import torch
>>> probs = torch.tensor([[[0.3, 0.7], [0.0, 0.0]],
...                       [[0.2, 0.8], [0.9, 0.1]]])
>>> lens = torch.tensor([0.51, 1.0])
>>> blank_id = 0
>>> ctc_greedy_decode(probs, lens, blank_id)
[[1], [1]]
class speechbrain.decoders.ctc.CTCBeam(text: str, full_text: str, next_word: str, partial_word: str, last_token: str | None, last_token_index: int | None, text_frames: List[Tuple[int, int]], partial_frames: Tuple[int, int], p: float = -inf, p_b: float = -inf, p_nb: float = -inf, n_p_b: float = -inf, n_p_nb: float = -inf, score: float = -inf, score_ctc: float = -inf)[source]

Bases: object

This class handle the CTC beam informations during decoding.

Parameters:
  • text (str) – The current text of the beam.

  • full_text (str) – The full text of the beam.

  • next_word (str) – The next word to be added to the beam.

  • partial_word (str) – The partial word being added to the beam.

  • last_token (str, optional) – The last token of the beam.

  • last_token_index (int, optional) – The index of the last token of the beam.

  • text_frames (List[Tuple[int, int]]) – The start and end frame of the text.

  • partial_frames (Tuple[int, int]) – The start and end frame of the partial word.

  • p (float) – The probability of the beam.

  • p_b (float) – The probability of the beam ending in a blank.

  • p_nb (float) – The probability of the beam not ending in a blank.

  • n_p_b (float) – The previous probability of the beam ending in a blank.

  • n_p_nb (float) – The previous probability of the beam not ending in a blank.

  • score (float) – The score of the beam (LM + CTC)

  • score_ctc (float) – The CTC score computed.

Example

>>> beam = CTCBeam(
...     text="",
...     full_text="",
...     next_word="",
...     partial_word="",
...     last_token=None,
...     last_token_index=None,
...     text_frames=[(0, 0)],
...     partial_frames=(0, 0),
...     p=-math.inf,
...     p_b=-math.inf,
...     p_nb=-math.inf,
...     n_p_b=-math.inf,
...     n_p_nb=-math.inf,
...     score=-math.inf,
...     score_ctc=-math.inf,
... )
text: str
full_text: str
next_word: str
partial_word: str
last_token: str | None
last_token_index: int | None
text_frames: List[Tuple[int, int]]
partial_frames: Tuple[int, int]
p: float = -inf
p_b: float = -inf
p_nb: float = -inf
n_p_b: float = -inf
n_p_nb: float = -inf
score: float = -inf
score_ctc: float = -inf
classmethod from_lm_beam(lm_beam: LMCTCBeam) CTCBeam[source]

Create a CTCBeam from a LMCTCBeam

Parameters:

lm_beam (LMCTCBeam) – The LMCTCBeam to convert.

Returns:

The CTCBeam converted.

Return type:

CTCBeam

step() None[source]

Update the beam probabilities.

class speechbrain.decoders.ctc.LMCTCBeam(text: str, full_text: str, next_word: str, partial_word: str, last_token: str | None, last_token_index: int | None, text_frames: List[Tuple[int, int]], partial_frames: Tuple[int, int], p: float = -inf, p_b: float = -inf, p_nb: float = -inf, n_p_b: float = -inf, n_p_nb: float = -inf, score: float = -inf, score_ctc: float = -inf, lm_score: float = -inf)[source]

Bases: CTCBeam

This class handle the LM scores during decoding.

Parameters:
  • lm_score (float) – The LM score of the beam.

  • **kwargs – See CTCBeam for the other arguments.

lm_score: float = -inf
class speechbrain.decoders.ctc.CTCHypothesis(text: str, last_lm_state: None, score: float, lm_score: float, text_frames: list | None = None)[source]

Bases: object

This class is a data handler over the generated hypotheses.

This class is the default output of the CTC beam searchers.

It can be re-used for other decoders if using the beam searchers in an online fashion.

Parameters:
  • text (str) – The text of the hypothesis.

  • last_lm_state (None) – The last LM state of the hypothesis.

  • score (float) – The score of the hypothesis.

  • lm_score (float) – The LM score of the hypothesis.

  • text_frames (List[Tuple[str, Tuple[int, int]]], optional) – The list of the text and the corresponding frames.

text: str
last_lm_state: None
score: float
lm_score: float
text_frames: list = None
class speechbrain.decoders.ctc.CTCBaseSearcher(blank_index: int, vocab_list: List[str], space_token: str = ' ', kenlm_model_path: str | None = None, unigrams: None | List[str] = None, alpha: float = 0.5, beta: float = 1.5, unk_score_offset: float = -10.0, score_boundary: bool = True, beam_size: int = 100, beam_prune_logp: int = -10.0, token_prune_min_logp: int = -5.0, prune_history: bool = True, blank_skip_threshold: None | int = 1.0, topk: int = 1, spm_token: str = '▁')[source]

Bases: Module

CTCBaseSearcher class to be inherited by other CTC beam searchers.

This class provides the basic functionalities for CTC beam search decoding.

The space_token is required with a non-sentencepiece vocabulary list if your transcription is expecting to contain spaces.

Parameters:
  • blank_index (int) – The index of the blank token.

  • vocab_list (list) – The list of the vocabulary tokens.

  • space_token (int, optional) – The index of the space token. (default: -1)

  • kenlm_model_path (str, optional) – The path to the kenlm model. Use .bin for a faster loading. If None, no language model will be used. (default: None)

  • unigrams (list, optional) – The list of known word unigrams. (default: None)

  • alpha (float) – Weight for language model during shallow fusion. (default: 0.5)

  • beta (float) – Weight for length score adjustment of during scoring. (default: 1.5)

  • unk_score_offset (float) – Amount of log score offset for unknown tokens. (default: -10.0)

  • score_boundary (bool) – Whether to have kenlm respect boundaries when scoring. (default: True)

  • beam_size (int, optional) – The width of the beam. (default: 100)

  • beam_prune_logp (float, optional) – The pruning threshold for the beam. (default: -10.0)

  • token_prune_min_logp (float, optional) – The pruning threshold for the tokens. (default: -5.0)

  • prune_history (bool, optional) – Whether to prune the history. (default: True) Note: when using topk > 1, this should be set to False as it is pruning a lot of beams.

  • blank_skip_threshold (float, optional) – Skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding. Note: This is only used when using the CUDA decoder, and it might worsen the WER/CER results. Use it at your own risk. (default: 1.0)

  • topk (int, optional) – The number of top hypotheses to return. (default: 1)

  • spm_token (str, optional) – The sentencepiece token. (default: “▁”)

Example

>>> blank_index = 0
>>> vocab_list = ['blank', 'a', 'b', 'c', ' ']
>>> space_token = ' '
>>> kenlm_model_path = None
>>> unigrams = None
>>> beam_size = 100
>>> beam_prune_logp = -10.0
>>> token_prune_min_logp = -5.0
>>> prune_history = True
>>> blank_skip_threshold = 1.0
>>> topk = 1
>>> searcher = CTCBaseSearcher(
...     blank_index=blank_index,
...     vocab_list=vocab_list,
...     space_token=space_token,
...     kenlm_model_path=kenlm_model_path,
...     unigrams=unigrams,
...     beam_size=beam_size,
...     beam_prune_logp=beam_prune_logp,
...     token_prune_min_logp=token_prune_min_logp,
...     prune_history=prune_history,
...     blank_skip_threshold=blank_skip_threshold,
...     topk=topk,
... )
partial_decoding(log_probs: Tensor, beams: List[CTCBeam], cached_lm_scores: dict, cached_p_lm_scores: dict, processed_frames: int = 0)[source]

Perform a single step of decoding.

Parameters:
  • log_probs (torch.Tensor) – The log probabilities of the CTC output.

  • beams (list) – The list of the beams.

  • cached_lm_scores (dict) – The cached language model scores.

  • cached_p_lm_scores (dict) – The cached prefix language model scores.

  • processed_frames (int, default: 0) – The start frame of the current decoding step.

normalize_whitespace(text: str) str[source]

Efficiently normalize whitespace.

Parameters:

text (str) – The text to normalize.

Returns:

The normalized text.

Return type:

str

merge_tokens(token_1: str, token_2: str) str[source]

Merge two tokens, and avoid empty ones.

Taken from: https://github.com/kensho-technologies/pyctcdecode

Parameters:
  • token_1 (str) – The first token.

  • token_2 (str) – The second token.

Returns:

The merged token.

Return type:

str

merge_beams(beams: List[CTCBeam]) List[CTCBeam][source]

Merge beams with the same text.

Taken from: https://github.com/kensho-technologies/pyctcdecode

Parameters:

beams (list) – The list of the beams.

Returns:

The list of CTCBeam merged.

Return type:

list

sort_beams(beams: List[CTCBeam]) List[CTCBeam][source]

Sort beams by lm_score.

Parameters:

beams (list) – The list of CTCBeam.

Returns:

The list of CTCBeam sorted.

Return type:

list

finalize_decoding(beams: List[CTCBeam], cached_lm_scores: dict, cached_p_lm_scores: dict, force_next_word=False, is_end=False) List[CTCBeam][source]

Finalize the decoding process by adding and scoring the last partial word.

Parameters:
  • beams (list) – The list of CTCBeam.

  • cached_lm_scores (dict) – The cached language model scores.

  • cached_p_lm_scores (dict) – The cached prefix language model scores.

  • force_next_word (bool, default: False) – Whether to force the next word.

  • is_end (bool, default: False) – Whether the end of the sequence has been reached.

Returns:

The list of the CTCBeam.

Return type:

list

decode_beams(log_probs: Tensor, wav_lens: Tensor | None = None, lm_start_state: Any | None = None) List[List[CTCHypothesis]][source]

Decodes the input log probabilities of the CTC output.

It automatically converts the SpeechBrain’s relative length of the wav input to the absolute length.

Make sure that the input are in the log domain. The decoder will fail to decode logits or probabilities. The input should be the log probabilities of the CTC output.

Parameters:
  • log_probs (torch.Tensor) – The log probabilities of the CTC output. The expected shape is [batch_size, seq_length, vocab_size].

  • wav_lens (torch.Tensor, optional (default: None)) – The SpeechBrain’s relative length of the wav input.

  • lm_start_state (Any, optional (default: None)) – The start state of the language model.

Returns:

The list of topk list of CTCHypothesis.

Return type:

list of list

__call__(log_probs: Tensor, wav_lens: Tensor | None = None, lm_start_state: Any | None = None) List[List[CTCHypothesis]][source]

Decodes the log probabilities of the CTC output.

It automatically converts the SpeechBrain’s relative length of the wav input to the absolute length.

Each tensors is converted to numpy and CPU as it is faster and consummes less memory.

Parameters:
  • log_probs (torch.Tensor) – The log probabilities of the CTC output. The expected shape is [batch_size, seq_length, vocab_size].

  • wav_lens (torch.Tensor, optional (default: None)) – The SpeechBrain’s relative length of the wav input.

  • lm_start_state (Any, optional (default: None)) – The start state of the language model.

Returns:

The list of topk list of CTCHypothesis.

Return type:

list of list

partial_decode_beams(log_probs: Tensor, cached_lm_scores: dict, cached_p_lm_scores: dict, beams: List[CTCBeam], processed_frames: int, force_next_word=False, is_end=False) List[CTCBeam][source]

Perform a single step of decoding.

Parameters:
  • log_probs (torch.Tensor) – The log probabilities of the CTC output.

  • cached_lm_scores (dict) – The cached language model scores.

  • cached_p_lm_scores (dict) – The cached prefix language model scores.

  • beams (list) – The list of the beams.

  • processed_frames (int) – The start frame of the current decoding step.

  • force_next_word (bool, optional (default: False)) – Whether to force the next word.

  • is_end (bool, optional (default: False)) – Whether the end of the sequence has been reached.

Returns:

The list of CTCBeam.

Return type:

list

decode_log_probs(log_probs: Tensor, wav_len: int, lm_start_state: Any | None = None) List[CTCHypothesis][source]

Decodes the log probabilities of the CTC output.

Parameters:
  • log_probs (torch.Tensor) – The log probabilities of the CTC output. The expected shape is [seq_length, vocab_size].

  • wav_len (int) – The length of the wav input.

  • lm_start_state (Any, optional (default: None)) – The start state of the language model.

Returns:

The topk list of CTCHypothesis.

Return type:

list

training: bool
class speechbrain.decoders.ctc.CTCBeamSearcher(**kwargs)[source]

Bases: CTCBaseSearcher

CTC Beam Search is a Beam Search for CTC which does not keep track of the blank and non-blank probabilities. Each new token probability is added to the general score, and each beams that share the same text are merged together.

The implementation suppors n-gram scoring on words and SentencePiece tokens. The input is expected to be a log-probabilities tensor of shape [batch, time, vocab_size].

The main advantage of this CTCBeamSearcher over the CTCPrefixBeamSearcher is that it is relatively faster, and obtains slightly better results. However, the implementation is based on the one from the PyCTCDecode toolkit, adpated for the SpeechBrain’s needs and does not follow a specific paper. We do recommand to use the CTCPrefixBeamSearcher if you want to cite the appropriate paper for the decoding method.

Several heuristics are implemented to speed up the decoding process: - pruning of the beam : the beams are pruned if their score is lower than

the best beam score minus the beam_prune_logp

  • pruning of the tokensthe tokens are pruned if their score is lower than

    the token_prune_min_logp

  • pruning of the historythe beams are pruned if they are the same over

    max_ngram history

  • skipping of the blankthe frame is skipped if the blank probability is

    higher than the blank_skip_threshold

Note: if the Acoustic Model is not trained, the Beam Search will take a lot of time. We do recommand to use Greedy Search during validation until the model is fully trained and ready to be evaluated on test sets.

Parameters:

**kwargs – see CTCBaseSearcher, arguments are directly passed.

Example

>>> import torch
>>> from speechbrain.decoders import CTCBeamSearcher
>>> probs = torch.tensor([[[0.2, 0.0, 0.8],
...                   [0.4, 0.0, 0.6]]])
>>> log_probs = torch.log(probs)
>>> lens = torch.tensor([1.0])
>>> blank_index = 2
>>> vocab_list = ['a', 'b', '-']
>>> searcher = CTCBeamSearcher(blank_index=blank_index, vocab_list=vocab_list)
>>> hyps = searcher(probs, lens)
get_lm_beams(beams: List[CTCBeam], cached_lm_scores: dict, cached_partial_token_scores: dict, is_eos=False) List[LMCTCBeam][source]

Score the beams with the language model if not None, and return the new beams.

This function is modified and adapted from https://github.com/kensho-technologies/pyctcdecode

Parameters:
  • beams (list) – The list of the beams.

  • cached_lm_scores (dict) – The cached language model scores.

  • cached_partial_token_scores (dict) – The cached partial token scores.

  • is_eos (bool (default: False)) – Whether the end of the sequence has been reached.

Returns:

new_beams – The list of the new beams.

Return type:

list

partial_decoding(log_probs: Tensor, wav_len: int, beams: List[CTCBeam], cached_lm_scores: dict, cached_p_lm_scores: dict, processed_frames: int = 0) List[CTCBeam][source]

Perform CTC Prefix Beam Search decoding.

If self.lm is not None, the language model scores are computed and added to the CTC scores.

Parameters:
  • log_probs (torch.Tensor) – The log probabilities of the CTC input. Shape: (seq_length, vocab_size)

  • wav_len (int) – The length of the input sequence.

  • beams (list) – The list of CTCBeam objects.

  • cached_lm_scores (dict) – The cached language model scores.

  • cached_p_lm_scores (dict) – The cached prefix language model scores.

  • processed_frames (int) – The start frame of the current decoding step. (default: 0)

Returns:

beams – The list of CTCBeam objects.

Return type:

list

training: bool
class speechbrain.decoders.ctc.CTCPrefixBeamSearcher(**kwargs)[source]

Bases: CTCBaseSearcher

CTC Prefix Beam Search is based on the paper First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs by Awni Y. Hannun and al (https://arxiv.org/abs/1408.2873).

The implementation keep tracks of the blank and non-blank probabilities. It also suppors n-gram scoring on words and SentencePiece tokens. The input is expected to be a log-probabilities tensor of shape [batch, time, vocab_size].

Several heuristics are implemented to speed up the decoding process: - pruning of the beam : the beams are pruned if their score is lower than

the best beam score minus the beam_prune_logp

  • pruning of the tokensthe tokens are pruned if their score is lower than

    the token_prune_min_logp

  • pruning of the historythe beams are pruned if they are the same over

    max_ngram history

  • skipping of the blankthe frame is skipped if the blank probability is

    higher than the blank_skip_threshold

Note: The CTCPrefixBeamSearcher can be more unstable than the CTCBeamSearcher or the TorchAudioCTCPrefixBeamSearch searcher. Please, use it with caution and check the results carefully.

Note: if the Acoustic Model is not trained, the Beam Search will take a lot of time. We do recommand to use Greedy Search during validation until the model is fully trained and ready to be evaluated on test sets.

Note: This implementation does not provide the time alignment of the hypothesis. If you need it, please use the CTCBeamSearcher.

Parameters:

**kwargs – see CTCBaseSearcher, arguments are directly passed.

Example

>>> import torch
>>> from speechbrain.decoders import CTCPrefixBeamSearcher
>>> probs = torch.tensor([[[0.2, 0.0, 0.8],
...                   [0.4, 0.0, 0.6]]])
>>> log_probs = torch.log(probs)
>>> lens = torch.tensor([1.0])
>>> blank_index = 2
>>> vocab_list = ['a', 'b', '-']
>>> searcher = CTCPrefixBeamSearcher(blank_index=blank_index, vocab_list=vocab_list)
>>> hyps = searcher(probs, lens)
get_lm_beams(beams: List[CTCBeam], cached_lm_scores: dict, cached_partial_token_scores: dict, is_eos=False) List[LMCTCBeam][source]

Score the beams with the language model if not None, and return the new beams.

This function is modified and adapted from https://github.com/kensho-technologies/pyctcdecode

Parameters:
  • beams (list) – The list of the beams.

  • cached_lm_scores (dict) – The cached language model scores.

  • cached_partial_token_scores (dict) – The cached partial token scores.

  • is_eos (bool (default: False)) – Whether the end of the sequence has been reached.

Returns:

new_beams – The list of the new beams.

Return type:

list

partial_decoding(log_probs: Tensor, wav_len: int, beams: List[CTCBeam], cached_lm_scores: dict, cached_p_lm_scores: dict, processed_frames: int = 0) List[CTCBeam][source]

Perform CTC Prefix Beam Search decoding.

If self.lm is not None, the language model scores are computed and added to the CTC scores.

Parameters:
  • log_probs (torch.Tensor) – The log probabilities of the CTC input. Shape: (seq_length, vocab_size)

  • wav_len (int) – The length of the input sequence.

  • beams (list) – The list of CTCBeam objects.

  • cached_lm_scores (dict) – The cached language model scores.

  • cached_p_lm_scores (dict) – The cached prefix language model scores.

  • processed_frames (int) – The start frame of the current decoding step. (default: 0)

Returns:

beams – The list of CTCBeam objects.

Return type:

list

training: bool
class speechbrain.decoders.ctc.TorchAudioCTCPrefixBeamSearcher(tokens: list | str, lexicon: str | None = None, lm: str | None = None, lm_dict: str | None = None, topk: int = 1, beam_size: int = 50, beam_size_token: int | None = None, beam_threshold: float = 50, lm_weight: float = 2, word_score: float = 0, unk_score: float = -inf, sil_score: float = 0, log_add: bool = False, blank_index: str | int = 0, sil_index: str | int = 0, unk_word: str = '<unk>', using_cpu_decoder: bool = True, blank_skip_threshold: float = 1.0)[source]

Bases: object

TorchAudio CTC Prefix Beam Search Decoder.

This class is a wrapper around the CTC decoder from TorchAudio. It provides a simple interface where you can either use the CPU or CUDA CTC decoder.

The CPU decoder is slower but uses less memory. The CUDA decoder is faster but uses more memory. The CUDA decoder is also only available in the nightly version of torchaudio.

A lot of features are missing in the CUDA decoder, such as the ability to use a language model, constraint search, and more. If you want to use those features, you have to use the CPU decoder.

For more information about the CPU decoder, please refer to the documentation of TorchAudio: https://pytorch.org/audio/main/generated/torchaudio.models.decoder.ctc_decoder.html

For more information about the CUDA decoder, please refer to the documentation of TorchAudio: https://pytorch.org/audio/main/generated/torchaudio.models.decoder.cuda_ctc_decoder.html#torchaudio.models.decoder.cuda_ctc_decoder

If you want to use the language model, or the lexicon search, please make sure that your tokenizer/acoustic model uses the same tokens as the language model/lexicon. Otherwise, the decoding will fail.

The implementation is compatible with Sentenpiece Tokens.

Note: When using CUDA CTC decoder, the blank_index has to be 0. Furthermore, using CUDA CTC decoder requires the nightly version of torchaudio and a lot of VRAM memory (if you want to use a lot of beams). Overall, we do recommand to use the CTCBeamSearcher or CTCPrefixBeamSearcher in SpeechBrain if you wants to use n-gram + beam search decoding. If you wants to have constraint search, please use the CPU version of torchaudio, and if you want to speedup as much as possible the decoding, please use the CUDA version.

Parameters:
  • tokens (list or str) – The list of tokens or the path to the tokens file. If this is a path, then the file should contain one token per line.

  • lexicon (str, default: None) – Lexicon file containing the possible words and corresponding spellings. Each line consists of a word and its space separated spelling. If None, uses lexicon-free decoding. (default: None)

  • lm (str, optional) – A path containing KenLM language model or None if not using a language model. (default: None)

  • lm_dict (str, optional) – File consisting of the dictionary used for the LM, with a word per line sorted by LM index. If decoding with a lexicon, entries in lm_dict must also occur in the lexicon file. If None, dictionary for LM is constructed using the lexicon file. (default: None)

  • topk (int, optional) – Number of top CTCHypothesis to return. (default: 1)

  • beam_size (int, optional) – Numbers of hypotheses to hold after each decode step. (default: 50)

  • beam_size_token (int, optional) – Max number of tokens to consider at each decode step. If None, it is set to the total number of tokens. (default: None)

  • beam_threshold (float, optional) – Threshold for pruning hypothesis. (default: 50)

  • lm_weight (float, optional) – Weight of language model. (default: 2)

  • word_score (float, optional) – Word insertion score. (default: 0)

  • unk_score (float, optional) – Unknown word insertion score. (default: float(“-inf”))

  • sil_score (float, optional) – Silence insertion score. (default: 0)

  • log_add (bool, optional) – Whether to use use logadd when merging hypotheses. (default: False)

  • blank_index (int or str, optional) – Index of the blank token. If tokens is a file path, then this should be an str. Otherwise, this should be a int. (default: 0)

  • sil_index (int or str, optional) – Index of the silence token. If tokens is a file path, then this should be an str. Otherwise, this should be a int. (default: 0)

  • unk_word (str, optional) – Unknown word token. (default: “<unk>”)

  • using_cpu_decoder (bool, optional) – Whether to use the CPU searcher. If False, then the CUDA decoder is used. (default: True)

  • blank_skip_threshold (float, optional) – Skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding (default: 1.0). Note: This is only used when using the CUDA decoder, and it might worsen the WER/CER results. Use it at your own risk.

Example

>>> import torch
>>> from speechbrain.decoders import TorchAudioCTCPrefixBeamSearcher
>>> probs = torch.tensor([[[0.2, 0.0, 0.8],
...                   [0.4, 0.0, 0.6]]])
>>> log_probs = torch.log(probs)
>>> lens = torch.tensor([1.0])
>>> blank_index = 2
>>> vocab_list = ['a', 'b', '-']
>>> searcher = TorchAudioCTCPrefixBeamSearcher(tokens=vocab_list, blank_index=blank_index, sil_index=blank_index) 
>>> hyps = searcher(probs, lens) 
decode_beams(log_probs: Tensor, wav_len: Tensor | None = None) List[List[CTCHypothesis]][source]

Decode log_probs using TorchAudio CTC decoder.

If using_cpu_decoder=True then log_probs and wav_len are moved to CPU before decoding. When using CUDA CTC decoder, the timestep information is not available. Therefore, the timesteps in the returned hypotheses are set to None.

Make sure that the input are in the log domain. The decoder will fail to decode logits or probabilities. The input should be the log probabilities of the CTC output.

Parameters:
  • log_probs (torch.Tensor) – The log probabilities of the input audio. Shape: (batch_size, seq_length, vocab_size)

  • wav_len (torch.Tensor, default: None) – The speechbrain-style relative length. Shape: (batch_size,) If None, then the length of each audio is assumed to be seq_length.

Returns:

The decoded hypotheses. The outer list is over the batch dimension, and the inner list is over the topk dimension.

Return type:

list of list of CTCHypothesis

__call__(log_probs: Tensor, wav_len: Tensor | None = None) List[List[CTCHypothesis]][source]

Decode log_probs using TorchAudio CTC decoder.

If using_cpu_decoder=True then log_probs and wav_len are moved to CPU before decoding. When using CUDA CTC decoder, the timestep information is not available. Therefore, the timesteps in the returned hypotheses are set to None.

Parameters:
  • log_probs (torch.Tensor) – The log probabilities of the input audio. Shape: (batch_size, seq_length, vocab_size)

  • wav_len (torch.Tensor, default: None) – The speechbrain-style relative length. Shape: (batch_size,) If None, then the length of each audio is assumed to be seq_length.

Returns:

The decoded hypotheses. The outer list is over the batch dimension, and the inner list is over the topk dimension.

Return type:

list of list of CTCHypothesis