"""Decoders and output normalization for CTC.
Authors
* Mirco Ravanelli 2020
* Aku Rouhe 2020
* Sung-Lin Yeh 2020
"""
import torch
from itertools import groupby
from speechbrain.dataio.dataio import length_to_mask
[docs]class CTCPrefixScorer:
"""This class implements the CTC prefix scorer of Algorithm 2 in
reference: https://www.merl.com/publications/docs/TR2017-190.pdf.
Official implementation: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
Arguments
---------
x : torch.Tensor
The encoder states.
enc_lens : torch.Tensor
The actual length of each enc_states sequence.
batch_size : int
The size of the batch.
beam_size : int
The width of beam.
blank_index : int
The index of the blank token.
eos_index : int
The index of the end-of-sequence (eos) token.
ctc_window_size: int
Compute the ctc scores over the time frames using windowing based on attention peaks.
If 0, no windowing applied.
"""
def __init__(
self,
x,
enc_lens,
batch_size,
beam_size,
blank_index,
eos_index,
ctc_window_size=0,
):
self.blank_index = blank_index
self.eos_index = eos_index
self.max_enc_len = x.size(1)
self.batch_size = batch_size
self.beam_size = beam_size
self.vocab_size = x.size(-1)
self.device = x.device
self.minus_inf = -1e20
self.last_frame_index = enc_lens - 1
self.ctc_window_size = ctc_window_size
# mask frames > enc_lens
mask = 1 - length_to_mask(enc_lens)
mask = mask.unsqueeze(-1).expand(-1, -1, x.size(-1)).eq(1)
x.masked_fill_(mask, self.minus_inf)
x[:, :, 0] = x[:, :, 0].masked_fill_(mask[:, :, 0], 0)
# dim=0: xnb, nonblank posteriors, dim=1: xb, blank posteriors
xnb = x.transpose(0, 1)
xb = (
xnb[:, :, self.blank_index]
.unsqueeze(2)
.expand(-1, -1, self.vocab_size)
)
# (2, L, batch_size * beam_size, vocab_size)
self.x = torch.stack([xnb, xb])
# The first index of each sentence.
self.beam_offset = (
torch.arange(batch_size, device=self.device) * self.beam_size
)
# The first index of each candidates.
self.cand_offset = (
torch.arange(batch_size, device=self.device) * self.vocab_size
)
[docs] def forward_step(self, g, state, candidates=None, attn=None):
"""This method if one step of forwarding operation
for the prefix ctc scorer.
Arguments
---------
g : torch.Tensor
The tensor of prefix label sequences, h = g + c.
state : tuple
Previous ctc states.
candidates : torch.Tensor
(batch_size * beam_size, ctc_beam_size), The topk candidates for rescoring.
The ctc_beam_size is set as 2 * beam_size. If given, performing partial ctc scoring.
"""
prefix_length = g.size(1)
last_char = [gi[-1] for gi in g] if prefix_length > 0 else [0] * len(g)
self.num_candidates = (
self.vocab_size if candidates is None else candidates.size(-1)
)
if state is None:
# r_prev: (L, 2, batch_size * beam_size)
r_prev = torch.full(
(self.max_enc_len, 2, self.batch_size, self.beam_size),
self.minus_inf,
device=self.device,
)
# Accumulate blank posteriors at each step
r_prev[:, 1] = torch.cumsum(
self.x[0, :, :, self.blank_index], 0
).unsqueeze(2)
r_prev = r_prev.view(-1, 2, self.batch_size * self.beam_size)
psi_prev = 0.0
else:
r_prev, psi_prev = state
# for partial search
if candidates is not None:
scoring_table = torch.full(
(self.batch_size * self.beam_size, self.vocab_size),
-1,
dtype=torch.long,
device=self.device,
)
# Assign indices of candidates to their positions in the table
col_index = torch.arange(
self.batch_size * self.beam_size, device=self.device
).unsqueeze(1)
scoring_table[col_index, candidates] = torch.arange(
self.num_candidates, device=self.device
)
# Select candidates indices for scoring
scoring_index = (
candidates
+ self.cand_offset.unsqueeze(1)
.repeat(1, self.beam_size)
.view(-1, 1)
).view(-1)
x_inflate = torch.index_select(
self.x.view(2, -1, self.batch_size * self.vocab_size),
2,
scoring_index,
).view(2, -1, self.batch_size * self.beam_size, self.num_candidates)
# for full search
else:
scoring_table = None
x_inflate = (
self.x.unsqueeze(3)
.repeat(1, 1, 1, self.beam_size, 1)
.view(
2, -1, self.batch_size * self.beam_size, self.num_candidates
)
)
# Prepare forward probs
r = torch.full(
(
self.max_enc_len,
2,
self.batch_size * self.beam_size,
self.num_candidates,
),
self.minus_inf,
device=self.device,
)
r.fill_(self.minus_inf)
# (Alg.2-6)
if prefix_length == 0:
r[0, 0] = x_inflate[0, 0]
# (Alg.2-10): phi = prev_nonblank + prev_blank = r_t-1^nb(g) + r_t-1^b(g)
r_sum = torch.logsumexp(r_prev, 1)
phi = r_sum.unsqueeze(2).repeat(1, 1, self.num_candidates)
# (Alg.2-10): if last token of prefix g in candidates, phi = prev_b + 0
if candidates is not None:
for i in range(self.batch_size * self.beam_size):
pos = scoring_table[i, last_char[i]]
if pos != -1:
phi[:, i, pos] = r_prev[:, 1, i]
else:
for i in range(self.batch_size * self.beam_size):
phi[:, i, last_char[i]] = r_prev[:, 1, i]
# Start, end frames for scoring (|g| < |h|).
# Scoring based on attn peak if ctc_window_size > 0
if self.ctc_window_size == 0 or attn is None:
start = max(1, prefix_length)
end = self.max_enc_len
else:
_, attn_peak = torch.max(attn, dim=1)
max_frame = torch.max(attn_peak).item() + self.ctc_window_size
min_frame = torch.min(attn_peak).item() - self.ctc_window_size
start = max(max(1, prefix_length), int(min_frame))
end = min(self.max_enc_len, int(max_frame))
# Compute forward prob log(r_t^nb(h)) and log(r_t^b(h)):
for t in range(start, end):
# (Alg.2-11): dim=0, p(h|cur step is nonblank) = [p(prev step=y) + phi] * p(c)
rnb_prev = r[t - 1, 0]
# (Alg.2-12): dim=1, p(h|cur step is blank) = [p(prev step is blank) + p(prev step is nonblank)] * p(blank)
rb_prev = r[t - 1, 1]
r_ = torch.stack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view(
2, 2, self.batch_size * self.beam_size, self.num_candidates
)
r[t] = torch.logsumexp(r_, 1) + x_inflate[:, t]
# Compute the predix prob, psi
psi_init = r[start - 1, 0].unsqueeze(0)
# phi is prob at t-1 step, shift one frame and add it to the current prob p(c)
phix = torch.cat((phi[0].unsqueeze(0), phi[:-1]), dim=0) + x_inflate[0]
# (Alg.2-13): psi = psi + phi * p(c)
if candidates is not None:
psi = torch.full(
(self.batch_size * self.beam_size, self.vocab_size),
self.minus_inf,
device=self.device,
)
psi_ = torch.logsumexp(
torch.cat((phix[start:end], psi_init), dim=0), dim=0
)
# only assign prob to candidates
for i in range(self.batch_size * self.beam_size):
psi[i, candidates[i]] = psi_[i]
else:
psi = torch.logsumexp(
torch.cat((phix[start:end], psi_init), dim=0), dim=0
)
# (Alg.2-3): if c = <eos>, psi = log(r_T^n(g) + r_T^b(g)), where T is the length of max frames
for i in range(self.batch_size * self.beam_size):
psi[i, self.eos_index] = r_sum[
self.last_frame_index[i // self.beam_size], i
]
# Exclude blank probs for joint scoring
psi[:, self.blank_index] = self.minus_inf
return psi - psi_prev, (r, psi, scoring_table)
[docs] def permute_mem(self, memory, index):
"""This method permutes the CTC model memory
to synchronize the memory index with the current output.
Arguments
---------
memory : No limit
The memory variable to be permuted.
index : torch.Tensor
The index of the previous path.
Return
------
The variable of the memory being permuted.
"""
r, psi, scoring_table = memory
# The index of top-K vocab came from in (t-1) timesteps.
best_index = (
index
+ (self.beam_offset.unsqueeze(1).expand_as(index) * self.vocab_size)
).view(-1)
# synchronize forward prob
psi = torch.index_select(psi.view(-1), dim=0, index=best_index)
psi = (
psi.view(-1, 1)
.repeat(1, self.vocab_size)
.view(self.batch_size * self.beam_size, self.vocab_size)
)
# synchronize ctc states
if scoring_table is not None:
effective_index = (
index // self.vocab_size + self.beam_offset.view(-1, 1)
).view(-1)
selected_vocab = (index % self.vocab_size).view(-1)
score_index = scoring_table[effective_index, selected_vocab]
score_index[score_index == -1] = 0
best_index = score_index + effective_index * self.num_candidates
r = torch.index_select(
r.view(
-1, 2, self.batch_size * self.beam_size * self.num_candidates
),
dim=-1,
index=best_index,
)
r = r.view(-1, 2, self.batch_size * self.beam_size)
return r, psi
[docs]def filter_ctc_output(string_pred, blank_id=-1):
"""Apply CTC output merge and filter rules.
Removes the blank symbol and output repetitions.
Arguments
---------
string_pred : list
A list containing the output strings/ints predicted by the CTC system.
blank_id : int, string
The id of the blank.
Returns
-------
list
The output predicted by CTC without the blank symbol and
the repetitions.
Example
-------
>>> string_pred = ['a','a','blank','b','b','blank','c']
>>> string_out = filter_ctc_output(string_pred, blank_id='blank')
>>> print(string_out)
['a', 'b', 'c']
"""
if isinstance(string_pred, list):
# Filter the repetitions
string_out = [
v
for i, v in enumerate(string_pred)
if i == 0 or v != string_pred[i - 1]
]
# Remove duplicates
string_out = [i[0] for i in groupby(string_out)]
# Filter the blank symbol
string_out = list(filter(lambda elem: elem != blank_id, string_out))
else:
raise ValueError("filter_ctc_out can only filter python lists")
return string_out
[docs]def ctc_greedy_decode(probabilities, seq_lens, blank_id=-1):
"""Greedy decode a batch of probabilities and apply CTC rules.
Arguments
---------
probabilities : torch.tensor
Output probabilities (or log-probabilities) from the network with shape
[batch, probabilities, time]
seq_lens : torch.tensor
Relative true sequence lengths (to deal with padded inputs),
the longest sequence has length 1.0, others a value between zero and one
shape [batch, lengths].
blank_id : int, string
The blank symbol/index. Default: -1. If a negative number is given,
it is assumed to mean counting down from the maximum possible index,
so that -1 refers to the maximum possible index.
Returns
-------
list
Outputs as Python list of lists, with "ragged" dimensions; padding
has been removed.
Example
-------
>>> import torch
>>> probs = torch.tensor([[[0.3, 0.7], [0.0, 0.0]],
... [[0.2, 0.8], [0.9, 0.1]]])
>>> lens = torch.tensor([0.51, 1.0])
>>> blank_id = 0
>>> ctc_greedy_decode(probs, lens, blank_id)
[[1], [1]]
"""
if isinstance(blank_id, int) and blank_id < 0:
blank_id = probabilities.shape[-1] + blank_id
batch_max_len = probabilities.shape[1]
batch_outputs = []
for seq, seq_len in zip(probabilities, seq_lens):
actual_size = int(torch.round(seq_len * batch_max_len))
scores, predictions = torch.max(seq.narrow(0, 0, actual_size), dim=1)
out = filter_ctc_output(predictions.tolist(), blank_id=blank_id)
batch_outputs.append(out)
return batch_outputs