"""Decoders and output normalization for CTC.
Authors
* Mirco Ravanelli 2020
* Aku Rouhe 2020
* Sung-Lin Yeh 2020
* Adel Moumen 2023, 2024
"""
import dataclasses
import heapq
import math
import warnings
from itertools import groupby
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from speechbrain.dataio.dataio import length_to_mask
from speechbrain.utils.logger import get_logger
logger = get_logger(__name__)
[docs]
class CTCPrefixScore:
"""This class implements the CTC prefix score of Algorithm 2 in
reference: https://www.merl.com/publications/docs/TR2017-190.pdf.
Official implementation: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
Arguments
---------
x : torch.Tensor
The encoder states.
enc_lens : torch.Tensor
The actual length of each enc_states sequence.
blank_index : int
The index of the blank token.
eos_index : int
The index of the end-of-sequence (eos) token.
ctc_window_size: int
Compute the ctc scores over the time frames using windowing based on attention peaks.
If 0, no windowing applied.
"""
def __init__(self, x, enc_lens, blank_index, eos_index, ctc_window_size=0):
self.blank_index = blank_index
self.eos_index = eos_index
self.batch_size = x.size(0)
self.max_enc_len = x.size(1)
self.vocab_size = x.size(-1)
self.device = x.device
self.minus_inf = -1e20
self.last_frame_index = enc_lens - 1
self.ctc_window_size = ctc_window_size
self.prefix_length = -1
# mask frames > enc_lens
mask = 1 - length_to_mask(enc_lens)
mask = mask.unsqueeze(-1).expand(-1, -1, x.size(-1)).eq(1)
x.masked_fill_(mask, self.minus_inf)
x[:, :, 0] = x[:, :, 0].masked_fill_(mask[:, :, 0], 0)
# dim=0: xnb, nonblank posteriors, dim=1: xb, blank posteriors
xnb = x.transpose(0, 1)
xb = (
xnb[:, :, self.blank_index]
.unsqueeze(2)
.expand(-1, -1, self.vocab_size)
)
# (2, L, batch_size * beam_size, vocab_size)
self.x = torch.stack([xnb, xb])
# indices of batch.
self.batch_index = torch.arange(self.batch_size, device=self.device)
[docs]
@torch.no_grad()
def forward_step(self, inp_tokens, states, candidates=None, attn=None):
"""This method if one step of forwarding operation
for the prefix ctc scorer.
Arguments
---------
inp_tokens : torch.Tensor
The last chars of prefix label sequences g, where h = g + c.
states : tuple
Previous ctc states.
candidates : torch.Tensor
(batch_size * beam_size, ctc_beam_size), The topk candidates for rescoring.
If given, performing partial ctc scoring.
attn : torch.Tensor
(batch_size * beam_size, max_enc_len), The attention weights.
Returns
-------
new_psi : torch.Tensor
(r, psi, scoring_table) : tuple
"""
n_bh = inp_tokens.size(0)
beam_size = n_bh // self.batch_size
last_char = inp_tokens
self.prefix_length += 1
self.num_candidates = (
self.vocab_size if candidates is None else candidates.size(-1)
)
if states is None:
# r_prev: (L, 2, batch_size * beam_size)
r_prev = torch.full(
(self.max_enc_len, 2, self.batch_size, beam_size),
self.minus_inf,
device=self.device,
)
# Accumulate blank posteriors at each step
r_prev[:, 1] = torch.cumsum(
self.x[0, :, :, self.blank_index], 0
).unsqueeze(2)
r_prev = r_prev.view(-1, 2, n_bh)
psi_prev = torch.full(
(n_bh, self.vocab_size), 0.0, device=self.device
)
else:
r_prev, psi_prev = states
# for partial search
if candidates is not None:
# The first index of each candidate.
cand_offset = self.batch_index * self.vocab_size
scoring_table = torch.full(
(n_bh, self.vocab_size),
-1,
dtype=torch.long,
device=self.device,
)
# Assign indices of candidates to their positions in the table
col_index = torch.arange(n_bh, device=self.device).unsqueeze(1)
scoring_table[col_index, candidates] = torch.arange(
self.num_candidates, device=self.device
)
# Select candidates indices for scoring
scoring_index = (
candidates
+ cand_offset.unsqueeze(1).repeat(1, beam_size).view(-1, 1)
).view(-1)
x_inflate = torch.index_select(
self.x.view(2, -1, self.batch_size * self.vocab_size),
2,
scoring_index,
).view(2, -1, n_bh, self.num_candidates)
# for full search
else:
scoring_table = None
# Inflate x to (2, -1, batch_size * beam_size, num_candidates)
# It is used to compute forward probs in a batched way
x_inflate = (
self.x.unsqueeze(3)
.repeat(1, 1, 1, beam_size, 1)
.view(2, -1, n_bh, self.num_candidates)
)
# Prepare forward probs
r = torch.full(
(self.max_enc_len, 2, n_bh, self.num_candidates),
self.minus_inf,
device=self.device,
)
r.fill_(self.minus_inf)
# (Alg.2-6)
if self.prefix_length == 0:
r[0, 0] = x_inflate[0, 0]
# (Alg.2-10): phi = prev_nonblank + prev_blank = r_t-1^nb(g) + r_t-1^b(g)
r_sum = torch.logsumexp(r_prev, 1)
phi = r_sum.unsqueeze(2).repeat(1, 1, self.num_candidates)
# (Alg.2-10): if last token of prefix g in candidates, phi = prev_b + 0
if candidates is not None:
for i in range(n_bh):
pos = scoring_table[i, last_char[i]]
if pos != -1:
phi[:, i, pos] = r_prev[:, 1, i]
else:
for i in range(n_bh):
phi[:, i, last_char[i]] = r_prev[:, 1, i]
# Start, end frames for scoring (|g| < |h|).
# Scoring based on attn peak if ctc_window_size > 0
if self.ctc_window_size == 0 or attn is None:
start = max(1, self.prefix_length)
end = self.max_enc_len
else:
_, attn_peak = torch.max(attn, dim=1)
max_frame = torch.max(attn_peak).item() + self.ctc_window_size
min_frame = torch.min(attn_peak).item() - self.ctc_window_size
start = max(max(1, self.prefix_length), int(min_frame))
end = min(self.max_enc_len, int(max_frame))
# Compute forward prob log(r_t^nb(h)) and log(r_t^b(h)):
for t in range(start, end):
# (Alg.2-11): dim=0, p(h|cur step is nonblank) = [p(prev step=y) + phi] * p(c)
rnb_prev = r[t - 1, 0]
# (Alg.2-12): dim=1, p(h|cur step is blank) = [p(prev step is blank) + p(prev step is nonblank)] * p(blank)
rb_prev = r[t - 1, 1]
r_ = torch.stack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view(
2, 2, n_bh, self.num_candidates
)
r[t] = torch.logsumexp(r_, 1) + x_inflate[:, t]
# Compute the predix prob, psi
psi_init = r[start - 1, 0].unsqueeze(0)
# phi is prob at t-1 step, shift one frame and add it to the current prob p(c)
phix = torch.cat((phi[0].unsqueeze(0), phi[:-1]), dim=0) + x_inflate[0]
# (Alg.2-13): psi = psi + phi * p(c)
if candidates is not None:
psi = torch.full(
(n_bh, self.vocab_size), self.minus_inf, device=self.device
)
psi_ = torch.logsumexp(
torch.cat((phix[start:end], psi_init), dim=0), dim=0
)
# only assign prob to candidates
for i in range(n_bh):
psi[i, candidates[i]] = psi_[i]
else:
psi = torch.logsumexp(
torch.cat((phix[start:end], psi_init), dim=0), dim=0
)
# (Alg.2-3): if c = <eos>, psi = log(r_T^n(g) + r_T^b(g)), where T is the length of max frames
for i in range(n_bh):
psi[i, self.eos_index] = r_sum[
self.last_frame_index[i // beam_size], i
]
if self.eos_index != self.blank_index:
# Exclude blank probs for joint scoring
psi[:, self.blank_index] = self.minus_inf
return psi - psi_prev, (r, psi, scoring_table)
[docs]
def permute_mem(self, memory, index):
"""This method permutes the CTC model memory
to synchronize the memory index with the current output.
Arguments
---------
memory : No limit
The memory variable to be permuted.
index : torch.Tensor
The index of the previous path.
Return
------
The variable of the memory being permuted.
"""
r, psi, scoring_table = memory
beam_size = index.size(1)
n_bh = self.batch_size * beam_size
# The first index of each batch.
beam_offset = self.batch_index * beam_size
# The index of top-K vocab came from in (t-1) timesteps at batch * beam * vocab dimension.
cand_index = (
index + beam_offset.unsqueeze(1).expand_as(index) * self.vocab_size
).view(n_bh)
# synchronize forward prob
psi = torch.index_select(psi.view(-1), dim=0, index=cand_index)
psi = (
psi.view(-1, 1)
.repeat(1, self.vocab_size)
.view(n_bh, self.vocab_size)
)
# The index of top-K vocab came from in (t-1) timesteps at batch * beam dimension.
hyp_index = (
torch.div(index, self.vocab_size, rounding_mode="floor")
+ beam_offset.unsqueeze(1).expand_as(index)
).view(n_bh)
# synchronize ctc states
if scoring_table is not None:
selected_vocab = (index % self.vocab_size).view(-1)
score_index = scoring_table[hyp_index, selected_vocab]
score_index[score_index == -1] = 0
cand_index = score_index + hyp_index * self.num_candidates
r = torch.index_select(
r.view(-1, 2, n_bh * self.num_candidates), dim=-1, index=cand_index
)
r = r.view(-1, 2, n_bh)
return r, psi
[docs]
def filter_ctc_output(string_pred, blank_id=-1):
"""Apply CTC output merge and filter rules.
Removes the blank symbol and output repetitions.
Arguments
---------
string_pred : list
A list containing the output strings/ints predicted by the CTC system.
blank_id : int, string
The id of the blank.
Returns
-------
list
The output predicted by CTC without the blank symbol and
the repetitions.
Example
-------
>>> string_pred = ['a','a','blank','b','b','blank','c']
>>> string_out = filter_ctc_output(string_pred, blank_id='blank')
>>> print(string_out)
['a', 'b', 'c']
"""
if isinstance(string_pred, list):
# Filter the repetitions
string_out = [i[0] for i in groupby(string_pred)]
# Filter the blank symbol
string_out = list(filter(lambda elem: elem != blank_id, string_out))
else:
raise ValueError("filter_ctc_out can only filter python lists")
return string_out
[docs]
def ctc_greedy_decode(probabilities, seq_lens, blank_id=-1):
"""Greedy decode a batch of probabilities and apply CTC rules.
Arguments
---------
probabilities : torch.tensor
Output probabilities (or log-probabilities) from the network with shape
[batch, lengths, probabilities]
seq_lens : torch.tensor
Relative true sequence lengths (to deal with padded inputs),
the longest sequence has length 1.0, others a value between zero and one
shape [batch, lengths].
blank_id : int, string
The blank symbol/index. Default: -1. If a negative number is given,
it is assumed to mean counting down from the maximum possible index,
so that -1 refers to the maximum possible index.
Returns
-------
list
Outputs as Python list of lists, with "ragged" dimensions; padding
has been removed.
Example
-------
>>> import torch
>>> probs = torch.tensor([[[0.3, 0.7], [0.0, 0.0]],
... [[0.2, 0.8], [0.9, 0.1]]])
>>> lens = torch.tensor([0.51, 1.0])
>>> blank_id = 0
>>> ctc_greedy_decode(probs, lens, blank_id)
[[1], [1]]
"""
if isinstance(blank_id, int) and blank_id < 0:
blank_id = probabilities.shape[-1] + blank_id
batch_max_len = probabilities.shape[1]
batch_outputs = []
for seq, seq_len in zip(probabilities, seq_lens):
actual_size = int(torch.round(seq_len * batch_max_len))
scores, predictions = torch.max(seq.narrow(0, 0, actual_size), dim=1)
out = filter_ctc_output(predictions.tolist(), blank_id=blank_id)
batch_outputs.append(out)
return batch_outputs
[docs]
@dataclasses.dataclass
class CTCBeam:
"""This class handle the CTC beam information during decoding.
Arguments
---------
text : str
The current text of the beam.
full_text : str
The full text of the beam.
next_word : str
The next word to be added to the beam.
partial_word : str
The partial word being added to the beam.
last_token : str, optional
The last token of the beam.
last_token_index : int, optional
The index of the last token of the beam.
text_frames : List[Tuple[int, int]]
The start and end frame of the text.
partial_frames : Tuple[int, int]
The start and end frame of the partial word.
p : float
The probability of the beam.
p_b : float
The probability of the beam ending in a blank.
p_nb : float
The probability of the beam not ending in a blank.
n_p_b : float
The previous probability of the beam ending in a blank.
n_p_nb : float
The previous probability of the beam not ending in a blank.
score : float
The score of the beam (LM + CTC)
score_ctc : float
The CTC score computed.
Example
-------
>>> beam = CTCBeam(
... text="",
... full_text="",
... next_word="",
... partial_word="",
... last_token=None,
... last_token_index=None,
... text_frames=[(0, 0)],
... partial_frames=(0, 0),
... p=-math.inf,
... p_b=-math.inf,
... p_nb=-math.inf,
... n_p_b=-math.inf,
... n_p_nb=-math.inf,
... score=-math.inf,
... score_ctc=-math.inf,
... )
"""
text: str
full_text: str
next_word: str
partial_word: str
last_token: Optional[str]
last_token_index: Optional[int]
text_frames: List[Tuple[int, int]]
partial_frames: Tuple[int, int]
p: float = -math.inf
p_b: float = -math.inf
p_nb: float = -math.inf
n_p_b: float = -math.inf
n_p_nb: float = -math.inf
score: float = -math.inf
score_ctc: float = -math.inf
[docs]
@classmethod
def from_lm_beam(self, lm_beam: "LMCTCBeam") -> "CTCBeam":
"""Create a CTCBeam from a LMCTCBeam
Arguments
---------
lm_beam : LMCTCBeam
The LMCTCBeam to convert.
Returns
-------
CTCBeam
The CTCBeam converted.
"""
return CTCBeam(
text=lm_beam.text,
full_text=lm_beam.full_text,
next_word=lm_beam.next_word,
partial_word=lm_beam.partial_word,
last_token=lm_beam.last_token,
last_token_index=lm_beam.last_token_index,
text_frames=lm_beam.text_frames,
partial_frames=lm_beam.partial_frames,
p=lm_beam.p,
p_b=lm_beam.p_b,
p_nb=lm_beam.p_nb,
n_p_b=lm_beam.n_p_b,
n_p_nb=lm_beam.n_p_nb,
score=lm_beam.score,
score_ctc=lm_beam.score_ctc,
)
[docs]
def step(self) -> None:
"""Update the beam probabilities."""
self.p_b, self.p_nb = self.n_p_b, self.n_p_nb
self.n_p_b = self.n_p_nb = -math.inf
self.score_ctc = np.logaddexp(self.p_b, self.p_nb)
self.score = self.score_ctc
[docs]
@dataclasses.dataclass
class LMCTCBeam(CTCBeam):
"""This class handle the LM scores during decoding.
Arguments
---------
lm_score: float
The LM score of the beam.
**kwargs
See CTCBeam for the other arguments.
"""
lm_score: float = -math.inf
[docs]
@dataclasses.dataclass
class CTCHypothesis:
"""This class is a data handler over the generated hypotheses.
This class is the default output of the CTC beam searchers.
It can be re-used for other decoders if using
the beam searchers in an online fashion.
Arguments
---------
text : str
The text of the hypothesis.
last_lm_state : None
The last LM state of the hypothesis.
score : float
The score of the hypothesis.
lm_score : float
The LM score of the hypothesis.
text_frames : List[Tuple[str, Tuple[int, int]]], optional
The list of the text and the corresponding frames.
"""
text: str
last_lm_state: None
score: float
lm_score: float
text_frames: list = None
[docs]
class CTCBaseSearcher(torch.nn.Module):
"""CTCBaseSearcher class to be inherited by other
CTC beam searchers.
This class provides the basic functionalities for
CTC beam search decoding.
The space_token is required with a non-sentencepiece vocabulary list
if your transcription is expecting to contain spaces.
Arguments
---------
blank_index : int
The index of the blank token.
vocab_list : list
The list of the vocabulary tokens.
space_token : int, optional
The index of the space token. (default: -1)
kenlm_model_path : str, optional
The path to the kenlm model. Use .bin for a faster loading.
If None, no language model will be used. (default: None)
unigrams : list, optional
The list of known word unigrams. (default: None)
alpha : float
Weight for language model during shallow fusion. (default: 0.5)
beta : float
Weight for length score adjustment of during scoring. (default: 1.5)
unk_score_offset : float
Amount of log score offset for unknown tokens. (default: -10.0)
score_boundary : bool
Whether to have kenlm respect boundaries when scoring. (default: True)
beam_size : int, optional
The width of the beam. (default: 100)
beam_prune_logp : float, optional
The pruning threshold for the beam. (default: -10.0)
token_prune_min_logp : float, optional
The pruning threshold for the tokens. (default: -5.0)
prune_history : bool, optional
Whether to prune the history. (default: True)
Note: when using topk > 1, this should be set to False as
it is pruning a lot of beams.
blank_skip_threshold : float, optional
Skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding.
Note: This is only used when using the CUDA decoder, and it might worsen the WER/CER results. Use it at your own risk. (default: 1.0)
topk : int, optional
The number of top hypotheses to return. (default: 1)
spm_token: str, optional
The sentencepiece token. (default: "▁")
Example
-------
>>> blank_index = 0
>>> vocab_list = ['blank', 'a', 'b', 'c', ' ']
>>> space_token = ' '
>>> kenlm_model_path = None
>>> unigrams = None
>>> beam_size = 100
>>> beam_prune_logp = -10.0
>>> token_prune_min_logp = -5.0
>>> prune_history = True
>>> blank_skip_threshold = 1.0
>>> topk = 1
>>> searcher = CTCBaseSearcher(
... blank_index=blank_index,
... vocab_list=vocab_list,
... space_token=space_token,
... kenlm_model_path=kenlm_model_path,
... unigrams=unigrams,
... beam_size=beam_size,
... beam_prune_logp=beam_prune_logp,
... token_prune_min_logp=token_prune_min_logp,
... prune_history=prune_history,
... blank_skip_threshold=blank_skip_threshold,
... topk=topk,
... )
"""
def __init__(
self,
blank_index: int,
vocab_list: List[str],
space_token: str = " ",
kenlm_model_path: Union[None, str] = None,
unigrams: Union[None, List[str]] = None,
alpha: float = 0.5,
beta: float = 1.5,
unk_score_offset: float = -10.0,
score_boundary: bool = True,
beam_size: int = 100,
beam_prune_logp: int = -10.0,
token_prune_min_logp: int = -5.0,
prune_history: bool = True,
blank_skip_threshold: Union[None, int] = 1.0,
topk: int = 1,
spm_token: str = "▁",
):
super().__init__()
self.blank_index = blank_index
self.vocab_list = vocab_list
self.space_token = space_token
self.kenlm_model_path = kenlm_model_path
self.unigrams = unigrams
self.alpha = alpha
self.beta = beta
self.unk_score_offset = unk_score_offset
self.score_boundary = score_boundary
self.beam_size = beam_size
self.beam_prune_logp = beam_prune_logp
self.token_prune_min_logp = token_prune_min_logp
self.prune_history = prune_history
self.blank_skip_threshold = math.log(blank_skip_threshold)
self.topk = topk
self.spm_token = spm_token
# check if the vocab is coming from SentencePiece
self.is_spm = any(
[str(s).startswith(self.spm_token) for s in vocab_list]
)
# fetch the index of space_token
if not self.is_spm:
try:
self.space_index = vocab_list.index(space_token)
except ValueError:
logger.warning(
f"space_token `{space_token}` not found in the vocabulary."
"Using value -1 as `space_index`."
"Note: If your transcription is not expected to contain spaces, "
"you can ignore this warning."
)
self.space_index = -1
logger.info(f"Found `space_token` at index {self.space_index}.")
self.kenlm_model = None
if kenlm_model_path is not None:
try:
import kenlm # type: ignore
from speechbrain.decoders.language_model import (
LanguageModel,
load_unigram_set_from_arpa,
)
except ImportError:
raise ImportError(
"kenlm python bindings are not installed. To install it use: "
"pip install https://github.com/kpu/kenlm/archive/master.zip"
)
self.kenlm_model = kenlm.Model(kenlm_model_path)
if kenlm_model_path is not None and kenlm_model_path.endswith(".arpa"):
logger.info(
"Using arpa instead of binary LM file, decoder instantiation might be slow."
)
if unigrams is None and kenlm_model_path is not None:
if kenlm_model_path.endswith(".arpa"):
unigrams = load_unigram_set_from_arpa(kenlm_model_path)
else:
logger.warning(
"Unigrams not provided and cannot be automatically determined from LM file (only "
"arpa format). Decoding accuracy might be reduced."
)
if self.kenlm_model is not None:
self.lm = LanguageModel(
kenlm_model=self.kenlm_model,
unigrams=unigrams,
alpha=self.alpha,
beta=self.beta,
unk_score_offset=self.unk_score_offset,
score_boundary=self.score_boundary,
)
else:
self.lm = None
[docs]
def partial_decoding(
self,
log_probs: torch.Tensor,
beams: List[CTCBeam],
cached_lm_scores: dict,
cached_p_lm_scores: dict,
processed_frames: int = 0,
):
"""Perform a single step of decoding.
Arguments
---------
log_probs : torch.Tensor
The log probabilities of the CTC output.
beams : list
The list of the beams.
cached_lm_scores : dict
The cached language model scores.
cached_p_lm_scores : dict
The cached prefix language model scores.
processed_frames : int, default: 0
The start frame of the current decoding step.
"""
raise NotImplementedError
[docs]
def normalize_whitespace(self, text: str) -> str:
"""Efficiently normalize whitespace.
Arguments
---------
text : str
The text to normalize.
Returns
-------
str
The normalized text.
"""
return " ".join(text.split())
[docs]
def merge_tokens(self, token_1: str, token_2: str) -> str:
"""Merge two tokens, and avoid empty ones.
Taken from: https://github.com/kensho-technologies/pyctcdecode
Arguments
---------
token_1 : str
The first token.
token_2 : str
The second token.
Returns
-------
str
The merged token.
"""
if len(token_2) == 0:
text = token_1
elif len(token_1) == 0:
text = token_2
else:
text = token_1 + " " + token_2
return text
[docs]
def merge_beams(self, beams: List[CTCBeam]) -> List[CTCBeam]:
"""Merge beams with the same text.
Taken from: https://github.com/kensho-technologies/pyctcdecode
Arguments
---------
beams : list
The list of the beams.
Returns
-------
list
The list of CTCBeam merged.
"""
beam_dict = {}
for beam in beams:
new_text = self.merge_tokens(beam.text, beam.next_word)
hash_idx = (new_text, beam.partial_word, beam.last_token)
if hash_idx not in beam_dict:
beam_dict[hash_idx] = beam
else:
# We've already seen this text - we want to combine the scores
beam_dict[hash_idx] = dataclasses.replace(
beam,
score=np.logaddexp(beam_dict[hash_idx].score, beam.score),
)
return list(beam_dict.values())
[docs]
def sort_beams(self, beams: List[CTCBeam]) -> List[CTCBeam]:
"""Sort beams by lm_score.
Arguments
---------
beams : list
The list of CTCBeam.
Returns
-------
list
The list of CTCBeam sorted.
"""
return heapq.nlargest(self.beam_size, beams, key=lambda x: x.lm_score)
def _prune_history(
self, beams: List[CTCBeam], lm_order: int
) -> List[CTCBeam]:
"""Filter out beams that are the same over max_ngram history.
Since n-gram language models have a finite history when scoring a new token, we can use that
fact to prune beams that only differ early on (more than n tokens in the past) and keep only the
higher scoring ones. Note that this helps speed up the decoding process but comes at the cost of
some amount of beam diversity. If more than the top beam is used in the output it should
potentially be disabled.
Taken from: https://github.com/kensho-technologies/pyctcdecode
Arguments
---------
beams : list
The list of the beams.
lm_order : int
The order of the language model.
Returns
-------
list
The list of CTCBeam.
"""
# let's keep at least 1 word of history
min_n_history = max(1, lm_order - 1)
seen_hashes = set()
filtered_beams = []
# for each beam after this, check if we need to add it
for lm_beam in beams:
# hash based on history that can still affect lm scoring going forward
hash_idx = (
tuple(lm_beam.text.split()[-min_n_history:]),
lm_beam.partial_word,
lm_beam.last_token,
)
if hash_idx not in seen_hashes:
filtered_beams.append(CTCBeam.from_lm_beam(lm_beam))
seen_hashes.add(hash_idx)
return filtered_beams
[docs]
def finalize_decoding(
self,
beams: List[CTCBeam],
cached_lm_scores: dict,
cached_p_lm_scores: dict,
force_next_word=False,
is_end=False,
) -> List[CTCBeam]:
"""Finalize the decoding process by adding and scoring the last partial word.
Arguments
---------
beams : list
The list of CTCBeam.
cached_lm_scores : dict
The cached language model scores.
cached_p_lm_scores : dict
The cached prefix language model scores.
force_next_word : bool, default: False
Whether to force the next word.
is_end : bool, default: False
Whether the end of the sequence has been reached.
Returns
-------
list
The list of the CTCBeam.
"""
if force_next_word or is_end:
new_beams = []
for beam in beams:
new_token_times = (
beam.text_frames
if beam.partial_word == ""
else beam.text_frames + [beam.partial_frames]
)
new_beams.append(
CTCBeam(
text=beam.text,
full_text=beam.full_text,
next_word=beam.partial_word,
partial_word="",
last_token=None,
last_token_index=None,
text_frames=new_token_times,
partial_frames=(-1, -1),
score=beam.score,
)
)
new_beams = self.merge_beams(new_beams)
else:
new_beams = list(beams)
scored_beams = self.get_lm_beams(
new_beams, cached_lm_scores, cached_p_lm_scores
)
# remove beam outliers
max_score = max([b.lm_score for b in scored_beams])
scored_beams = [
b
for b in scored_beams
if b.lm_score >= max_score + self.beam_prune_logp
]
sorted_beams = self.sort_beams(scored_beams)
return sorted_beams
[docs]
def decode_beams(
self,
log_probs: torch.Tensor,
wav_lens: Optional[torch.Tensor] = None,
lm_start_state: Any = None,
) -> List[List[CTCHypothesis]]:
"""Decodes the input log probabilities of the CTC output.
It automatically converts the SpeechBrain's relative length of the wav input
to the absolute length.
Make sure that the input are in the log domain. The decoder will fail to decode
logits or probabilities. The input should be the log probabilities of the CTC output.
Arguments
---------
log_probs : torch.Tensor
The log probabilities of the CTC output.
The expected shape is [batch_size, seq_length, vocab_size].
wav_lens : torch.Tensor, optional (default: None)
The SpeechBrain's relative length of the wav input.
lm_start_state : Any, optional (default: None)
The start state of the language model.
Returns
-------
list of list
The list of topk list of CTCHypothesis.
"""
# check that the last dimension of log_probs is equal to the vocab size
if log_probs.size(2) != len(self.vocab_list):
warnings.warn(
f"Vocab size mismatch: log_probs vocab dim is {log_probs.size(2)} "
f"while vocab_list is {len(self.vocab_list)}. "
"During decoding, going to truncate the log_probs vocab dim to match vocab_list."
)
# compute wav_lens and cast to numpy as it is faster
if wav_lens is not None:
wav_lens = log_probs.size(1) * wav_lens
wav_lens = wav_lens.cpu().numpy().astype(int)
else:
wav_lens = [log_probs.size(1)] * log_probs.size(0)
log_probs = log_probs.cpu().numpy()
hyps = [
self.decode_log_probs(log_prob, wav_len, lm_start_state)
for log_prob, wav_len in zip(log_probs, wav_lens)
]
return hyps
[docs]
def __call__(
self,
log_probs: torch.Tensor,
wav_lens: Optional[torch.Tensor] = None,
lm_start_state: Any = None,
) -> List[List[CTCHypothesis]]:
"""Decodes the log probabilities of the CTC output.
It automatically converts the SpeechBrain's relative length of the wav input
to the absolute length.
Each tensors is converted to numpy and CPU as it is faster and consumes less memory.
Arguments
---------
log_probs : torch.Tensor
The log probabilities of the CTC output.
The expected shape is [batch_size, seq_length, vocab_size].
wav_lens : torch.Tensor, optional (default: None)
The SpeechBrain's relative length of the wav input.
lm_start_state : Any, optional (default: None)
The start state of the language model.
Returns
-------
list of list
The list of topk list of CTCHypothesis.
"""
return self.decode_beams(log_probs, wav_lens, lm_start_state)
[docs]
def partial_decode_beams(
self,
log_probs: torch.Tensor,
cached_lm_scores: dict,
cached_p_lm_scores: dict,
beams: List[CTCBeam],
processed_frames: int,
force_next_word=False,
is_end=False,
) -> List[CTCBeam]:
"""Perform a single step of decoding.
Arguments
---------
log_probs : torch.Tensor
The log probabilities of the CTC output.
cached_lm_scores : dict
The cached language model scores.
cached_p_lm_scores : dict
The cached prefix language model scores.
beams : list
The list of the beams.
processed_frames : int
The start frame of the current decoding step.
force_next_word : bool, optional (default: False)
Whether to force the next word.
is_end : bool, optional (default: False)
Whether the end of the sequence has been reached.
Returns
-------
list
The list of CTCBeam.
"""
beams = self.partial_decoding(
log_probs,
beams,
cached_lm_scores,
cached_p_lm_scores,
processed_frames=processed_frames,
)
trimmed_beams = self.finalize_decoding(
beams,
cached_lm_scores,
cached_p_lm_scores,
force_next_word=force_next_word,
is_end=is_end,
)
return trimmed_beams
[docs]
def decode_log_probs(
self,
log_probs: torch.Tensor,
wav_len: int,
lm_start_state: Optional[Any] = None,
) -> List[CTCHypothesis]:
"""Decodes the log probabilities of the CTC output.
Arguments
---------
log_probs : torch.Tensor
The log probabilities of the CTC output.
The expected shape is [seq_length, vocab_size].
wav_len : int
The length of the wav input.
lm_start_state : Any, optional (default: None)
The start state of the language model.
Returns
-------
list
The topk list of CTCHypothesis.
"""
# prepare caching/state for language model
language_model = self.lm
if language_model is None:
cached_lm_scores = {}
else:
if lm_start_state is None:
start_state = language_model.get_start_state()
else:
start_state = lm_start_state
cached_lm_scores = {("", False): (0.0, start_state)}
cached_p_lm_scores: Dict[str, float] = {}
beams = [
CTCBeam(
text="",
full_text="",
next_word="",
partial_word="",
last_token=None,
last_token_index=None,
text_frames=[],
partial_frames=(-1, -1),
score=0.0,
score_ctc=0.0,
p_b=0.0,
)
]
# loop over the frames and perform the decoding
beams = self.partial_decoding(
log_probs, wav_len, beams, cached_lm_scores, cached_p_lm_scores
)
# finalize decoding by adding and scoring the last partial word
trimmed_beams = self.finalize_decoding(
beams,
cached_lm_scores,
cached_p_lm_scores,
force_next_word=True,
is_end=True,
)
# transform the beams into hypotheses and select the topk
output_beams = [
CTCHypothesis(
text=self.normalize_whitespace(lm_beam.text),
last_lm_state=(
cached_lm_scores[(lm_beam.text, True)][-1]
if (lm_beam.text, True) in cached_lm_scores
else None
),
text_frames=list(
zip(lm_beam.text.split(), lm_beam.text_frames)
),
score=lm_beam.score,
lm_score=lm_beam.lm_score,
)
for lm_beam in trimmed_beams
][: self.topk]
return output_beams
[docs]
class CTCBeamSearcher(CTCBaseSearcher):
"""CTC Beam Search is a Beam Search for CTC which does not keep track of
the blank and non-blank probabilities. Each new token probability is
added to the general score, and each beams that share the same text are
merged together.
The implementation supports n-gram scoring on words and SentencePiece tokens. The input
is expected to be a log-probabilities tensor of shape [batch, time, vocab_size].
The main advantage of this CTCBeamSearcher over the CTCPrefixBeamSearcher is that it is
relatively faster, and obtains slightly better results. However, the implementation is
based on the one from the PyCTCDecode toolkit, adapted for the SpeechBrain's needs and does
not follow a specific paper. We do recommend to use the CTCPrefixBeamSearcher if you want
to cite the appropriate paper for the decoding method.
Several heuristics are implemented to speed up the decoding process:
- pruning of the beam : the beams are pruned if their score is lower than
the best beam score minus the beam_prune_logp
- pruning of the tokens : the tokens are pruned if their score is lower than
the token_prune_min_logp
- pruning of the history : the beams are pruned if they are the same over
max_ngram history
- skipping of the blank : the frame is skipped if the blank probability is
higher than the blank_skip_threshold
Note: if the Acoustic Model is not trained, the Beam Search will
take a lot of time. We do recommend to use Greedy Search during validation
until the model is fully trained and ready to be evaluated on test sets.
Arguments
---------
see CTCBaseSearcher, arguments are directly passed.
Example
-------
>>> import torch
>>> from speechbrain.decoders import CTCBeamSearcher
>>> probs = torch.tensor([[[0.2, 0.0, 0.8],
... [0.4, 0.0, 0.6]]])
>>> log_probs = torch.log(probs)
>>> lens = torch.tensor([1.0])
>>> blank_index = 2
>>> vocab_list = ['a', 'b', '-']
>>> searcher = CTCBeamSearcher(blank_index=blank_index, vocab_list=vocab_list)
>>> hyps = searcher(probs, lens)
"""
[docs]
def get_lm_beams(
self,
beams: List[CTCBeam],
cached_lm_scores: dict,
cached_partial_token_scores: dict,
is_eos=False,
) -> List[LMCTCBeam]:
"""Score the beams with the language model if not None, and
return the new beams.
This function is modified and adapted from
https://github.com/kensho-technologies/pyctcdecode
Arguments
---------
beams : list
The list of the beams.
cached_lm_scores : dict
The cached language model scores.
cached_partial_token_scores : dict
The cached partial token scores.
is_eos : bool (default: False)
Whether the end of the sequence has been reached.
Returns
-------
new_beams : list
The list of the new beams.
"""
if self.lm is None:
# no lm is used, lm_score is equal to score and we can return the beams
new_beams = []
for beam in beams:
new_text = self.merge_tokens(beam.text, beam.next_word)
new_beams.append(
LMCTCBeam(
text=new_text,
full_text=beam.full_text,
next_word="",
partial_word=beam.partial_word,
last_token=beam.last_token,
last_token_index=beam.last_token,
text_frames=beam.text_frames,
partial_frames=beam.partial_frames,
score=beam.score,
lm_score=beam.score,
)
)
return new_beams
else:
# lm is used, we need to compute the lm_score
# first we compute the lm_score of the next word
# we check if the next word is in the cache
# if not, we compute the score and add it to the cache
new_beams = []
for beam in beams:
# fast token merge
new_text = self.merge_tokens(beam.text, beam.next_word)
cache_key = (new_text, is_eos)
if cache_key not in cached_lm_scores:
prev_raw_lm_score, start_state = cached_lm_scores[
(beam.text, False)
]
score, end_state = self.lm.score(
start_state, beam.next_word, is_last_word=is_eos
)
raw_lm_score = prev_raw_lm_score + score
cached_lm_scores[cache_key] = (raw_lm_score, end_state)
lm_score, _ = cached_lm_scores[cache_key]
# we score the partial word
word_part = beam.partial_word
if len(word_part) > 0:
if word_part not in cached_partial_token_scores:
cached_partial_token_scores[word_part] = (
self.lm.score_partial_token(word_part)
)
lm_score += cached_partial_token_scores[word_part]
new_beams.append(
LMCTCBeam(
text=new_text,
full_text=beam.full_text,
next_word="",
partial_word=word_part,
last_token=beam.last_token,
last_token_index=beam.last_token,
text_frames=beam.text_frames,
partial_frames=beam.partial_frames,
score=beam.score,
lm_score=beam.score + lm_score,
)
)
return new_beams
[docs]
def partial_decoding(
self,
log_probs: torch.Tensor,
wav_len: int,
beams: List[CTCBeam],
cached_lm_scores: dict,
cached_p_lm_scores: dict,
processed_frames: int = 0,
) -> List[CTCBeam]:
"""Perform CTC Prefix Beam Search decoding.
If self.lm is not None, the language model scores are computed and added to the CTC scores.
Arguments
---------
log_probs : torch.Tensor
The log probabilities of the CTC input.
Shape: (seq_length, vocab_size)
wav_len : int
The length of the input sequence.
beams : list
The list of CTCBeam objects.
cached_lm_scores : dict
The cached language model scores.
cached_p_lm_scores : dict
The cached prefix language model scores.
processed_frames : int
The start frame of the current decoding step. (default: 0)
Returns
-------
beams : list
The list of CTCBeam objects.
"""
# select only the valid frames i.e. the frames that are not padded
log_probs = log_probs[:wav_len]
for frame_index, logit_col in enumerate(
log_probs, start=processed_frames
):
# skip the frame if the blank probability is higher than the threshold
if logit_col[self.blank_index] > self.blank_skip_threshold:
continue
# get the tokens with the highest probability
max_index = logit_col.argmax()
tokens_index_list = set(
np.where(logit_col > self.token_prune_min_logp)[0]
) | {max_index}
new_beams = []
# select tokens that are in the vocab
# this is useful if the logit vocab_size is larger than the vocab_list
tokens_index_list = tokens_index_list & set(
range(len(self.vocab_list))
)
for token_index in tokens_index_list:
p_token = logit_col[token_index]
token = self.vocab_list[token_index]
for beam in beams:
if (
token_index == self.blank_index
or beam.last_token == token
):
if token_index == self.blank_index:
new_end_frame = beam.partial_frames[0]
else:
new_end_frame = frame_index + 1
new_part_frames = (
beam.partial_frames
if token_index == self.blank_index
else (beam.partial_frames[0], new_end_frame)
)
# if blank or repeated token, we only change the score
new_beams.append(
CTCBeam(
text=beam.text,
full_text=beam.full_text,
next_word=beam.next_word,
partial_word=beam.partial_word,
last_token=token,
last_token_index=token_index,
text_frames=beam.text_frames,
partial_frames=new_part_frames,
score=beam.score + p_token,
)
)
elif self.is_spm and token[:1] == self.spm_token:
# remove the spm token at the beginning of the token
clean_token = token[1:]
new_frame_list = (
beam.text_frames
if beam.partial_word == ""
else beam.text_frames + [beam.partial_frames]
)
# If the beginning of the token is the spm_token
# then it means that we are extending the beam with a new word.
# We need to change the new_word with the partial_word
# and reset the partial_word with the new token
new_beams.append(
CTCBeam(
text=beam.text,
full_text=beam.full_text,
next_word=beam.partial_word,
partial_word=clean_token,
last_token=token,
last_token_index=token_index,
text_frames=new_frame_list,
partial_frames=(frame_index, frame_index + 1),
score=beam.score + p_token,
)
)
elif not self.is_spm and token_index == self.space_index:
new_frame_list = (
beam.text_frames
if beam.partial_word == ""
else beam.text_frames + [beam.partial_frames]
)
# same as before but in the case of a non spm vocab
new_beams.append(
CTCBeam(
text=beam.text,
full_text=beam.full_text,
next_word=beam.partial_word,
partial_word="",
last_token=token,
last_token_index=token_index,
text_frames=new_frame_list,
partial_frames=(-1, -1),
score=beam.score + p_token,
)
)
else:
new_part_frames = (
(frame_index, frame_index + 1)
if beam.partial_frames[0] < 0
else (beam.partial_frames[0], frame_index + 1)
)
# last case, we are extending the partial_word with a new token
new_beams.append(
CTCBeam(
text=beam.text,
full_text=beam.full_text,
next_word=beam.next_word,
partial_word=beam.partial_word + token,
last_token=token,
last_token_index=token_index,
text_frames=beam.text_frames,
partial_frames=new_part_frames,
score=beam.score + p_token,
)
)
# we merge the beams with the same text
new_beams = self.merge_beams(new_beams)
# kenlm scoring
scored_beams = self.get_lm_beams(
new_beams, cached_lm_scores, cached_p_lm_scores
)
# remove beam outliers
max_score = max([b.lm_score for b in scored_beams])
scored_beams = [
b
for b in scored_beams
if b.lm_score >= max_score + self.beam_prune_logp
]
trimmed_beams = self.sort_beams(scored_beams)
if self.prune_history:
lm_order = 1 if self.lm is None else self.lm.order
beams = self._prune_history(trimmed_beams, lm_order=lm_order)
else:
beams = [CTCBeam.from_lm_beam(b) for b in trimmed_beams]
return beams
[docs]
class CTCPrefixBeamSearcher(CTCBaseSearcher):
"""CTC Prefix Beam Search is based on the paper
`First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs`
by Awni Y. Hannun and al (https://arxiv.org/abs/1408.2873).
The implementation keep tracks of the blank and non-blank probabilities.
It also supports n-gram scoring on words and SentencePiece tokens. The input
is expected to be a log-probabilities tensor of shape [batch, time, vocab_size].
Several heuristics are implemented to speed up the decoding process:
- pruning of the beam : the beams are pruned if their score is lower than
the best beam score minus the beam_prune_logp
- pruning of the tokens : the tokens are pruned if their score is lower than
the token_prune_min_logp
- pruning of the history : the beams are pruned if they are the same over
max_ngram history
- skipping of the blank : the frame is skipped if the blank probability is
higher than the blank_skip_threshold
Note: The CTCPrefixBeamSearcher can be more unstable than the CTCBeamSearcher
or the TorchAudioCTCPrefixBeamSearch searcher. Please, use it with caution
and check the results carefully.
Note: if the Acoustic Model is not trained, the Beam Search will
take a lot of time. We do recommend to use Greedy Search during validation
until the model is fully trained and ready to be evaluated on test sets.
Note: This implementation does not provide the time alignment of the
hypothesis. If you need it, please use the CTCBeamSearcher.
Arguments
---------
see CTCBaseSearcher, arguments are directly passed.
Example
-------
>>> import torch
>>> from speechbrain.decoders import CTCPrefixBeamSearcher
>>> probs = torch.tensor([[[0.2, 0.0, 0.8],
... [0.4, 0.0, 0.6]]])
>>> log_probs = torch.log(probs)
>>> lens = torch.tensor([1.0])
>>> blank_index = 2
>>> vocab_list = ['a', 'b', '-']
>>> searcher = CTCPrefixBeamSearcher(blank_index=blank_index, vocab_list=vocab_list)
>>> hyps = searcher(probs, lens)
"""
[docs]
def get_lm_beams(
self,
beams: List[CTCBeam],
cached_lm_scores: dict,
cached_partial_token_scores: dict,
is_eos=False,
) -> List[LMCTCBeam]:
"""Score the beams with the language model if not None, and
return the new beams.
This function is modified and adapted from
https://github.com/kensho-technologies/pyctcdecode
Arguments
---------
beams : list
The list of the beams.
cached_lm_scores : dict
The cached language model scores.
cached_partial_token_scores : dict
The cached partial token scores.
is_eos : bool (default: False)
Whether the end of the sequence has been reached.
Returns
-------
new_beams : list
The list of the new beams.
"""
if self.lm is None:
# no lm is used, lm_score is equal to score and we can return the beams
# we have to keep track of the probabilities as well
new_beams = []
for beam in beams:
new_text = self.merge_tokens(beam.full_text, beam.next_word)
new_beams.append(
LMCTCBeam(
text=beam.text,
full_text=new_text,
next_word="",
partial_word=beam.partial_word,
last_token=beam.last_token,
last_token_index=beam.last_token_index,
text_frames=beam.text_frames,
partial_frames=beam.partial_frames,
p=beam.p,
p_b=beam.p_b,
p_nb=beam.p_nb,
n_p_b=beam.n_p_b,
n_p_nb=beam.n_p_nb,
score=beam.score,
score_ctc=beam.score_ctc,
lm_score=beam.score,
)
)
return new_beams
else:
# lm is used, we need to compute the lm_score
# first we compute the lm_score of the next word
# we check if the next word is in the cache
# if not, we compute the score and add it to the cache
new_beams = []
for beam in beams:
# fast token merge
new_text = self.merge_tokens(beam.full_text, beam.next_word)
cache_key = (new_text, is_eos)
if cache_key not in cached_lm_scores:
prev_raw_lm_score, start_state = cached_lm_scores[
(beam.full_text, False)
]
score, end_state = self.lm.score(
start_state, beam.next_word, is_last_word=is_eos
)
raw_lm_score = prev_raw_lm_score + score
cached_lm_scores[cache_key] = (raw_lm_score, end_state)
lm_score, _ = cached_lm_scores[cache_key]
word_part = beam.partial_word
# we score the partial word
if len(word_part) > 0:
if word_part not in cached_partial_token_scores:
cached_partial_token_scores[word_part] = (
self.lm.score_partial_token(word_part)
)
lm_score += cached_partial_token_scores[word_part]
new_beams.append(
LMCTCBeam(
text=beam.text,
full_text=new_text,
next_word="",
partial_word=beam.partial_word,
last_token=beam.last_token,
last_token_index=beam.last_token_index,
text_frames=beam.text_frames,
partial_frames=beam.partial_frames,
p=beam.p,
p_b=beam.p_b,
p_nb=beam.p_nb,
n_p_b=beam.n_p_b,
n_p_nb=beam.n_p_nb,
score=beam.score,
score_ctc=beam.score_ctc,
lm_score=beam.score + lm_score,
)
)
return new_beams
def _get_new_beam(
self,
frame_index: int,
new_prefix: str,
new_token: str,
new_token_index: int,
beams: List[CTCBeam],
p: float,
previous_beam: CTCBeam,
) -> CTCBeam:
"""Create a new beam and add it to the list of beams.
Arguments
---------
frame_index : int
The index of the current frame.
new_prefix : str
The new prefix.
new_token : str
The new token.
new_token_index : int
The index of the new token.
beams : list
The list of beams.
p : float
The probability of the new token.
previous_beam : CTCBeam
The previous beam.
Returns
-------
new_beam : CTCBeam
The new beam.
"""
for beam in beams:
if beam.text == new_prefix:
if p and p > beam.p:
beam.p = p
return beam
if not self.is_spm and new_token_index == self.space_index:
new_frame_list = (
previous_beam.text_frames
if previous_beam.partial_word == ""
else previous_beam.text_frames + [previous_beam.partial_frames]
)
# if we extend the beam with a space, we need to reset the partial word
# and move it to the next word
new_beam = CTCBeam(
text=new_prefix,
full_text=previous_beam.full_text,
next_word=previous_beam.partial_word,
partial_word="",
last_token=new_token,
last_token_index=new_token_index,
text_frames=new_frame_list,
partial_frames=(-1, -1),
score=-math.inf,
score_ctc=-math.inf,
p_b=-math.inf,
)
elif self.is_spm and new_token[:1] == self.spm_token:
# remove the spm token at the beginning of the token
clean_token = new_token[1:]
new_frame_list = (
previous_beam.text_frames
if previous_beam.partial_word == ""
else previous_beam.text_frames + [previous_beam.partial_frames]
)
# If the beginning of the token is the spm_token
# then it means that we are extending the beam with a new word.
# We need to change the new_word with the partial_word
# and reset the partial_word with the new token
new_prefix = previous_beam.text + " " + clean_token
new_beam = CTCBeam(
text=new_prefix,
full_text=previous_beam.full_text,
next_word=previous_beam.partial_word,
partial_word=clean_token,
last_token=new_token,
last_token_index=new_token_index,
text_frames=new_frame_list,
partial_frames=(frame_index, frame_index + 1),
score=-math.inf,
score_ctc=-math.inf,
p_b=-math.inf,
)
elif new_token_index == previous_beam.last_token_index:
new_end_frame = frame_index + 1
new_part_frames = (
previous_beam.partial_frames
if new_token_index == self.blank_index
else (previous_beam.partial_frames[0], new_end_frame)
)
# if repeated token, we only change the score
new_beam = CTCBeam(
text=new_prefix,
full_text=previous_beam.full_text,
next_word="",
partial_word=previous_beam.partial_word,
last_token=new_token,
last_token_index=new_token_index,
text_frames=previous_beam.text_frames,
partial_frames=new_part_frames,
score=-math.inf,
score_ctc=-math.inf,
p_b=-math.inf,
)
else:
new_part_frames = (
(frame_index, frame_index + 1)
if previous_beam.partial_frames[0] < 0
else (previous_beam.partial_frames[0], frame_index + 1)
)
# last case, we are extending the partial_word with a new token
new_beam = CTCBeam(
text=new_prefix,
full_text=previous_beam.full_text,
next_word="",
partial_word=previous_beam.partial_word + new_token,
last_token=new_token,
last_token_index=new_token_index,
text_frames=previous_beam.text_frames,
partial_frames=new_part_frames,
score=-math.inf,
score_ctc=-math.inf,
p_b=-math.inf,
)
beams.append(new_beam)
if previous_beam:
new_beam.p = previous_beam.p
return new_beam
[docs]
def partial_decoding(
self,
log_probs: torch.Tensor,
wav_len: int,
beams: List[CTCBeam],
cached_lm_scores: dict,
cached_p_lm_scores: dict,
processed_frames: int = 0,
) -> List[CTCBeam]:
"""Perform CTC Prefix Beam Search decoding.
If self.lm is not None, the language model scores are computed and added to the CTC scores.
Arguments
---------
log_probs : torch.Tensor
The log probabilities of the CTC input.
Shape: (seq_length, vocab_size)
wav_len : int
The length of the input sequence.
beams : list
The list of CTCBeam objects.
cached_lm_scores : dict
The cached language model scores.
cached_p_lm_scores : dict
The cached prefix language model scores.
processed_frames : int
The start frame of the current decoding step. (default: 0)
Returns
-------
beams : list
The list of CTCBeam objects.
"""
# select only the valid frames, i.e., the frames that are not padded
log_probs = log_probs[:wav_len]
for frame_index, logit_col in enumerate(
log_probs, start=processed_frames
):
# skip the frame if the blank probability is higher than the threshold
if logit_col[self.blank_index] > self.blank_skip_threshold:
continue
# get the tokens with the highest probability
max_index = logit_col.argmax()
tokens_index_list = set(
np.where(logit_col > self.token_prune_min_logp)[0]
) | {max_index}
curr_beams = beams.copy()
# select tokens that are in the vocab
# this is useful if the logit vocab_size is larger than the vocab_list
tokens_index_list = tokens_index_list & set(
range(len(self.vocab_list))
)
for token_index in tokens_index_list:
p_token = logit_col[token_index]
token = self.vocab_list[token_index]
for beam in curr_beams:
p_b, p_nb = beam.p_b, beam.p_nb
# blank case
if token_index == self.blank_index:
beam.n_p_b = np.logaddexp(
beam.n_p_b, beam.score_ctc + p_token
)
continue
if token == beam.last_token:
beam.n_p_nb = np.logaddexp(beam.n_p_nb, p_nb + p_token)
new_text = beam.text + token
new_beam = self._get_new_beam(
frame_index,
new_text,
token,
token_index,
beams,
p=p_token,
previous_beam=beam,
)
n_p_nb = new_beam.n_p_nb
if token_index == beam.last_token_index and p_b > -math.inf:
n_p_nb = np.logaddexp(n_p_nb, p_b + p_token)
elif token_index != beam.last_token_index:
n_p_nb = np.logaddexp(n_p_nb, beam.score_ctc + p_token)
new_beam.n_p_nb = n_p_nb
# update the CTC probabilities
for beam in beams:
beam.step()
# kenLM scores
scored_beams = self.get_lm_beams(
beams, cached_lm_scores, cached_p_lm_scores
)
# remove beams outliers
max_score = max([b.lm_score for b in scored_beams])
scored_beams = [
b
for b in scored_beams
if b.lm_score >= max_score + self.beam_prune_logp
]
trimmed_beams = self.sort_beams(scored_beams)
if self.prune_history:
lm_order = 1 if self.lm is None else self.lm.order
beams = self._prune_history(trimmed_beams, lm_order=lm_order)
else:
beams = [CTCBeam.from_lm_beam(b) for b in trimmed_beams]
return beams
[docs]
class TorchAudioCTCPrefixBeamSearcher:
"""TorchAudio CTC Prefix Beam Search Decoder.
This class is a wrapper around the CTC decoder from TorchAudio. It provides a simple interface
where you can either use the CPU or CUDA CTC decoder.
The CPU decoder is slower but uses less memory. The CUDA decoder is faster but uses more memory.
The CUDA decoder is also only available in the nightly version of torchaudio.
A lot of features are missing in the CUDA decoder, such as the ability to use a language model,
constraint search, and more. If you want to use those features, you have to use the CPU decoder.
For more information about the CPU decoder, please refer to the documentation of TorchAudio:
https://pytorch.org/audio/main/generated/torchaudio.models.decoder.ctc_decoder.html
For more information about the CUDA decoder, please refer to the documentation of TorchAudio:
https://pytorch.org/audio/main/generated/torchaudio.models.decoder.cuda_ctc_decoder.html#torchaudio.models.decoder.cuda_ctc_decoder
If you want to use the language model, or the lexicon search, please make sure that your
tokenizer/acoustic model uses the same tokens as the language model/lexicon. Otherwise, the decoding will fail.
The implementation is compatible with SentencePiece Tokens.
Note: When using CUDA CTC decoder, the blank_index has to be 0. Furthermore, using CUDA CTC decoder
requires the nightly version of torchaudio and a lot of VRAM memory (if you want to use a lot of beams).
Overall, we do recommend to use the CTCBeamSearcher or CTCPrefixBeamSearcher in SpeechBrain if you wants to use
n-gram + beam search decoding. If you wants to have constraint search, please use the CPU version of torchaudio,
and if you want to speedup as much as possible the decoding, please use the CUDA version.
Arguments
---------
tokens : list or str
The list of tokens or the path to the tokens file.
If this is a path, then the file should contain one token per line.
lexicon : str, default: None
Lexicon file containing the possible words and corresponding spellings. Each line consists of a word and its space separated spelling.
If None, uses lexicon-free decoding. (default: None)
lm : str, optional
A path containing KenLM language model or None if not using a language model. (default: None)
lm_dict : str, optional
File consisting of the dictionary used for the LM, with a word per line sorted by LM index.
If decoding with a lexicon, entries in lm_dict must also occur in the lexicon file.
If None, dictionary for LM is constructed using the lexicon file. (default: None)
topk : int, optional
Number of top CTCHypothesis to return. (default: 1)
beam_size : int, optional
Numbers of hypotheses to hold after each decode step. (default: 50)
beam_size_token : int, optional
Max number of tokens to consider at each decode step. If None, it is set to the total number of tokens. (default: None)
beam_threshold : float, optional
Threshold for pruning hypothesis. (default: 50)
lm_weight : float, optional
Weight of language model. (default: 2)
word_score : float, optional
Word insertion score. (default: 0)
unk_score : float, optional
Unknown word insertion score. (default: float("-inf"))
sil_score : float, optional
Silence insertion score. (default: 0)
log_add : bool, optional
Whether to use use logadd when merging hypotheses. (default: False)
blank_index : int or str, optional
Index of the blank token. If tokens is a file path, then this should be an str. Otherwise, this should be a int. (default: 0)
sil_index : int or str, optional
Index of the silence token. If tokens is a file path, then this should be an str. Otherwise, this should be a int. (default: 0)
unk_word : str, optional
Unknown word token. (default: "<unk>")
using_cpu_decoder : bool, optional
Whether to use the CPU searcher. If False, then the CUDA decoder is used. (default: True)
blank_skip_threshold : float, optional
Skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding (default: 1.0).
Note: This is only used when using the CUDA decoder, and it might worsen the WER/CER results. Use it at your own risk.
Example
-------
>>> import torch
>>> from speechbrain.decoders import TorchAudioCTCPrefixBeamSearcher
>>> probs = torch.tensor([[[0.2, 0.0, 0.8],
... [0.4, 0.0, 0.6]]])
>>> log_probs = torch.log(probs)
>>> lens = torch.tensor([1.0])
>>> blank_index = 2
>>> vocab_list = ['a', 'b', '-']
>>> searcher = TorchAudioCTCPrefixBeamSearcher(tokens=vocab_list, blank_index=blank_index, sil_index=blank_index) # doctest: +SKIP
>>> hyps = searcher(probs, lens) # doctest: +SKIP
"""
def __init__(
self,
tokens: Union[list, str],
lexicon: Optional[str] = None,
lm: Optional[str] = None,
lm_dict: Optional[str] = None,
topk: int = 1,
beam_size: int = 50,
beam_size_token: Optional[int] = None,
beam_threshold: float = 50,
lm_weight: float = 2,
word_score: float = 0,
unk_score: float = float("-inf"),
sil_score: float = 0,
log_add: bool = False,
blank_index: Union[str, int] = 0,
sil_index: Union[str, int] = 0,
unk_word: str = "<unk>",
using_cpu_decoder: bool = True,
blank_skip_threshold: float = 1.0,
):
self.lexicon = lexicon
self.tokens = tokens
self.lm = lm
self.lm_dict = lm_dict
self.topk = topk
self.beam_size = beam_size
self.beam_size_token = beam_size_token
self.beam_threshold = beam_threshold
self.lm_weight = lm_weight
self.word_score = word_score
self.unk_score = unk_score
self.sil_score = sil_score
self.log_add = log_add
self.blank_index = blank_index
self.sil_index = sil_index
self.unk_word = unk_word
self.using_cpu_decoder = using_cpu_decoder
self.blank_skip_threshold = blank_skip_threshold
if self.using_cpu_decoder:
try:
from torchaudio.models.decoder import ctc_decoder
except ImportError:
raise ImportError(
"ctc_decoder not found. Please install torchaudio and flashlight to use this decoder."
)
# if this is a path, then torchaudio expect to be an index
# while if its a list then it expects to be a token
if isinstance(self.tokens, str):
blank_token = self.blank_index
sil_token = self.sil_index
else:
blank_token = self.tokens[self.blank_index]
sil_token = self.tokens[self.sil_index]
self._ctc_decoder = ctc_decoder(
lexicon=self.lexicon,
tokens=self.tokens,
lm=self.lm,
lm_dict=self.lm_dict,
nbest=self.topk,
beam_size=self.beam_size,
beam_size_token=self.beam_size_token,
beam_threshold=self.beam_threshold,
lm_weight=self.lm_weight,
word_score=self.word_score,
unk_score=self.unk_score,
sil_score=self.sil_score,
log_add=self.log_add,
blank_token=blank_token,
sil_token=sil_token,
unk_word=self.unk_word,
)
else:
try:
from torchaudio.models.decoder import cuda_ctc_decoder
except ImportError:
raise ImportError(
"cuda_ctc_decoder not found. Please install the latest version of torchaudio to use this decoder."
)
assert (
self.blank_index == 0
), "Index of blank token has to be 0 when using CUDA CTC decoder."
self._ctc_decoder = cuda_ctc_decoder(
tokens=self.tokens,
nbest=self.topk,
beam_size=self.beam_size,
blank_skip_threshold=self.blank_skip_threshold,
)
[docs]
def decode_beams(
self, log_probs: torch.Tensor, wav_len: Union[torch.Tensor, None] = None
) -> List[List[CTCHypothesis]]:
"""Decode log_probs using TorchAudio CTC decoder.
If `using_cpu_decoder=True` then log_probs and wav_len are moved to CPU before decoding.
When using CUDA CTC decoder, the timestep information is not available. Therefore, the timesteps
in the returned hypotheses are set to None.
Make sure that the input are in the log domain. The decoder will fail to decode
logits or probabilities. The input should be the log probabilities of the CTC output.
Arguments
---------
log_probs : torch.Tensor
The log probabilities of the input audio.
Shape: (batch_size, seq_length, vocab_size)
wav_len : torch.Tensor, default: None
The speechbrain-style relative length. Shape: (batch_size,)
If None, then the length of each audio is assumed to be seq_length.
Returns
-------
list of list of CTCHypothesis
The decoded hypotheses. The outer list is over the batch dimension, and the inner list is over the topk dimension.
"""
if wav_len is not None:
wav_len = log_probs.size(1) * wav_len
else:
wav_len = torch.tensor(
[log_probs.size(1)] * log_probs.size(0),
device=log_probs.device,
dtype=torch.int32,
)
if wav_len.dtype != torch.int32:
wav_len = wav_len.to(torch.int32)
if log_probs.dtype != torch.float32:
raise ValueError("log_probs must be float32.")
# When using CPU decoder, we need to move the log_probs and wav_len to CPU
if self.using_cpu_decoder and log_probs.is_cuda:
log_probs = log_probs.cpu()
if self.using_cpu_decoder and wav_len.is_cuda:
wav_len = wav_len.cpu()
if not log_probs.is_contiguous():
raise RuntimeError("log_probs must be contiguous.")
results = self._ctc_decoder(log_probs, wav_len)
tokens_preds = []
words_preds = []
scores_preds = []
timesteps_preds = []
# over batch dim
for i in range(len(results)):
if self.using_cpu_decoder:
preds = [
results[i][j].tokens.tolist()
for j in range(len(results[i]))
]
preds = [
[self.tokens[token] for token in tokens] for tokens in preds
]
tokens_preds.append(preds)
timesteps = [
results[i][j].timesteps.tolist()
for j in range(len(results[i]))
]
timesteps_preds.append(timesteps)
else:
# no timesteps is available for CUDA CTC decoder
timesteps = [None for _ in range(len(results[i]))]
timesteps_preds.append(timesteps)
preds = [results[i][j].tokens for j in range(len(results[i]))]
preds = [
[self.tokens[token] for token in tokens] for tokens in preds
]
tokens_preds.append(preds)
words = [results[i][j].words for j in range(len(results[i]))]
words_preds.append(words)
scores = [results[i][j].score for j in range(len(results[i]))]
scores_preds.append(scores)
hyps = []
for (
batch_index,
(batch_text, batch_score, batch_timesteps),
) in enumerate(zip(tokens_preds, scores_preds, timesteps_preds)):
hyps.append([])
for text, score, timestep in zip(
batch_text, batch_score, batch_timesteps
):
hyps[batch_index].append(
CTCHypothesis(
text="".join(text),
last_lm_state=None,
score=score,
lm_score=score,
text_frames=timestep,
)
)
return hyps
[docs]
def __call__(
self, log_probs: torch.Tensor, wav_len: Union[torch.Tensor, None] = None
) -> List[List[CTCHypothesis]]:
"""Decode log_probs using TorchAudio CTC decoder.
If `using_cpu_decoder=True` then log_probs and wav_len are moved to CPU before decoding.
When using CUDA CTC decoder, the timestep information is not available. Therefore, the timesteps
in the returned hypotheses are set to None.
Arguments
---------
log_probs : torch.Tensor
The log probabilities of the input audio.
Shape: (batch_size, seq_length, vocab_size)
wav_len : torch.Tensor, default: None
The speechbrain-style relative length. Shape: (batch_size,)
If None, then the length of each audio is assumed to be seq_length.
Returns
-------
list of list of CTCHypothesis
The decoded hypotheses. The outer list is over the batch dimension, and the inner list is over the topk dimension.
"""
return self.decode_beams(log_probs, wav_len)