"""Decoders and output normalization for Transducer sequence.
Author:
Abdelwahab HEBA 2020
Sung-Lin Yeh 2020
"""
from dataclasses import dataclass
from functools import partial
from typing import Any, Optional
import torch
[docs]
@dataclass
class TransducerGreedySearcherStreamingContext(torch.nn.Module):
"""Simple wrapper for the hidden state of the transducer greedy searcher.
Used by :meth:`~TransducerBeamSearcher.transducer_greedy_decode_streaming`.
"""
hidden: Optional[Any] = None
"""Hidden state; typically a tensor or a tuple of tensors."""
[docs]
class TransducerBeamSearcher(torch.nn.Module):
"""
This class implements the beam-search algorithm for the transducer model.
Arguments
---------
decode_network_lst : list
List of prediction network (PN) layers.
tjoint: transducer_joint module
This module perform the joint between TN and PN.
classifier_network : list
List of output layers (after performing joint between TN and PN)
exp: (TN,PN) => joint => classifier_network_list [DNN block, Linear..] => chars prob
blank_id : int
The blank symbol/index.
beam_size : int
The width of beam. Greedy Search is used when beam_size = 1.
nbest : int
Number of hypotheses to keep.
lm_module : torch.nn.ModuleList
Neural networks modules for LM.
lm_weight : float
The weight of LM when performing beam search (λ).
log P(y|x) + λ log P_LM(y). (default: 0.3)
state_beam : float
The threshold coefficient in log space to decide if hyps in A (process_hyps)
is likely to compete with hyps in B (beam_hyps), if not, end the while loop.
Reference: https://arxiv.org/pdf/1911.01629.pdf
expand_beam : float
The threshold coefficient to limit the number of expanded hypotheses
that are added in A (process_hyp).
Reference: https://arxiv.org/pdf/1911.01629.pdf
Reference: https://github.com/kaldi-asr/kaldi/blob/master/src/decoder/simple-decoder.cc (See PruneToks)
Example
-------
searcher = TransducerBeamSearcher(
decode_network_lst=[hparams["emb"], hparams["dec"]],
tjoint=hparams["Tjoint"],
classifier_network=[hparams["transducer_lin"]],
blank_id=0,
beam_size=hparams["beam_size"],
nbest=hparams["nbest"],
lm_module=hparams["lm_model"],
lm_weight=hparams["lm_weight"],
state_beam=2.3,
expand_beam=2.3,
)
>>> from speechbrain.nnet.transducer.transducer_joint import Transducer_joint
>>> import speechbrain as sb
>>> emb = sb.nnet.embedding.Embedding(
... num_embeddings=35,
... embedding_dim=3,
... consider_as_one_hot=True,
... blank_id=0
... )
>>> dec = sb.nnet.RNN.GRU(
... hidden_size=10, input_shape=(1, 40, 34), bidirectional=False
... )
>>> lin = sb.nnet.linear.Linear(input_shape=(1, 40, 10), n_neurons=35)
>>> joint_network= sb.nnet.linear.Linear(input_shape=(1, 1, 40, 35), n_neurons=35)
>>> tjoint = Transducer_joint(joint_network, joint="sum")
>>> searcher = TransducerBeamSearcher(
... decode_network_lst=[emb, dec],
... tjoint=tjoint,
... classifier_network=[lin],
... blank_id=0,
... beam_size=1,
... nbest=1,
... lm_module=None,
... lm_weight=0.0,
... )
>>> enc = torch.rand([1, 20, 10])
>>> hyps, _, _, _ = searcher(enc)
"""
def __init__(
self,
decode_network_lst,
tjoint,
classifier_network,
blank_id,
beam_size=4,
nbest=5,
lm_module=None,
lm_weight=0.0,
state_beam=2.3,
expand_beam=2.3,
):
super().__init__()
self.decode_network_lst = decode_network_lst
self.tjoint = tjoint
self.classifier_network = classifier_network
self.blank_id = blank_id
self.beam_size = beam_size
self.nbest = nbest
self.lm = lm_module
self.lm_weight = lm_weight
if lm_module is None and lm_weight > 0:
raise ValueError("Language model is not provided.")
self.state_beam = state_beam
self.expand_beam = expand_beam
self.softmax = torch.nn.LogSoftmax(dim=-1)
if self.beam_size <= 1:
self.searcher = self.transducer_greedy_decode
else:
self.searcher = self.transducer_beam_search_decode
[docs]
def forward(self, tn_output):
"""
Arguments
---------
tn_output : torch.Tensor
Output from transcription network with shape
[batch, time_len, hiddens].
Returns
-------
Topk hypotheses
"""
hyps = self.searcher(tn_output)
return hyps
[docs]
def transducer_greedy_decode(
self, tn_output, hidden_state=None, return_hidden=False
):
"""Transducer greedy decoder is a greedy decoder over batch which apply Transducer rules:
1- for each time step in the Transcription Network (TN) output:
-> Update the ith utterance only if
the previous target != the new one (we save the hiddens and the target)
-> otherwise:
---> keep the previous target prediction from the decoder
Arguments
---------
tn_output : torch.Tensor
Output from transcription network with shape
[batch, time_len, hiddens].
hidden_state : (torch.Tensor, torch.Tensor)
Hidden state to initially feed the decode network with. This is
useful in conjunction with `return_hidden` to be able to perform
beam search in a streaming context, so that you can reuse the last
hidden state as an initial state across calls.
return_hidden : bool
Whether the return tuple should contain an extra 5th element with
the hidden state at of the last step. See `hidden_state`.
Returns
-------
Tuple of 4 or 5 elements (if `return_hidden`).
First element: List[List[int]]
List of decoded tokens
Second element: torch.Tensor
Outputs a logits tensor [B,T,1,Output_Dim]; padding
has not been removed.
Third element: None
nbest; irrelevant for greedy decode
Fourth element: None
nbest scores; irrelevant for greedy decode
Fifth element: Present if `return_hidden`, (torch.Tensor, torch.Tensor)
Tuple representing the hidden state required to call
`transducer_greedy_decode` where you left off in a streaming
context.
"""
hyp = {
"prediction": [[] for _ in range(tn_output.size(0))],
"logp_scores": [0.0 for _ in range(tn_output.size(0))],
}
# prepare BOS = Blank for the Prediction Network (PN)
input_PN = (
torch.ones(
(tn_output.size(0), 1),
device=tn_output.device,
dtype=torch.int32,
)
* self.blank_id
)
if hidden_state is None:
# First forward-pass on PN
out_PN, hidden = self._forward_PN(input_PN, self.decode_network_lst)
else:
out_PN, hidden = hidden_state
# For each time step
for t_step in range(tn_output.size(1)):
# do unsqueeze over since tjoint must be have a 4 dim [B,T,U,Hidden]
log_probs = self._joint_forward_step(
tn_output[:, t_step, :].unsqueeze(1).unsqueeze(1),
out_PN.unsqueeze(1),
)
# Sort outputs at time
logp_targets, positions = torch.max(
log_probs.squeeze(1).squeeze(1), dim=1
)
# Batch hidden update
have_update_hyp = []
for i in range(positions.size(0)):
# Update hiddens only if
# 1- current prediction is non blank
if positions[i].item() != self.blank_id:
hyp["prediction"][i].append(positions[i].item())
hyp["logp_scores"][i] += logp_targets[i]
input_PN[i][0] = positions[i]
have_update_hyp.append(i)
if len(have_update_hyp) > 0:
# Select sentence to update
# And do a forward steps + generated hidden
(
selected_input_PN,
selected_hidden,
) = self._get_sentence_to_update(
have_update_hyp, input_PN, hidden
)
selected_out_PN, selected_hidden = self._forward_PN(
selected_input_PN, self.decode_network_lst, selected_hidden
)
# update hiddens and out_PN
out_PN[have_update_hyp] = selected_out_PN
hidden = self._update_hiddens(
have_update_hyp, selected_hidden, hidden
)
ret = (
hyp["prediction"],
torch.Tensor(hyp["logp_scores"]).exp().mean(),
None,
None,
)
if return_hidden:
# append the `(out_PN, hidden)` tuple to ret
ret += (
(
out_PN,
hidden,
),
)
return ret
[docs]
def transducer_greedy_decode_streaming(
self, x: torch.Tensor, context: TransducerGreedySearcherStreamingContext
):
"""Tiny wrapper for
:meth:`~TransducerBeamSearcher.transducer_greedy_decode` with an API
that makes it suitable to be passed as a `decoding_function` for
streaming.
Arguments
---------
x : torch.Tensor
Outputs of the prediction network (equivalent to `tn_output`)
context : TransducerGreedySearcherStreamingContext
Mutable streaming context object, which must be specified and reused
across calls when streaming.
You can obtain an initial context by initializing a default object.
Returns
-------
hyp : torch.Tensor
"""
(hyp, _scores, _, _, hidden) = self.transducer_greedy_decode(
x, context.hidden, return_hidden=True
)
context.hidden = hidden
return hyp
[docs]
def transducer_beam_search_decode(self, tn_output):
"""Transducer beam search decoder is a beam search decoder over batch which apply Transducer rules:
1- for each utterance:
2- for each time steps in the Transcription Network (TN) output:
-> Do forward on PN and Joint network
-> Select topK <= beam
-> Do a while loop extending the hyps until we reach blank
-> otherwise:
--> extend hyp by the new token
Arguments
---------
tn_output : torch.Tensor
Output from transcription network with shape
[batch, time_len, hiddens].
Returns
-------
torch.Tensor
Outputs a logits tensor [B,T,1,Output_Dim]; padding
has not been removed.
"""
# min between beam and max_target_lent
nbest_batch = []
nbest_batch_score = []
for i_batch in range(tn_output.size(0)):
# if we use RNN LM keep there hiddens
# prepare BOS = Blank for the Prediction Network (PN)
# Prepare Blank prediction
blank = (
torch.ones((1, 1), device=tn_output.device, dtype=torch.int32)
* self.blank_id
)
input_PN = (
torch.ones((1, 1), device=tn_output.device, dtype=torch.int32)
* self.blank_id
)
# First forward-pass on PN
hyp = {
"prediction": [self.blank_id],
"logp_score": 0.0,
"hidden_dec": None,
}
if self.lm_weight > 0:
lm_dict = {"hidden_lm": None}
hyp.update(lm_dict)
beam_hyps = [hyp]
# For each time step
for t_step in range(tn_output.size(1)):
# get hyps for extension
process_hyps = beam_hyps
beam_hyps = []
while True:
if len(beam_hyps) >= self.beam_size:
break
# Add norm score
a_best_hyp = max(
process_hyps,
key=partial(get_transducer_key),
)
# Break if best_hyp in A is worse by more than state_beam than best_hyp in B
if len(beam_hyps) > 0:
b_best_hyp = max(
beam_hyps,
key=partial(get_transducer_key),
)
a_best_prob = a_best_hyp["logp_score"]
b_best_prob = b_best_hyp["logp_score"]
if b_best_prob >= self.state_beam + a_best_prob:
break
# remove best hyp from process_hyps
process_hyps.remove(a_best_hyp)
# forward PN
input_PN[0, 0] = a_best_hyp["prediction"][-1]
out_PN, hidden = self._forward_PN(
input_PN,
self.decode_network_lst,
a_best_hyp["hidden_dec"],
)
# do unsqueeze over since tjoint must be have a 4 dim [B,T,U,Hidden]
log_probs = self._joint_forward_step(
tn_output[i_batch, t_step, :]
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
out_PN.unsqueeze(0),
)
if self.lm_weight > 0:
log_probs_lm, hidden_lm = self._lm_forward_step(
input_PN, a_best_hyp["hidden_lm"]
)
# Sort outputs at time
logp_targets, positions = torch.topk(
log_probs.view(-1), k=self.beam_size, dim=-1
)
best_logp = (
logp_targets[0]
if positions[0] != blank
else logp_targets[1]
)
# Extend hyp by selection
for j in range(logp_targets.size(0)):
# hyp
topk_hyp = {
"prediction": a_best_hyp["prediction"][:],
"logp_score": a_best_hyp["logp_score"]
+ logp_targets[j],
"hidden_dec": a_best_hyp["hidden_dec"],
}
if positions[j] == self.blank_id:
beam_hyps.append(topk_hyp)
if self.lm_weight > 0:
topk_hyp["hidden_lm"] = a_best_hyp["hidden_lm"]
continue
if logp_targets[j] >= best_logp - self.expand_beam:
topk_hyp["prediction"].append(positions[j].item())
topk_hyp["hidden_dec"] = hidden
if self.lm_weight > 0:
topk_hyp["hidden_lm"] = hidden_lm
topk_hyp["logp_score"] += (
self.lm_weight
* log_probs_lm[0, 0, positions[j]]
)
process_hyps.append(topk_hyp)
# Add norm score
nbest_hyps = sorted(
beam_hyps,
key=partial(get_transducer_key),
reverse=True,
)[: self.nbest]
all_predictions = []
all_scores = []
for hyp in nbest_hyps:
all_predictions.append(hyp["prediction"][1:])
all_scores.append(hyp["logp_score"] / len(hyp["prediction"]))
nbest_batch.append(all_predictions)
nbest_batch_score.append(all_scores)
return (
[nbest_utt[0] for nbest_utt in nbest_batch],
torch.Tensor(
[nbest_utt_score[0] for nbest_utt_score in nbest_batch_score]
)
.exp()
.mean(),
nbest_batch,
nbest_batch_score,
)
def _joint_forward_step(self, h_i, out_PN):
"""Join predictions (TN & PN)."""
with torch.no_grad():
# the output would be a tensor of [B,T,U, oneof[sum,concat](Hidden_TN,Hidden_PN)]
out = self.tjoint(
h_i,
out_PN,
)
# forward the output layers + activation + save logits
out = self._forward_after_joint(out, self.classifier_network)
log_probs = self.softmax(out)
return log_probs
def _lm_forward_step(self, inp_tokens, memory):
"""This method should implement one step of
forwarding operation for language model.
Arguments
---------
inp_tokens : torch.Tensor
The input tensor of the current timestep.
memory : No limit
The memory variables input for this timestep.
(e.g., RNN hidden states).
Return
------
log_probs : torch.Tensor
Log-probabilities of the current timestep output.
hs : No limit
The memory variables are generated in this timestep.
(e.g., RNN hidden states).
"""
with torch.no_grad():
logits, hs = self.lm(inp_tokens, hx=memory)
log_probs = self.softmax(logits)
return log_probs, hs
def _get_sentence_to_update(self, selected_sentences, output_PN, hidden):
"""Select and return the updated hiddens and output
from the Prediction Network.
Arguments
---------
selected_sentences : list
List of updated sentences (indexes).
output_PN: torch.Tensor
Output tensor from prediction network (PN).
hidden : torch.Tensor
Optional: None, hidden tensor to be used for
recurrent layers in the prediction network.
Returns
-------
selected_output_PN: torch.Tensor
Outputs a logits tensor [B_selected,U, hiddens].
hidden_update_hyp: torch.Tensor
Selected hiddens tensor.
"""
selected_output_PN = output_PN[selected_sentences, :]
# for LSTM hiddens (hn, hc)
if isinstance(hidden, tuple):
hidden0_update_hyp = hidden[0][:, selected_sentences, :]
hidden1_update_hyp = hidden[1][:, selected_sentences, :]
hidden_update_hyp = (hidden0_update_hyp, hidden1_update_hyp)
else:
hidden_update_hyp = hidden[:, selected_sentences, :]
return selected_output_PN, hidden_update_hyp
def _update_hiddens(self, selected_sentences, updated_hidden, hidden):
"""Update hidden tensor by a subset of hidden tensor (updated ones).
Arguments
---------
selected_sentences : list
List of index to be updated.
updated_hidden : torch.Tensor
Hidden tensor of the selected sentences for update.
hidden : torch.Tensor
Hidden tensor to be updated.
Returns
-------
torch.Tensor
Updated hidden tensor.
"""
if isinstance(hidden, tuple):
hidden[0][:, selected_sentences, :] = updated_hidden[0]
hidden[1][:, selected_sentences, :] = updated_hidden[1]
else:
hidden[:, selected_sentences, :] = updated_hidden
return hidden
def _forward_PN(self, out_PN, decode_network_lst, hidden=None):
"""Compute forward-pass through a list of prediction network (PN) layers.
Arguments
---------
out_PN : torch.Tensor
Input sequence from prediction network with shape
[batch, target_seq_lens].
decode_network_lst: list
List of prediction network (PN) layers.
hidden : torch.Tensor
Optional: None, hidden tensor to be used for
recurrent layers in the prediction network
Returns
-------
out_PN : torch.Tensor
Outputs a logits tensor [B,U, hiddens].
hidden : torch.Tensor
Hidden tensor to be used for the next step
by recurrent layers in prediction network.
"""
for layer in decode_network_lst:
if layer.__class__.__name__ in [
"RNN",
"LSTM",
"GRU",
"LiGRU",
"LiGRU_Layer",
]:
out_PN, hidden = layer(out_PN, hidden)
else:
out_PN = layer(out_PN)
return out_PN, hidden
def _forward_after_joint(self, out, classifier_network):
"""Compute forward-pass through a list of classifier neural network.
Arguments
---------
out : torch.Tensor
Output from joint network with shape
[batch, target_len, time_len, hiddens]
classifier_network : list
List of output layers (after performing joint between TN and PN)
exp: (TN,PN) => joint => classifier_network_list [DNN block, Linear..] => chars prob
Returns
-------
torch.Tensor
Outputs a logits tensor [B, U,T, Output_Dim];
"""
for layer in classifier_network:
out = layer(out)
return out
[docs]
def get_transducer_key(x):
"""Argument function to customize the sort order (in sorted & max).
To be used as `key=partial(get_transducer_key)`.
Arguments
---------
x : dict
one of the items under comparison
Returns
-------
float
Normalized log-score.
"""
logp_key = x["logp_score"] / len(x["prediction"])
return logp_key