Source code for speechbrain.decoders.scorer

"""
Token scorer abstraction and specifications.

Authors:
 * Adel Moumen 2022, 2023
 * Sung-Lin Yeh 2021
"""

import torch
import numpy as np
import speechbrain as sb
from speechbrain.decoders.ctc import CTCPrefixScore


[docs] class BaseScorerInterface: """A scorer abstraction to be inherited by other scoring approaches for beam search. A scorer is a module that scores tokens in vocabulary based on the current timestep input and the previous scorer states. It can be used to score on full vocabulary set (i.e., full scorers) or a pruned set of tokens (i.e. partial scorers) to prevent computation overhead. In the latter case, the partial scorers will be called after the full scorers. It will only scores the top-k candidates (i.e., pruned set of tokens) extracted from the full scorers. The top-k candidates are extracted based on the beam size and the scorer_beam_scale such that the number of candidates is int(beam_size * scorer_beam_scale). It can be very useful when the full scorers are computationally expensive (e.g., KenLM scorer). Inherit this class to implement your own scorer compatible with speechbrain.decoders.seq2seq.S2SBeamSearcher(). See: - speechbrain.decoders.scorer.CTCPrefixScorer - speechbrain.decoders.scorer.RNNLMScorer - speechbrain.decoders.scorer.TransformerLMScorer - speechbrain.decoders.scorer.KenLMScorer - speechbrain.decoders.scorer.CoverageScorer - speechbrain.decoders.scorer.LengthScorer """
[docs] def score(self, inp_tokens, memory, candidates, attn): """This method scores the new beams based on the informations of the current timestep. A score is a tensor of shape (batch_size x beam_size, vocab_size). It is the log probability of the next token given the current timestep input and the previous scorer states. It can be used to score on pruned top-k candidates to prevent computation overhead, or on full vocabulary set when candidates is None. Arguments --------- inp_tokens : torch.Tensor The input tensor of the current timestep. memory : No limit The scorer states for this timestep. candidates : torch.Tensor (batch_size x beam_size, scorer_beam_size). The top-k candidates to be scored after the full scorers. If None, scorers will score on full vocabulary set. attn : torch.Tensor The attention weight to be used in CoverageScorer or CTCScorer. Returns --------- torch.Tensor (batch_size x beam_size, vocab_size), Scores for the next tokens. memory : No limit The memory variables input for this timestep. """ raise NotImplementedError
[docs] def permute_mem(self, memory, index): """This method permutes the scorer memory to synchronize the memory index with the current output and perform batched beam search. Arguments --------- memory : No limit The memory variables input for this timestep. index : torch.Tensor (batch_size, beam_size). The index of the previous path. """ return None
[docs] def reset_mem(self, x, enc_lens): """This method should implement the resetting of memory variables for the scorer. Arguments --------- x : torch.Tensor The precomputed encoder states to be used when decoding. (ex. the encoded speech representation to be attended). enc_lens : torch.Tensor The speechbrain-style relative length. """ return None
[docs] class CTCScorer(BaseScorerInterface): """A wrapper of CTCPrefixScore based on the BaseScorerInterface. This Scorer is used to provides the CTC label-synchronous scores of the next input tokens. The implementation is based on https://www.merl.com/publications/docs/TR2017-190.pdf. See: - speechbrain.decoders.scorer.CTCPrefixScore Arguments --------- ctc_fc : torch.nn.Module A output linear layer for ctc. blank_index : int The index of the blank token. eos_index : int The index of the end-of-sequence (eos) token. ctc_window_size : int Compute the ctc scores over the time frames using windowing based on attention peaks. If 0, no windowing applied. (default: 0) Example ------- >>> import torch >>> from speechbrain.nnet.linear import Linear >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR >>> from speechbrain.decoders import S2STransformerBeamSearcher, CTCScorer, ScorerBuilder >>> batch_size=8 >>> n_channels=6 >>> input_size=40 >>> d_model=128 >>> tgt_vocab=140 >>> src = torch.rand([batch_size, n_channels, input_size]) >>> tgt = torch.randint(0, tgt_vocab, [batch_size, n_channels]) >>> net = TransformerASR( ... tgt_vocab, input_size, d_model, 8, 1, 1, 1024, activation=torch.nn.GELU ... ) >>> ctc_lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab) >>> lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab) >>> eos_index = 2 >>> ctc_scorer = CTCScorer( ... ctc_fc=ctc_lin, ... blank_index=0, ... eos_index=eos_index, ... ) >>> scorer = ScorerBuilder( ... full_scorers=[ctc_scorer], ... weights={'ctc': 1.0} ... ) >>> searcher = S2STransformerBeamSearcher( ... modules=[net, lin], ... bos_index=1, ... eos_index=eos_index, ... min_decode_ratio=0.0, ... max_decode_ratio=1.0, ... using_eos_threshold=False, ... beam_size=7, ... temperature=1.15, ... scorer=scorer ... ) >>> enc, dec = net.forward(src, tgt) >>> hyps, _, _, _ = searcher(enc, torch.ones(batch_size)) """ def __init__( self, ctc_fc, blank_index, eos_index, ctc_window_size=0, ): self.ctc_fc = ctc_fc self.blank_index = blank_index self.eos_index = eos_index self.ctc_window_size = ctc_window_size self.softmax = sb.nnet.activations.Softmax(apply_log=True)
[docs] def score(self, inp_tokens, memory, candidates, attn): """This method scores the new beams based on the CTC scores computed over the time frames. See: - speechbrain.decoders.scorer.CTCPrefixScore Arguments --------- inp_tokens : torch.Tensor The input tensor of the current timestep. memory : No limit The scorer states for this timestep. candidates : torch.Tensor (batch_size x beam_size, scorer_beam_size). The top-k candidates to be scored after the full scorers. If None, scorers will score on full vocabulary set. attn : torch.Tensor The attention weight to be used in CoverageScorer or CTCScorer. """ scores, memory = self.ctc_score.forward_step( inp_tokens, memory, candidates, attn ) return scores, memory
[docs] def permute_mem(self, memory, index): """This method permutes the scorer memory to synchronize the memory index with the current output and perform batched CTC beam search. Arguments --------- memory : No limit The memory variables input for this timestep. index : torch.Tensor (batch_size, beam_size). The index of the previous path. """ r, psi = self.ctc_score.permute_mem(memory, index) return r, psi
[docs] def reset_mem(self, x, enc_lens): """This method implement the resetting of memory variables for the CTC scorer. Arguments --------- x : torch.Tensor The precomputed encoder states to be used when decoding. (ex. the encoded speech representation to be attended). enc_lens : torch.Tensor The speechbrain-style relative length. """ logits = self.ctc_fc(x) x = self.softmax(logits) self.ctc_score = CTCPrefixScore( x, enc_lens, self.blank_index, self.eos_index, self.ctc_window_size ) return None
[docs] class RNNLMScorer(BaseScorerInterface): """A wrapper of RNNLM based on BaseScorerInterface. The RNNLMScorer is used to provide the RNNLM scores of the next input tokens based on the current timestep input and the previous scorer states. Arguments --------- language_model : torch.nn.Module A RNN-based language model. temperature : float Temperature factor applied to softmax. It changes the probability distribution, being softer when T>1 and sharper with T<1. (default: 1.0) Example ------- >>> from speechbrain.nnet.linear import Linear >>> from speechbrain.lobes.models.RNNLM import RNNLM >>> from speechbrain.nnet.RNN import AttentionalRNNDecoder >>> from speechbrain.decoders import S2SRNNBeamSearcher, RNNLMScorer, ScorerBuilder >>> input_size=17 >>> vocab_size=11 >>> emb = torch.nn.Embedding( ... embedding_dim=input_size, ... num_embeddings=vocab_size, ... ) >>> d_model=7 >>> dec = AttentionalRNNDecoder( ... rnn_type="gru", ... attn_type="content", ... hidden_size=3, ... attn_dim=3, ... num_layers=1, ... enc_dim=d_model, ... input_size=input_size, ... ) >>> n_channels=3 >>> seq_lin = Linear(input_shape=[d_model, n_channels], n_neurons=vocab_size) >>> lm_weight = 0.4 >>> lm_model = RNNLM( ... embedding_dim=d_model, ... output_neurons=vocab_size, ... dropout=0.0, ... rnn_neurons=128, ... dnn_neurons=64, ... return_hidden=True, ... ) >>> rnnlm_scorer = RNNLMScorer( ... language_model=lm_model, ... temperature=1.25, ... ) >>> scorer = ScorerBuilder( ... full_scorers=[rnnlm_scorer], ... weights={'rnnlm': lm_weight} ... ) >>> beam_size=5 >>> searcher = S2SRNNBeamSearcher( ... embedding=emb, ... decoder=dec, ... linear=seq_lin, ... bos_index=1, ... eos_index=2, ... min_decode_ratio=0.0, ... max_decode_ratio=1.0, ... topk=2, ... using_eos_threshold=False, ... beam_size=beam_size, ... temperature=1.25, ... scorer=scorer ... ) >>> batch_size=2 >>> enc = torch.rand([batch_size, n_channels, d_model]) >>> wav_len = torch.ones([batch_size]) >>> hyps, _, _, _ = searcher(enc, wav_len) """ def __init__(self, language_model, temperature=1.0): self.lm = language_model self.lm.eval() self.temperature = temperature self.softmax = sb.nnet.activations.Softmax(apply_log=True)
[docs] def score(self, inp_tokens, memory, candidates, attn): """This method scores the new beams based on the RNNLM scores computed over the previous tokens. Arguments --------- inp_tokens : torch.Tensor The input tensor of the current timestep. memory : No limit The scorer states for this timestep. candidates : torch.Tensor (batch_size x beam_size, scorer_beam_size). The top-k candidates to be scored after the full scorers. If None, scorers will score on full vocabulary set. attn : torch.Tensor The attention weight to be used in CoverageScorer or CTCScorer. """ with torch.no_grad(): logits, hs = self.lm(inp_tokens, hx=memory) log_probs = self.softmax(logits / self.temperature) return log_probs, hs
[docs] def permute_mem(self, memory, index): """This method permutes the scorer memory to synchronize the memory index with the current output and perform batched beam search. Arguments --------- memory : No limit The memory variables input for this timestep. index : torch.Tensor (batch_size, beam_size). The index of the previous path. """ if isinstance(memory, tuple): memory_0 = torch.index_select(memory[0], dim=1, index=index) memory_1 = torch.index_select(memory[1], dim=1, index=index) memory = (memory_0, memory_1) else: memory = torch.index_select(memory, dim=1, index=index) return memory
[docs] def reset_mem(self, x, enc_lens): """This method implement the resetting of memory variables for the RNNLM scorer. Arguments --------- x : torch.Tensor The precomputed encoder states to be used when decoding. (ex. the encoded speech representation to be attended). enc_lens : torch.Tensor The speechbrain-style relative length. """ return None
[docs] class TransformerLMScorer(BaseScorerInterface): """A wrapper of TransformerLM based on BaseScorerInterface. The TransformerLMScorer is used to provide the TransformerLM scores of the next input tokens based on the current timestep input and the previous scorer states. Arguments --------- language_model : torch.nn.Module A Transformer-based language model. temperature : float Temperature factor applied to softmax. It changes the probability distribution, being softer when T>1 and sharper with T<1. (default: 1.0) Example ------- >>> from speechbrain.nnet.linear import Linear >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR >>> from speechbrain.lobes.models.transformer.TransformerLM import TransformerLM >>> from speechbrain.decoders import S2STransformerBeamSearcher, TransformerLMScorer, CTCScorer, ScorerBuilder >>> input_size=17 >>> vocab_size=11 >>> d_model=128 >>> net = TransformerASR( ... tgt_vocab=vocab_size, ... input_size=input_size, ... d_model=d_model, ... nhead=8, ... num_encoder_layers=1, ... num_decoder_layers=1, ... d_ffn=256, ... activation=torch.nn.GELU ... ) >>> lm_model = TransformerLM( ... vocab=vocab_size, ... d_model=d_model, ... nhead=8, ... num_encoder_layers=1, ... num_decoder_layers=0, ... d_ffn=256, ... activation=torch.nn.GELU, ... ) >>> n_channels=6 >>> ctc_lin = Linear(input_size=d_model, n_neurons=vocab_size) >>> seq_lin = Linear(input_size=d_model, n_neurons=vocab_size) >>> eos_index = 2 >>> ctc_scorer = CTCScorer( ... ctc_fc=ctc_lin, ... blank_index=0, ... eos_index=eos_index, ... ) >>> transformerlm_scorer = TransformerLMScorer( ... language_model=lm_model, ... temperature=1.15, ... ) >>> ctc_weight_decode=0.4 >>> lm_weight=0.6 >>> scorer = ScorerBuilder( ... full_scorers=[transformerlm_scorer, ctc_scorer], ... weights={'transformerlm': lm_weight, 'ctc': ctc_weight_decode} ... ) >>> beam_size=5 >>> searcher = S2STransformerBeamSearcher( ... modules=[net, seq_lin], ... bos_index=1, ... eos_index=eos_index, ... min_decode_ratio=0.0, ... max_decode_ratio=1.0, ... using_eos_threshold=False, ... beam_size=beam_size, ... temperature=1.15, ... scorer=scorer ... ) >>> batch_size=2 >>> wav_len = torch.ones([batch_size]) >>> src = torch.rand([batch_size, n_channels, input_size]) >>> tgt = torch.randint(0, vocab_size, [batch_size, n_channels]) >>> enc, dec = net.forward(src, tgt) >>> hyps, _, _, _ = searcher(enc, wav_len) """ def __init__(self, language_model, temperature=1.0): self.lm = language_model self.lm.eval() self.temperature = temperature self.softmax = sb.nnet.activations.Softmax(apply_log=True)
[docs] def score(self, inp_tokens, memory, candidates, attn): """This method scores the new beams based on the TransformerLM scores computed over the previous tokens. Arguments --------- inp_tokens : torch.Tensor The input tensor of the current timestep. memory : No limit The scorer states for this timestep. candidates : torch.Tensor (batch_size x beam_size, scorer_beam_size). The top-k candidates to be scored after the full scorers. If None, scorers will score on full vocabulary set. attn : torch.Tensor The attention weight to be used in CoverageScorer or CTCScorer. """ with torch.no_grad(): if memory is None: memory = torch.empty( inp_tokens.size(0), 0, device=inp_tokens.device ) # Append the predicted token of the previous step to existing memory. memory = torch.cat([memory, inp_tokens.unsqueeze(1)], dim=-1) if not next(self.lm.parameters()).is_cuda: self.lm.to(inp_tokens.device) logits = self.lm(memory) log_probs = self.softmax(logits / self.temperature) return log_probs[:, -1, :], memory
[docs] def permute_mem(self, memory, index): """This method permutes the scorer memory to synchronize the memory index with the current output and perform batched beam search. Arguments --------- memory : No limit The memory variables input for this timestep. index : torch.Tensor (batch_size, beam_size). The index of the previous path. """ memory = torch.index_select(memory, dim=0, index=index) return memory
[docs] def reset_mem(self, x, enc_lens): """This method implement the resetting of memory variables for the RNNLM scorer. Arguments --------- x : torch.Tensor The precomputed encoder states to be used when decoding. (ex. the encoded speech representation to be attended). enc_lens : torch.Tensor The speechbrain-style relative length. """ return None
[docs] class KenLMScorer(BaseScorerInterface): """KenLM N-gram scorer. This scorer is based on KenLM, which is a fast and efficient N-gram language model toolkit. It is used to provide the n-gram scores of the next input tokens. This scorer is dependent on the KenLM package. It can be installed with the following command: > pip install https://github.com/kpu/kenlm/archive/master.zip Note: The KenLM scorer is computationally expensive. It is recommended to use it as a partial scorer to score on the top-k candidates instead of the full vocabulary set. Arguments --------- lm_path : str The path of ngram model. vocab_size: int The total number of tokens. token_list : list The tokens set. # Example # ------- # >>> from speechbrain.nnet.linear import Linear # >>> from speechbrain.nnet.RNN import AttentionalRNNDecoder # >>> from speechbrain.decoders import S2SRNNBeamSearcher, KenLMScorer, ScorerBuilder # >>> input_size=17 # >>> vocab_size=11 # >>> lm_path='path/to/kenlm_model.arpa' # or .bin # >>> token_list=['<pad>', '<bos>', '<eos>', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i'] # >>> emb = torch.nn.Embedding( # ... embedding_dim=input_size, # ... num_embeddings=vocab_size, # ... ) # >>> d_model=7 # >>> dec = AttentionalRNNDecoder( # ... rnn_type="gru", # ... attn_type="content", # ... hidden_size=3, # ... attn_dim=3, # ... num_layers=1, # ... enc_dim=d_model, # ... input_size=input_size, # ... ) # >>> n_channels=3 # >>> seq_lin = Linear(input_shape=[d_model, n_channels], n_neurons=vocab_size) # >>> kenlm_weight = 0.4 # >>> kenlm_model = KenLMScorer( # ... lm_path=lm_path, # ... vocab_size=vocab_size, # ... token_list=token_list, # ... ) # >>> scorer = ScorerBuilder( # ... full_scorers=[kenlm_model], # ... weights={'kenlm': kenlm_weight} # ... ) # >>> beam_size=5 # >>> searcher = S2SRNNBeamSearcher( # ... embedding=emb, # ... decoder=dec, # ... linear=seq_lin, # ... bos_index=1, # ... eos_index=2, # ... min_decode_ratio=0.0, # ... max_decode_ratio=1.0, # ... topk=2, # ... using_eos_threshold=False, # ... beam_size=beam_size, # ... temperature=1.25, # ... scorer=scorer # ... ) # >>> batch_size=2 # >>> enc = torch.rand([batch_size, n_channels, d_model]) # >>> wav_len = torch.ones([batch_size]) # >>> hyps, _, _, _ = searcher(enc, wav_len) """ def __init__(self, lm_path, vocab_size, token_list): try: import kenlm self.kenlm = kenlm except ImportError: MSG = """Couldn't import KenLM It is an optional dependency; it is not installed with SpeechBrain by default. Install it with: > pip install https://github.com/kpu/kenlm/archive/master.zip """ raise ImportError(MSG) self.lm = self.kenlm.Model(lm_path) self.vocab_size = vocab_size self.full_candidates = np.arange(self.vocab_size) self.minus_inf = -1e20 if len(token_list) != vocab_size: MSG = "The size of the token_list and vocab_size are not matched." raise ValueError(MSG) self.id2char = token_list
[docs] def score(self, inp_tokens, memory, candidates, attn): """This method scores the new beams based on the n-gram scores. Arguments --------- inp_tokens : torch.Tensor The input tensor of the current timestep. memory : No limit The scorer states for this timestep. candidates : torch.Tensor (batch_size x beam_size, scorer_beam_size). The top-k candidates to be scored after the full scorers. If None, scorers will score on full vocabulary set. attn : torch.Tensor The attention weight to be used in CoverageScorer or CTCScorer. """ n_bh = inp_tokens.size(0) scale = 1.0 / np.log10(np.e) if memory is None: state = self.kenlm.State() state = np.array([state] * n_bh) scoring_table = np.ones(n_bh) else: state, scoring_table = memory # Perform full scorer mode, not recommend if candidates is None: candidates = [self.full_candidates] * n_bh # Store new states and scores scores = np.ones((n_bh, self.vocab_size)) * self.minus_inf new_memory = np.zeros((n_bh, self.vocab_size), dtype=object) new_scoring_table = np.ones((n_bh, self.vocab_size)) * -1 # Scoring for i in range(n_bh): if scoring_table[i] == -1: continue parent_state = state[i] for token_id in candidates[i]: char = self.id2char[token_id.item()] out_state = self.kenlm.State() score = scale * self.lm.BaseScore(parent_state, char, out_state) scores[i, token_id] = score new_memory[i, token_id] = out_state new_scoring_table[i, token_id] = 1 scores = torch.from_numpy(scores).float().to(inp_tokens.device) return scores, (new_memory, new_scoring_table)
[docs] def permute_mem(self, memory, index): """This method permutes the scorer memory to synchronize the memory index with the current output and perform batched beam search. Arguments --------- memory : No limit The memory variables input for this timestep. index : torch.Tensor (batch_size, beam_size). The index of the previous path. """ state, scoring_table = memory index = index.cpu().numpy() # The first index of each sentence. beam_size = index.shape[1] beam_offset = self.batch_index * beam_size hyp_index = ( index + np.broadcast_to(np.expand_dims(beam_offset, 1), index.shape) * self.vocab_size ) hyp_index = hyp_index.reshape(-1) # Update states state = state.reshape(-1) state = state[hyp_index] scoring_table = scoring_table.reshape(-1) scoring_table = scoring_table[hyp_index] return state, scoring_table
[docs] def reset_mem(self, x, enc_lens): """This method implement the resetting of memory variables for the KenLM scorer. Arguments --------- x : torch.Tensor The precomputed encoder states to be used when decoding. (ex. the encoded speech representation to be attended). enc_lens : torch.Tensor The speechbrain-style relative length. """ state = self.kenlm.State() self.lm.NullContextWrite(state) self.batch_index = np.arange(x.size(0)) return None
[docs] class CoverageScorer(BaseScorerInterface): """A coverage penalty scorer to prevent looping of hyps, where ```coverage``` is the cumulative attention probability vector. Reference: https://arxiv.org/pdf/1612.02695.pdf, https://arxiv.org/pdf/1808.10792.pdf Arguments --------- vocab_size: int The total number of tokens. threshold: float The penalty increases when the coverage of a frame is more than given threshold. (default: 0.5) Example ------- >>> from speechbrain.nnet.linear import Linear >>> from speechbrain.lobes.models.RNNLM import RNNLM >>> from speechbrain.nnet.RNN import AttentionalRNNDecoder >>> from speechbrain.decoders import S2SRNNBeamSearcher, RNNLMScorer, CoverageScorer, ScorerBuilder >>> input_size=17 >>> vocab_size=11 >>> emb = torch.nn.Embedding( ... num_embeddings=vocab_size, ... embedding_dim=input_size ... ) >>> d_model=7 >>> dec = AttentionalRNNDecoder( ... rnn_type="gru", ... attn_type="content", ... hidden_size=3, ... attn_dim=3, ... num_layers=1, ... enc_dim=d_model, ... input_size=input_size, ... ) >>> n_channels=3 >>> seq_lin = Linear(input_shape=[d_model, n_channels], n_neurons=vocab_size) >>> lm_weight = 0.4 >>> coverage_penalty = 1.0 >>> lm_model = RNNLM( ... embedding_dim=d_model, ... output_neurons=vocab_size, ... dropout=0.0, ... rnn_neurons=128, ... dnn_neurons=64, ... return_hidden=True, ... ) >>> rnnlm_scorer = RNNLMScorer( ... language_model=lm_model, ... temperature=1.25, ... ) >>> coverage_scorer = CoverageScorer(vocab_size=vocab_size) >>> scorer = ScorerBuilder( ... full_scorers=[rnnlm_scorer, coverage_scorer], ... weights={'rnnlm': lm_weight, 'coverage': coverage_penalty} ... ) >>> beam_size=5 >>> searcher = S2SRNNBeamSearcher( ... embedding=emb, ... decoder=dec, ... linear=seq_lin, ... bos_index=1, ... eos_index=2, ... min_decode_ratio=0.0, ... max_decode_ratio=1.0, ... topk=2, ... using_eos_threshold=False, ... beam_size=beam_size, ... temperature=1.25, ... scorer=scorer ... ) >>> batch_size=2 >>> enc = torch.rand([batch_size, n_channels, d_model]) >>> wav_len = torch.ones([batch_size]) >>> hyps, _, _, _ = searcher(enc, wav_len) """ def __init__(self, vocab_size, threshold=0.5): self.vocab_size = vocab_size self.threshold = threshold # Use time_step to normalize the coverage over steps self.time_step = 0
[docs] def score(self, inp_tokens, coverage, candidates, attn): """This method scores the new beams based on the Coverage scorer. Arguments --------- inp_tokens : torch.Tensor The input tensor of the current timestep. coverage : No limit The scorer states for this timestep. candidates : torch.Tensor (batch_size x beam_size, scorer_beam_size). The top-k candidates to be scored after the full scorers. If None, scorers will score on full vocabulary set. attn : torch.Tensor The attention weight to be used in CoverageScorer or CTCScorer. """ n_bh = attn.size(0) self.time_step += 1 if coverage is None: coverage = torch.zeros_like(attn, device=attn.device) # Current coverage if len(attn.size()) > 2: # the attn of transformer is [batch_size x beam_size, current_step, source_len] coverage = torch.sum(attn, dim=1) else: coverage = coverage + attn # Compute coverage penalty and add it to scores penalty = torch.max( coverage, coverage.clone().fill_(self.threshold) ).sum(-1) penalty = penalty - coverage.size(-1) * self.threshold penalty = penalty.view(n_bh).unsqueeze(1).expand(-1, self.vocab_size) return -1 * penalty / self.time_step, coverage
[docs] def permute_mem(self, coverage, index): """This method permutes the scorer memory to synchronize the memory index with the current output and perform batched beam search. Arguments --------- coverage : No limit The memory variables input for this timestep. index : torch.Tensor (batch_size, beam_size). The index of the previous path. """ # Update coverage coverage = torch.index_select(coverage, dim=0, index=index) return coverage
[docs] def reset_mem(self, x, enc_lens): """This method implement the resetting of memory variables for the RNNLM scorer. Arguments --------- x : torch.Tensor The precomputed encoder states to be used when decoding. (ex. the encoded speech representation to be attended). enc_lens : torch.Tensor The speechbrain-style relative length. """ self.time_step = 0 return None
[docs] class LengthScorer(BaseScorerInterface): """A length rewarding scorer. The LengthScorer is used to provide the length rewarding scores. It is used to prevent the beam search from favoring short hypotheses. Note: length_normalization is not compatible with this scorer. Make sure to set is to False when using LengthScorer. Arguments --------- vocab_size: int The total number of tokens. Example ------- >>> from speechbrain.nnet.linear import Linear >>> from speechbrain.lobes.models.RNNLM import RNNLM >>> from speechbrain.nnet.RNN import AttentionalRNNDecoder >>> from speechbrain.decoders import S2SRNNBeamSearcher, RNNLMScorer, CoverageScorer, ScorerBuilder >>> input_size=17 >>> vocab_size=11 >>> emb = torch.nn.Embedding( ... num_embeddings=vocab_size, ... embedding_dim=input_size ... ) >>> d_model=7 >>> dec = AttentionalRNNDecoder( ... rnn_type="gru", ... attn_type="content", ... hidden_size=3, ... attn_dim=3, ... num_layers=1, ... enc_dim=d_model, ... input_size=input_size, ... ) >>> n_channels=3 >>> seq_lin = Linear(input_shape=[d_model, n_channels], n_neurons=vocab_size) >>> lm_weight = 0.4 >>> length_weight = 1.0 >>> lm_model = RNNLM( ... embedding_dim=d_model, ... output_neurons=vocab_size, ... dropout=0.0, ... rnn_neurons=128, ... dnn_neurons=64, ... return_hidden=True, ... ) >>> rnnlm_scorer = RNNLMScorer( ... language_model=lm_model, ... temperature=1.25, ... ) >>> length_scorer = LengthScorer(vocab_size=vocab_size) >>> scorer = ScorerBuilder( ... full_scorers=[rnnlm_scorer, length_scorer], ... weights={'rnnlm': lm_weight, 'length': length_weight} ... ) >>> beam_size=5 >>> searcher = S2SRNNBeamSearcher( ... embedding=emb, ... decoder=dec, ... linear=seq_lin, ... bos_index=1, ... eos_index=2, ... min_decode_ratio=0.0, ... max_decode_ratio=1.0, ... topk=2, ... using_eos_threshold=False, ... beam_size=beam_size, ... temperature=1.25, ... length_normalization=False, ... scorer=scorer ... ) >>> batch_size=2 >>> enc = torch.rand([batch_size, n_channels, d_model]) >>> wav_len = torch.ones([batch_size]) >>> hyps, _, _, _ = searcher(enc, wav_len) """ def __init__(self, vocab_size): self.vocab_size = vocab_size
[docs] def score(self, inp_tokens, memory, candidates, attn): """This method scores the new beams based on the Length scorer. Arguments --------- inp_tokens : torch.Tensor The input tensor of the current timestep. memory : No limit The scorer states for this timestep. candidates : torch.Tensor (batch_size x beam_size, scorer_beam_size). The top-k candidates to be scored after the full scorers. If None, scorers will score on full vocabulary set. attn : torch.Tensor The attention weight to be used in CoverageScorer or CTCScorer. """ return ( torch.tensor( [1.0], device=inp_tokens.device, dtype=inp_tokens.dtype ).expand(inp_tokens.size(0), self.vocab_size), None, )
[docs] class ScorerBuilder: """ Builds scorer instance for beamsearch. The ScorerBuilder class is responsible for building a scorer instance for beam search. It takes weights for full and partial scorers, as well as instances of full and partial scorer classes. It combines the scorers based on the weights specified and provides methods for scoring tokens, permuting scorer memory, and resetting scorer memory. This is the class to be used for building scorer instances for beam search. See speechbrain.decoders.seq2seq.S2SBeamSearcher() Arguments --------- weights : dict Weights of full/partial scorers specified. full_scorers : list Scorers that score on full vocabulary set. partial_scorers : list Scorers that score on pruned tokens to prevent computation overhead. Partial scoring is performed after full scorers. scorer_beam_scale : float The scale decides the number of pruned tokens for partial scorers: int(beam_size * scorer_beam_scale). Example ------- >>> from speechbrain.nnet.linear import Linear >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR >>> from speechbrain.lobes.models.transformer.TransformerLM import TransformerLM >>> from speechbrain.decoders import S2STransformerBeamSearcher, TransformerLMScorer, CoverageScorer, CTCScorer, ScorerBuilder >>> input_size=17 >>> vocab_size=11 >>> d_model=128 >>> net = TransformerASR( ... tgt_vocab=vocab_size, ... input_size=input_size, ... d_model=d_model, ... nhead=8, ... num_encoder_layers=1, ... num_decoder_layers=1, ... d_ffn=256, ... activation=torch.nn.GELU ... ) >>> lm_model = TransformerLM( ... vocab=vocab_size, ... d_model=d_model, ... nhead=8, ... num_encoder_layers=1, ... num_decoder_layers=0, ... d_ffn=256, ... activation=torch.nn.GELU, ... ) >>> n_channels=6 >>> ctc_lin = Linear(input_size=d_model, n_neurons=vocab_size) >>> seq_lin = Linear(input_size=d_model, n_neurons=vocab_size) >>> eos_index = 2 >>> ctc_scorer = CTCScorer( ... ctc_fc=ctc_lin, ... blank_index=0, ... eos_index=eos_index, ... ) >>> transformerlm_scorer = TransformerLMScorer( ... language_model=lm_model, ... temperature=1.15, ... ) >>> coverage_scorer = CoverageScorer(vocab_size=vocab_size) >>> ctc_weight_decode=0.4 >>> lm_weight=0.6 >>> coverage_penalty = 1.0 >>> scorer = ScorerBuilder( ... full_scorers=[transformerlm_scorer, coverage_scorer], ... partial_scorers=[ctc_scorer], ... weights={'transformerlm': lm_weight, 'ctc': ctc_weight_decode, 'coverage': coverage_penalty} ... ) >>> beam_size=5 >>> searcher = S2STransformerBeamSearcher( ... modules=[net, seq_lin], ... bos_index=1, ... eos_index=eos_index, ... min_decode_ratio=0.0, ... max_decode_ratio=1.0, ... using_eos_threshold=False, ... beam_size=beam_size, ... topk=3, ... temperature=1.15, ... scorer=scorer ... ) >>> batch_size=2 >>> wav_len = torch.ones([batch_size]) >>> src = torch.rand([batch_size, n_channels, input_size]) >>> tgt = torch.randint(0, vocab_size, [batch_size, n_channels]) >>> enc, dec = net.forward(src, tgt) >>> hyps, _, _, _ = searcher(enc, wav_len) """ def __init__( self, weights=dict(), full_scorers=list(), partial_scorers=list(), scorer_beam_scale=2, ): assert len(weights) == len(full_scorers) + len( partial_scorers ), "Weights and scorers are not matched." self.scorer_beam_scale = scorer_beam_scale all_scorer_names = [ k.lower().split("scorer")[0] for k in globals().keys() if k.endswith("Scorer") ] full_scorer_names = [ impl.__class__.__name__.lower().split("scorer")[0] for impl in full_scorers ] partial_scorer_names = [ impl.__class__.__name__.lower().split("scorer")[0] for impl in partial_scorers ] # Have a default 0.0 weight for scorer not specified init_weights = {k: 0.0 for k in all_scorer_names} self.weights = {**init_weights, **weights} self.full_scorers = dict(zip(full_scorer_names, full_scorers)) self.partial_scorers = dict(zip(partial_scorer_names, partial_scorers)) # Check if scorers are valid self._validate_scorer(all_scorer_names)
[docs] def score(self, inp_tokens, memory, attn, log_probs, beam_size): """This method scores tokens in vocabulary based on defined full scorers and partial scorers. Scores will be added to the log probs for beamsearch. Arguments --------- inp_tokens : torch.Tensor See BaseScorerInterface(). memory : dict[str, scorer memory] The states of scorers for this timestep. attn : torch.Tensor See BaseScorerInterface(). log_probs : torch.Tensor (batch_size x beam_size, vocab_size). The log probs at this timestep. beam_size : int The beam size. Returns --------- log_probs : torch.Tensor (batch_size x beam_size, vocab_size). Log probs updated by scorers. new_memory : dict[str, scorer memory] The updated states of scorers. """ new_memory = dict() # score full candidates for k, impl in self.full_scorers.items(): if k == "ctc": # block blank token if CTC is used log_probs[:, impl.blank_index] = impl.ctc_score.minus_inf score, new_memory[k] = impl.score(inp_tokens, memory[k], None, attn) log_probs += score * self.weights[k] # select candidates from the results of full scorers for partial scorers _, candidates = log_probs.topk( int(beam_size * self.scorer_beam_scale), dim=-1 ) # score pruned tokens candidates for k, impl in self.partial_scorers.items(): score, new_memory[k] = impl.score( inp_tokens, memory[k], candidates, attn ) log_probs += score * self.weights[k] return log_probs, new_memory
[docs] def permute_scorer_mem(self, memory, index, candidates): """Update memory variables of scorers to synchronize the memory index with the current output and perform batched beam search. Arguments --------- memory : dict[str, scorer memory] The states of scorers for this timestep. index : torch.Tensor (batch_size x beam_size). The index of the previous path. candidates : torch.Tensor (batch_size, beam_size). The index of the topk candidates. """ for k, impl in self.full_scorers.items(): # ctc scorer should always be scored by candidates if k == "ctc" or k == "kenlm": memory[k] = impl.permute_mem(memory[k], candidates) continue memory[k] = impl.permute_mem(memory[k], index) for k, impl in self.partial_scorers.items(): memory[k] = impl.permute_mem(memory[k], candidates) return memory
[docs] def reset_scorer_mem(self, x, enc_lens): """Reset memory variables for scorers. Arguments --------- x : torch.Tensor See BaseScorerInterface(). wav_len : torch.Tensor See BaseScorerInterface(). """ memory = dict() for k, impl in {**self.full_scorers, **self.partial_scorers}.items(): memory[k] = impl.reset_mem(x, enc_lens) return memory
def _validate_scorer(self, scorer_names): """These error messages indicate scorers are not properly set. Arguments --------- scorer_names : list Prefix of scorers defined in speechbrain.decoders.scorer. """ if len(self.weights) > len(scorer_names): raise ValueError( "The keys of weights should be named in {}".format(scorer_names) ) if not 0.0 <= self.weights["ctc"] <= 1.0: raise ValueError("ctc_weight should not > 1.0 and < 0.0") if self.weights["ctc"] == 1.0: if "ctc" not in self.full_scorers.keys(): raise ValueError( "CTC scorer should be a full scorer when it's weight is 1.0" ) if self.weights["coverage"] > 0.0: raise ValueError( "Pure CTC scorer doesn't have attention weights for coverage scorer" )
[docs] class BaseRescorerInterface(BaseScorerInterface): """A scorer abstraction intended for inheritance by other scoring approaches used in beam search. In this approach, a neural network is employed to assign scores to potential text transcripts. The beam search decoding process produces a collection of the top K hypotheses. These candidates are subsequently sent to a language model (LM) for ranking. The ranking is carried out by the LM, which assigns a score to each candidate. The score is computed as follows: score = beam_search_score + lm_weight * rescorer_score See: - speechbrain.decoders.scorer.RNNLMRescorer - speechbrain.decoders.scorer.TransformerLMRescorer - speechbrain.decoders.scorer.HuggingFaceLMRescorer """
[docs] def normalize_text(self, text): """This method should implement the normalization of the text before scoring. Arguments --------- text : list of str The text to be normalized. """ return text
[docs] def preprocess_func(self, hyps): """This method should implement the preprocessing of the hypotheses before scoring. Arguments --------- hyps : list of str The hypotheses to be preprocessed. """ raise NotImplementedError
[docs] def rescore_hyps(self, hyps): """This method should implement the rescoring of the hypotheses. Arguments --------- hyps : list of str The hypotheses to be rescored. """ raise NotImplementedError
[docs] def to_device(self, device=None): """This method should implement the moving of the scorer to a device. If device is None, the scorer should be moved to the default device provided in the constructor. Arguments --------- device : str The device to move the scorer to. """ raise NotImplementedError
[docs] class RNNLMRescorer(BaseRescorerInterface): """A wrapper of RNNLM based on the BaseRescorerInterface. Arguments --------- language_model : torch.nn.Module A RNN-based language model. tokenizer : SentencePieceProcessor A SentencePiece tokenizer. device : str The device to move the scorer to. temperature : float Temperature factor applied to softmax. It changes the probability distribution, being softer when T>1 and sharper with T<1. (default: 1.0) bos_index : int The index of the beginning-of-sequence (bos) token. eos_index : int The index of the end-of-sequence (eos) token. pad_index : int The index of the padding token. NOTE ---- This class is intented to be used with a pretrained TransformerLM model. Please see: https://huggingface.co/speechbrain/asr-crdnn-rnnlm-librispeech By default, this model is using SentencePiece tokenizer. Example ------- >>> import torch >>> from sentencepiece import SentencePieceProcessor >>> from speechbrain.lobes.models.RNNLM import RNNLM >>> from speechbrain.utils.parameter_transfer import Pretrainer >>> source = "speechbrain/asr-crdnn-rnnlm-librispeech" >>> lm_model_path = source + "/lm.ckpt" >>> tokenizer_path = source + "/tokenizer.ckpt" >>> # define your tokenizer and RNNLM from the HF hub >>> tokenizer = SentencePieceProcessor() >>> lm_model = RNNLM( ... output_neurons = 1000, ... embedding_dim = 128, ... activation = torch.nn.LeakyReLU, ... dropout = 0.0, ... rnn_layers = 2, ... rnn_neurons = 2048, ... dnn_blocks = 1, ... dnn_neurons = 512, ... return_hidden = True, ... ) >>> pretrainer = Pretrainer( ... collect_in = getfixture("tmp_path"), ... loadables = { ... "lm" : lm_model, ... "tokenizer" : tokenizer, ... }, ... paths = { ... "lm" : lm_model_path, ... "tokenizer" : tokenizer_path, ... }) >>> _ = pretrainer.collect_files() >>> pretrainer.load_collected() >>> from speechbrain.decoders.scorer import RNNLMRescorer, RescorerBuilder >>> rnnlm_rescorer = RNNLMRescorer( ... language_model = lm_model, ... tokenizer = tokenizer, ... temperature = 1.0, ... bos_index = 0, ... eos_index = 0, ... pad_index = 0, ... ) >>> # Define a rescorer builder >>> rescorer = RescorerBuilder( ... rescorers=[rnnlm_rescorer], ... weights={"rnnlm":1.0} ... ) >>> # topk hyps >>> topk_hyps = [["HELLO", "HE LLO", "H E L L O"]] >>> topk_scores = [[-2, -2, -2]] >>> rescored_hyps, rescored_scores = rescorer.rescore(topk_hyps, topk_scores) >>> # NOTE: the returned hypotheses are already sorted by score. >>> rescored_hyps # doctest: +SKIP [['HELLO', 'H E L L O', 'HE LLO']] >>> # NOTE: as we are returning log-probs, the more it is closer to 0, the better. >>> rescored_scores # doctest: +SKIP [[-17.863974571228027, -25.12890625, -26.075977325439453]] """ def __init__( self, language_model, tokenizer, device="cuda", temperature=1.0, bos_index=0, eos_index=0, pad_index=0, ): self.lm = language_model self.lm.eval() self.tokenizer = tokenizer self.temperature = temperature self.softmax = sb.nnet.activations.Softmax(apply_log=True) self.device = device self.bos_index = bos_index self.eos_index = eos_index self.pad_index = pad_index
[docs] def normalize_text(self, text): """This method should implement the normalization of the text before scoring. Default to uppercasing the text because the (current) language models are trained on LibriSpeech which is all uppercase. Arguments --------- text : str The text to be normalized. Returns ------- str The normalized text. """ return text.upper()
[docs] def to_device(self, device=None): """This method moves the scorer to a device. If device is None, the scorer is moved to the default device provided in the constructor. Arguments --------- device : str The device to move the scorer to. """ if device is None: self.lm.to(self.device) else: self.lm.to(device)
[docs] def preprocess_func(self, topk_hyps): """This method preprocesses the hypotheses before scoring. Arguments --------- topk_hyps : list of list of str The hypotheses to be preprocessed. Returns ------- padded_hyps : torch.Tensor The padded hypotheses. enc_hyps_length : list of int The length of each hypothesis. """ # 1. normalize text decoded_seq = [] for batch in topk_hyps: for seq in batch: decoded_seq.append(self.normalize_text(seq)) # 2. encode text enc_hyps = [] for seq in decoded_seq: enc_hyps.append( torch.tensor( [self.bos_index] + self.tokenizer.encode_as_ids(seq) + [self.eos_index] ) ) enc_hyps_length = [enc_seq.shape[0] for enc_seq in enc_hyps] # 3. pad sequences padded_hyps = torch.nn.utils.rnn.pad_sequence( enc_hyps, batch_first=True, padding_value=self.pad_index ).to(self.lm.parameters().__next__().device) return padded_hyps, enc_hyps_length
[docs] @torch.no_grad() def rescore_hyps(self, topk_hyps): """This method implement the rescoring of the hypotheses. Arguments --------- topk_hyps : list of list of str The hypotheses to be rescored. Returns ------- log_probs_scores : torch.Tensor[B * Topk, 1] The rescored hypotheses scores """ # preprocess hypotheses padded_hyps, enc_hyps_length = self.preprocess_func(topk_hyps) bool_mask = [ [1 if i < length else 0 for i in range(max(enc_hyps_length))] for length in enc_hyps_length ] bool_mask_tensor = torch.tensor( bool_mask, dtype=torch.bool, device=padded_hyps.device ) if not next(self.lm.parameters()).is_cuda: self.lm.to(padded_hyps.device) # compute scores logits, _ = self.lm(padded_hyps) log_probs = self.softmax(logits / self.temperature) target_log_probs = ( log_probs[:, :-1] .gather(2, padded_hyps[:, 1:].unsqueeze(2)) .squeeze(2) ) log_probs_scores = torch.nansum( target_log_probs * bool_mask_tensor[:, 1:], dim=-1 ) return log_probs_scores
[docs] class TransformerLMRescorer(BaseRescorerInterface): """ A wrapper of TransformerLM based on the BaseRescorerInterface. Arguments --------- language_model : torch.nn.Module A Transformer-based language model. tokenizer : SentencePieceProcessor A SentencePiece tokenizer. device : str The device to move the scorer to. temperature : float Temperature factor applied to softmax. It changes the probability distribution, being softer when T>1 and sharper with T<1. (default: 1.0) bos_index : int The index of the beginning-of-sequence (bos) token. eos_index : int The index of the end-of-sequence (eos) token. pad_index : int The index of the padding token. NOTE ---- This class is intented to be used with a pretrained TransformerLM model. Please see: https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech By default, this model is using SentencePiece tokenizer. Example ------- >>> import torch >>> from sentencepiece import SentencePieceProcessor >>> from speechbrain.lobes.models.transformer.TransformerLM import TransformerLM >>> from speechbrain.utils.parameter_transfer import Pretrainer >>> source = "speechbrain/asr-transformer-transformerlm-librispeech" >>> lm_model_path = source + "/lm.ckpt" >>> tokenizer_path = source + "/tokenizer.ckpt" >>> tokenizer = SentencePieceProcessor() >>> lm_model = TransformerLM( ... vocab=5000, ... d_model=768, ... nhead=12, ... num_encoder_layers=12, ... num_decoder_layers=0, ... d_ffn=3072, ... dropout=0.0, ... activation=torch.nn.GELU, ... normalize_before=False, ... ) >>> pretrainer = Pretrainer( ... collect_in = getfixture("tmp_path"), ... loadables={ ... "lm": lm_model, ... "tokenizer": tokenizer, ... }, ... paths={ ... "lm": lm_model_path, ... "tokenizer": tokenizer_path, ... } ... ) >>> _ = pretrainer.collect_files() >>> pretrainer.load_collected() >>> from speechbrain.decoders.scorer import TransformerLMRescorer, RescorerBuilder >>> transformerlm_rescorer = TransformerLMRescorer( ... language_model=lm_model, ... tokenizer=tokenizer, ... temperature=1.0, ... bos_index=1, ... eos_index=2, ... pad_index=0, ... ) >>> rescorer = RescorerBuilder( ... rescorers=[transformerlm_rescorer], ... weights={"transformerlm": 1.0} ... ) >>> topk_hyps = [["HELLO", "HE LLO", "H E L L O"]] >>> topk_scores = [[-2, -2, -2]] >>> rescored_hyps, rescored_scores = rescorer.rescore(topk_hyps, topk_scores) >>> # NOTE: the returned hypotheses are already sorted by score. >>> rescored_hyps # doctest: +SKIP [["HELLO", "HE L L O", "HE LLO"]] >>> # NOTE: as we are returning log-probs, the more it is closer to 0, the better. >>> rescored_scores # doctest: +SKIP [[-17.863974571228027, -25.12890625, -26.075977325439453]] """ def __init__( self, language_model, tokenizer, device="cuda", temperature=1.0, bos_index=0, eos_index=0, pad_index=0, ): self.lm = language_model self.lm.eval() self.tokenizer = tokenizer self.temperature = temperature self.softmax = sb.nnet.activations.Softmax(apply_log=True) self.device = device self.bos_index = bos_index self.eos_index = eos_index self.pad_index = pad_index
[docs] def normalize_text(self, text): """This method should implement the normalization of the text before scoring. Default to uppercasing the text because the language models are trained on LibriSpeech. Arguments --------- text : str The text to be normalized. Returns ------- str The normalized text. """ return text.upper()
[docs] def to_device(self, device=None): """This method moves the scorer to a device. If device is None, the scorer is moved to the default device provided in the constructor. This method is dynamically called in the recipes when the stage is equal to TEST. Arguments --------- device : str The device to move the scorer to. """ if device is None: self.lm.to(self.device) else: self.lm.to(device)
[docs] def preprocess_func(self, topk_hyps): """This method preprocesses the hypotheses before scoring. Arguments --------- topk_hyps : list of list of str The hypotheses to be preprocessed. Returns ------- padded_hyps : torch.Tensor The padded hypotheses. enc_hyps_length : list of int The length of each hypothesis. """ # 1. normalize decoded_seq = [] for batch in topk_hyps: for seq in batch: decoded_seq.append(self.normalize_text(seq)) # 2. encode text enc_hyps = [] for seq in decoded_seq: enc_hyps.append( torch.tensor( [self.bos_index] + self.tokenizer.encode_as_ids(seq) + [self.eos_index] ) ) enc_hyps_length = [enc_seq.shape[0] for enc_seq in enc_hyps] # 3. pad sequences padded_hyps = torch.nn.utils.rnn.pad_sequence( enc_hyps, batch_first=True, padding_value=self.pad_index ).to(self.lm.parameters().__next__().device) return padded_hyps, enc_hyps_length
[docs] @torch.no_grad() def rescore_hyps(self, topk_hyps): """This method implement the rescoring of the hypotheses. Arguments --------- topk_hyps : list of list of str The hypotheses to be rescored. Returns ------- log_probs_scores : torch.Tensor[B * Topk, 1] The rescored hypotheses scores """ # preprocess hypotheses padded_hyps, enc_hyps_length = self.preprocess_func(topk_hyps) bool_mask = [ [1 if i < length else 0 for i in range(max(enc_hyps_length))] for length in enc_hyps_length ] bool_mask_tensor = torch.tensor( bool_mask, dtype=torch.bool, device=padded_hyps.device ) if not next(self.lm.parameters()).is_cuda: self.lm.to(padded_hyps.device) # compute scores logits = self.lm(padded_hyps) log_probs = self.softmax(logits / self.temperature) log_probs[:, :, self.pad_index] = float("-inf") target_log_probs = ( log_probs[:, :-1] .gather(2, padded_hyps[:, 1:].unsqueeze(2)) .squeeze(2) ) target_log_probs = target_log_probs - log_probs[:, :-1].logsumexp( dim=-1 ) log_probs_scores = torch.nansum( target_log_probs * bool_mask_tensor[:, 1:], dim=-1 ) return log_probs_scores
[docs] class HuggingFaceLMRescorer(BaseRescorerInterface): """ A wrapper of HuggingFace's TransformerLM based on the BaseRescorerInterface. Arguments --------- model_name : str The name of the model to be loaded. device : str The device to be used for scoring. (default: "cuda") Example ------- >>> from speechbrain.decoders.scorer import HuggingFaceLMRescorer, RescorerBuilder >>> source = "gpt2-medium" >>> huggingfacelm_rescorer = HuggingFaceLMRescorer( ... model_name=source, ... ) >>> rescorer = RescorerBuilder( ... rescorers=[huggingfacelm_rescorer], ... weights={"huggingfacelm": 1.0} ... ) >>> topk_hyps = [["Hello everyone.", "Hell o every one.", "Hello every one"]] >>> topk_scores = [[-2, -2, -2]] >>> rescored_hyps, rescored_scores = rescorer.rescore(topk_hyps, topk_scores) >>> # NOTE: the returned hypotheses are already sorted by score. >>> rescored_hyps # doctest: +SKIP [['Hello everyone.', 'Hello every one', 'Hell o every one.']] >>> # NOTE: as we are returning log-probs, the more it is closer to 0, the better. >>> rescored_scores # doctest: +SKIP [[-20.03631591796875, -27.615638732910156, -42.662353515625]] """ def __init__( self, model_name, device="cuda", ): self.model_name = model_name self.device = device try: from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: raise ImportError( "Please install transformers with: pip install transformers" ) self.lm = AutoModelForCausalLM.from_pretrained( self.model_name, is_decoder=True ).eval() self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, use_fast=True, add_special_tokens=False ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = "<|pad|>" self.tokenizer.add_special_tokens( {"additional_special_tokens": [self.tokenizer.pad_token]} ) self.lm.resize_token_embeddings( len(self.tokenizer), pad_to_multiple_of=32 ) self.bos_token = self.tokenizer.bos_token self.eos_token = self.tokenizer.eos_token
[docs] def to_device(self, device=None): """This method moves the scorer to a device. If device is None, the scorer is moved to the default device provided in the constructor. This method is dynamically called in the recipes when the stage is equal to TEST. Arguments --------- device : str The device to move the scorer to. """ if device is None: self.lm.to(self.device) else: self.lm.to(device)
[docs] def normalize_text(self, text): """This method should implement the normalization of the text before scoring. Arguments --------- text : str The text to be normalized. Returns ------- normalized_text : str The normalized text. In this case we do not apply any normalization. However, this method can be overriden to apply any normalization. """ return text
def _add_special_tokens(self, text): """This method adds the special tokens to the text. Arguments --------- text : str The text to be augmented. Returns ------- augmented_text : str The augmented text. """ return self.bos_token + text + self.eos_token
[docs] def preprocess_func(self, topk_hyps): """This method preprocesses the hypotheses before scoring. Arguments --------- topk_hyps : list of str The hypotheses to be preprocessed. Returns ------- encoding : tensor The encoding of the hypotheses. """ # 1. normalize normalized_hyps = [] for batch in topk_hyps: for seq in batch: normalized_hyps.append(self.normalize_text(seq)) text_augmented_with_tokens = list( map(self._add_special_tokens, normalized_hyps) ) encoding = self.tokenizer.batch_encode_plus( text_augmented_with_tokens, return_tensors="pt", padding=True ) return encoding
[docs] @torch.no_grad() def rescore_hyps(self, topk_hyps): """This method implement the rescoring of the hypotheses. Arguments --------- topk_hyps : list of list of str The hypotheses to be rescored. Returns ------- log_probs_scores : torch.Tensor[B * Topk, 1] The rescored hypotheses scores """ encoding = self.preprocess_func(topk_hyps) ids = encoding["input_ids"].to(self.lm.device) attention_mask = encoding["attention_mask"].to(self.lm.device) logits = self.lm(ids, attention_mask=attention_mask)[0] logits[:, :, self.tokenizer.pad_token_id :] = float("-inf") target_log_probs = ( logits[:, :-1].gather(2, ids[:, 1:].unsqueeze(2)).squeeze(2) ) target_log_probs = target_log_probs - logits[:, :-1].logsumexp(dim=-1) log_probs_scores = torch.nansum( target_log_probs * attention_mask[:, 1:], dim=-1 ) return log_probs_scores
[docs] class RescorerBuilder: """ Builds rescorer instance for beamsearch. The RecorerBuilder class is responsible for building a scorer instance for beam search. It takes weights and rescorers classes. It combines the scorers based on the weights specified and provides methods for rescoring text. This is the class to be used for building rescorer instances for beam search. Arguments --------- weights : dict Weights of rescorers specified. rescorers : list Rescorers that re-ranks topk hypotheses. """ def __init__( self, weights=dict(), rescorers=list(), ): assert len(weights) == len( rescorers ), "Weights and rescorers are not matched." self.weights = weights all_rescorer_names = [ k.lower().split("rescorer")[0] for k in globals().keys() if k.endswith("Rescorer") ] full_rescorer_names = [ impl.__class__.__name__.lower().split("rescorer")[0] for impl in rescorers ] # Have a default 0.0 weight for scorer not specified init_weights = {k: 0.0 for k in all_rescorer_names} self.weights = {**init_weights, **weights} self.rescorers = dict(zip(full_rescorer_names, rescorers)) self._validate_scorer(all_rescorer_names)
[docs] def rescore(self, topk_candidates, topk_scores): """This method rescores the topk candidates. Arguments --------- topk_candidates : list of list of str The topk candidates to be rescored. topk_scores : list of list of float The scores of the topk candidates. Returns ------- output_candidates : list of list of str The rescored candidates. output_scores : list of list of float The rescored scores. """ new_scores = topk_scores.copy() for k, impl in self.rescorers.items(): scores = impl.rescore_hyps(topk_candidates) index_scores = 0 for i in range(len(new_scores)): for j in range(len(new_scores[i])): new_scores[i][j] += ( self.weights[k] * scores[index_scores].item() ) index_scores += 1 sorted_candidates = [ list( zip( *sorted( zip(sublist, score), key=lambda x: x[1], reverse=True ) ) for sublist, score in zip(topk_candidates, new_scores) ) ] output_candidates = [] output_scores = [] for sublist in sorted_candidates: for item in sublist: texts, scores = item output_candidates.append(list(texts)) output_scores.append(list(scores)) return output_candidates, output_scores
def _validate_scorer(self, rescorer_names): """These error messages indicate rescorers are not properly set. Arguments --------- rescorer_names : list Prefix of rescorers defined in speechbrain.decoders.scorer. """ if len(self.weights) > len(rescorer_names): raise ValueError( "The keys of weights should be named in {}".format( rescorer_names ) )
[docs] def move_rescorers_to_device(self, device=None): """Moves rescorers to device. Usefull to avoid having on GPU rescorers while being on TRAIN and VALID Stages. Arguments --------- device : str The device to be used for scoring. (default: None) """ for _, impl in self.rescorers.items(): impl.to_device(device)