speechbrain.alignment.aligner module

Alignment code

Authors
  • Elena Rastorgueva 2020

  • Loren Lugosch 2020

Summary

Classes:

HMMAligner

This class calculates Viterbi alignments in the forward method.

Functions:

batch_log_matvecmul

For each ‘matrix’ and ‘vector’ pair in the batch, do matrix-vector multiplication in the log domain, i.e., logsumexp instead of add, add instead of multiply.

batch_log_maxvecmul

Similar to batch_log_matvecmul, but takes a maximum instead of logsumexp.

map_inds_to_intersect

Converts 2 lists containing indices for phonemes from different phoneme sets to a single phoneme so that comparing the equality of the indices of the resulting lists will yield the correct accuracy.

Reference

class speechbrain.alignment.aligner.HMMAligner(states_per_phoneme=1, output_folder='', neg_inf=- 100000.0, batch_reduction='none', input_len_norm=False, target_len_norm=False, lexicon_path=None)[source]

Bases: torch.nn.modules.module.Module

This class calculates Viterbi alignments in the forward method.

It also records alignments and creates batches of them for use in Viterbi training.

Parameters
  • states_per_phoneme (int) – Number of hidden states to use per phoneme.

  • output_folder (str) – It is the folder that the alignments will be stored in when saved to disk. Not yet implemented.

  • neg_inf (float) – The float used to represent a negative infinite log probability. Using -float(“Inf”) tends to give numerical instability. A number more negative than -1e5 also sometimes gave errors when the genbmm library was used (currently not in use). (default: -1e5)

  • batch_reduction (string) – One of “none”, “sum” or “mean”. What kind of batch-level reduction to apply to the loss calculated in the forward method.

  • input_len_norm (bool) – Whether to normalize the loss in the forward method by the length of the inputs.

  • target_len_norm (bool) – Whether to normalize the loss in the forward method by the length of the targets.

  • lexicon_path (string) – The location of the lexicon.

Example

>>> log_posteriors = torch.tensor([[[ -1., -10., -10.],
...                                 [-10.,  -1., -10.],
...                                 [-10., -10.,  -1.]],
...
...                                [[ -1., -10., -10.],
...                                 [-10.,  -1., -10.],
...                                 [-10., -10., -10.]]])
>>> lens = torch.tensor([1., 0.66])
>>> phns = torch.tensor([[0, 1, 2],
...                      [0, 1, 0]])
>>> phn_lens = torch.tensor([1., 0.66])
>>> aligner = HMMAligner()
>>> forward_scores = aligner(
...        log_posteriors, lens, phns, phn_lens, 'forward'
... )
>>> forward_scores.shape
torch.Size([2])
>>> viterbi_scores, alignments = aligner(
...        log_posteriors, lens, phns, phn_lens, 'viterbi'
... )
>>> alignments
[[0, 1, 2], [0, 1]]
>>> viterbi_scores.shape
torch.Size([2])
use_lexicon(words, interword_sils=True, sample_pron=False)[source]

Do processing using the lexicon to return a sequence of the possible phonemes, the transition/pi probabilities, and the possible final states. Does processing on an utterance-by-utterance basis. Each utterance in the batch is processed by a helper method _use_lexicon.

Parameters
  • words (list) – List of the words in the transcript

  • interword_sils (bool) – If True, optional silences will be inserted between every word. If False, optional silences will only be placed at the beginning and end of each utterance.

  • sample_pron (bool) – If True, it will sample a single possible sequence of phonemes. If False, it will return statistics for all possible sequences of phonemes.

Returns

  • poss_phns (torch.Tensor (batch, phoneme in possible phn sequence)) – The phonemes that are thought to be in each utterance.

  • poss_phn_lens (torch.Tensor (batch)) – The relative length of each possible phoneme sequence in the batch.

  • trans_prob (torch.Tensor (batch, from, to)) – Tensor containing transition (log) probabilities.

  • pi_prob (torch.Tensor (batch, state)) – Tensor containing initial (log) probabilities.

  • final_state (list of lists of ints) – A list of lists of possible final states for each utterance.

Example

>>> aligner = HMMAligner()
>>> aligner.lexicon = {
...                     "a": {0: "a"},
...                     "b": {0: "b", 1: "c"}
...                   }
>>> words = [["a", "b"]]
>>> aligner.lex_lab2ind = {
...                   "sil": 0,
...                   "a":  1,
...                   "b":  2,
...                   "c":  3,
...                 }
>>> poss_phns, poss_phn_lens, trans_prob, pi_prob, final_states = aligner.use_lexicon(
...     words,
...     interword_sils = True
... )
>>> poss_phns
tensor([[0, 1, 0, 2, 3, 0]])
>>> poss_phn_lens
tensor([1.])
>>> trans_prob
tensor([[[-6.9315e-01, -6.9315e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05,
          -1.0000e+05],
         [-1.0000e+05, -1.3863e+00, -1.3863e+00, -1.3863e+00, -1.3863e+00,
          -1.0000e+05],
         [-1.0000e+05, -1.0000e+05, -1.0986e+00, -1.0986e+00, -1.0986e+00,
          -1.0000e+05],
         [-1.0000e+05, -1.0000e+05, -1.0000e+05, -6.9315e-01, -1.0000e+05,
          -6.9315e-01],
         [-1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -6.9315e-01,
          -6.9315e-01],
         [-1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,
           0.0000e+00]]])
>>> pi_prob
tensor([[-6.9315e-01, -6.9315e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05,
         -1.0000e+05]])
>>> final_states
[[3, 4, 5]]
>>> # With no optional silences between words
>>> poss_phns_, _, trans_prob_, pi_prob_, final_states_ = aligner.use_lexicon(
...     words,
...     interword_sils = False
... )
>>> poss_phns_
tensor([[0, 1, 2, 3, 0]])
>>> trans_prob_
tensor([[[-6.9315e-01, -6.9315e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05],
         [-1.0000e+05, -1.0986e+00, -1.0986e+00, -1.0986e+00, -1.0000e+05],
         [-1.0000e+05, -1.0000e+05, -6.9315e-01, -1.0000e+05, -6.9315e-01],
         [-1.0000e+05, -1.0000e+05, -1.0000e+05, -6.9315e-01, -6.9315e-01],
         [-1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,  0.0000e+00]]])
>>> pi_prob_
tensor([[-6.9315e-01, -6.9315e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05]])
>>> final_states_
[[2, 3, 4]]
>>> # With sampling of a single possible pronunciation
>>> import random
>>> random.seed(0)
>>> poss_phns_, _, trans_prob_, pi_prob_, final_states_ = aligner.use_lexicon(
...     words,
...     sample_pron = True
... )
>>> poss_phns_
tensor([[0, 1, 0, 2, 0]])
>>> trans_prob_
tensor([[[-6.9315e-01, -6.9315e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05],
         [-1.0000e+05, -1.0986e+00, -1.0986e+00, -1.0986e+00, -1.0000e+05],
         [-1.0000e+05, -1.0000e+05, -6.9315e-01, -6.9315e-01, -1.0000e+05],
         [-1.0000e+05, -1.0000e+05, -1.0000e+05, -6.9315e-01, -6.9315e-01],
         [-1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,  0.0000e+00]]])
forward(emission_pred, lens, phns, phn_lens, dp_algorithm, prob_matrices=None)[source]

Prepares relevant (log) probability tensors and does dynamic programming: either the forward or the Viterbi algorithm. Applies reduction as specified during object initialization.

Parameters
  • emission_pred (torch.Tensor (batch, time, phoneme in vocabulary)) – Posterior probabilities from our acoustic model.

  • lens (torch.Tensor (batch)) – The relative duration of each utterance sound file.

  • phns (torch.Tensor (batch, phoneme in phn sequence)) – The phonemes that are known/thought to be in each utterance

  • phn_lens (torch.Tensor (batch)) – The relative length of each phoneme sequence in the batch.

  • dp_algorithm (string) – Either “forward” or “viterbi”.

  • prob_matrices (dict) – (Optional) Must contain keys ‘trans_prob’, ‘pi_prob’ and ‘final_states’. Used to override the default forward and viterbi operations which force traversal over all of the states in the phns sequence.

Returns

  1. if dp_algorithm == “forward”.

    forward_scores : torch.Tensor (batch, or scalar)

    The (log) likelihood of each utterance in the batch, with reduction applied if specified. (OR)

  2. if dp_algorithm == “viterbi”.

    viterbi_scores : torch.Tensor (batch, or scalar)

    The (log) likelihood of the Viterbi path for each utterance, with reduction applied if specified.

    alignments : list of lists of int

    Viterbi alignments for the files in the batch.

Return type

tensor

expand_phns_by_states_per_phoneme(phns, phn_lens)[source]

Expands each phoneme in the phn sequence by the number of hidden states per phoneme defined in the HMM.

Parameters
  • phns (torch.Tensor (batch, phoneme in phn sequence)) – The phonemes that are known/thought to be in each utterance.

  • phn_lens (torch.Tensor (batch)) – The relative length of each phoneme sequence in the batch.

Returns

expanded_phns

Return type

torch.Tensor (batch, phoneme in expanded phn sequence)

Example

>>> phns = torch.tensor([[0., 3., 5., 0.],
...                      [0., 2., 0., 0.]])
>>> phn_lens = torch.tensor([1., 0.75])
>>> aligner = HMMAligner(states_per_phoneme = 3)
>>> expanded_phns = aligner.expand_phns_by_states_per_phoneme(
...         phns, phn_lens
... )
>>> expanded_phns
tensor([[ 0.,  1.,  2.,  9., 10., 11., 15., 16., 17.,  0.,  1.,  2.],
        [ 0.,  1.,  2.,  6.,  7.,  8.,  0.,  1.,  2.,  0.,  0.,  0.]])
store_alignments(ids, alignments)[source]

Records Viterbi alignments in self.align_dict.

Parameters
  • ids (list of str) – IDs of the files in the batch.

  • alignments (list of lists of int) – Viterbi alignments for the files in the batch. Without padding.

Example

>>> aligner = HMMAligner()
>>> ids = ['id1', 'id2']
>>> alignments = [[0, 2, 4], [1, 2, 3, 4]]
>>> aligner.store_alignments(ids, alignments)
>>> aligner.align_dict.keys()
dict_keys(['id1', 'id2'])
>>> aligner.align_dict['id1']
tensor([0, 2, 4], dtype=torch.int16)
get_prev_alignments(ids, emission_pred, lens, phns, phn_lens)[source]

Fetches previously recorded Viterbi alignments if they are available. If not, fetches flat start alignments. Currently, assumes that if a Viterbi alignment is not available for the first utterance in the batch, it will not be available for the rest of the utterances.

Parameters
  • ids (list of str) – IDs of the files in the batch.

  • emission_pred (torch.Tensor (batch, time, phoneme in vocabulary)) – Posterior probabilities from our acoustic model. Used to infer the duration of the longest utterance in the batch.

  • lens (torch.Tensor (batch)) – The relative duration of each utterance sound file.

  • phns (torch.Tensor (batch, phoneme in phn sequence)) – The phonemes that are known/thought to be in each utterance.

  • phn_lens (torch.Tensor (batch)) – The relative length of each phoneme sequence in the batch.

Returns

Zero-padded alignments.

Return type

torch.Tensor (batch, time)

Example

>>> ids = ['id1', 'id2']
>>> emission_pred = torch.tensor([[[ -1., -10., -10.],
...                                [-10.,  -1., -10.],
...                                [-10., -10.,  -1.]],
...
...                               [[ -1., -10., -10.],
...                                [-10.,  -1., -10.],
...                                [-10., -10., -10.]]])
>>> lens = torch.tensor([1., 0.66])
>>> phns = torch.tensor([[0, 1, 2],
...                      [0, 1, 0]])
>>> phn_lens = torch.tensor([1., 0.66])
>>> aligner = HMMAligner()
>>> alignment_batch = aligner.get_prev_alignments(
...        ids, emission_pred, lens, phns, phn_lens
... )
>>> alignment_batch
tensor([[0, 1, 2],
        [0, 1, 0]])
calc_accuracy(alignments, ends, phns, ind2labs=None)[source]

Calculates mean accuracy between predicted alignments and ground truth alignments. Ground truth alignments are derived from ground truth phns and their ends in the audio sample.

Parameters
  • alignments (list of lists of ints/floats) – The predicted alignments for each utterance in the batch.

  • ends (list of lists of ints) – A list of lists of sample indices where each ground truth phoneme ends, according to the transcription. Note: current implementation assumes that ‘ends’ mark the index where the next phoneme begins.

  • phns (list of lists of ints/floats) – The unpadded list of lists of ground truth phonemes in the batch.

  • ind2labs (tuple) – (Optional) Contains the original index-to-label dicts for the first and second sequence of phonemes.

Returns

mean_acc – The mean percentage of times that the upsampled predicted alignment matches the ground truth alignment.

Return type

float

Example

>>> aligner = HMMAligner()
>>> alignments = [[0., 0., 0., 1.]]
>>> phns = [[0., 1.]]
>>> ends = [[2, 4]]
>>> mean_acc = aligner.calc_accuracy(alignments, ends, phns)
>>> mean_acc.item()
75.0
collapse_alignments(alignments)[source]

Converts alignments to 1 state per phoneme style.

Parameters

alignments (list of ints) – Predicted alignments for a single utterance.

Returns

sequence – The predicted alignments converted to a 1 state per phoneme style.

Return type

list of ints

Example

>>> aligner = HMMAligner(states_per_phoneme = 3)
>>> alignments = [0, 1, 2, 3, 4, 5, 3, 4, 5, 0, 1, 2]
>>> sequence = aligner.collapse_alignments(alignments)
>>> sequence
[0, 1, 1, 0]
training: bool
speechbrain.alignment.aligner.map_inds_to_intersect(lists1, lists2, ind2labs)[source]

Converts 2 lists containing indices for phonemes from different phoneme sets to a single phoneme so that comparing the equality of the indices of the resulting lists will yield the correct accuracy.

Parameters
  • lists1 (list of lists of ints) – Contains the indices of the first sequence of phonemes.

  • lists2 (list of lists of ints) – Contains the indices of the second sequence of phonemes.

  • ind2labs (tuple (dict, dict)) – Contains the original index-to-label dicts for the first and second sequence of phonemes.

Returns

  • lists1_new (list of lists of ints) – Contains the indices of the first sequence of phonemes, mapped to the new phoneme set.

  • lists2_new (list of lists of ints) – Contains the indices of the second sequence of phonemes, mapped to the new phoneme set.

Example

>>> lists1 = [[0, 1]]
>>> lists2 = [[0, 1]]
>>> ind2lab1 = {
...        0: "a",
...        1: "b",
...        }
>>> ind2lab2 = {
...        0: "a",
...        1: "c",
...        }
>>> ind2labs = (ind2lab1, ind2lab2)
>>> out1, out2 = map_inds_to_intersect(lists1, lists2, ind2labs)
>>> out1
[[0, 1]]
>>> out2
[[0, 2]]
speechbrain.alignment.aligner.batch_log_matvecmul(A, b)[source]

For each ‘matrix’ and ‘vector’ pair in the batch, do matrix-vector multiplication in the log domain, i.e., logsumexp instead of add, add instead of multiply.

Parameters

Example

>>> A = torch.tensor([[[   0., 0.],
...                    [ -1e5, 0.]]])
>>> b = torch.tensor([[0., 0.,]])
>>> x = batch_log_matvecmul(A, b)
>>> x
tensor([[0.6931, 0.0000]])
>>>
>>> # non-log domain equivalent without batching functionality
>>> A_ = torch.tensor([[1., 1.],
...                    [0., 1.]])
>>> b_ = torch.tensor([1., 1.,])
>>> x_ = torch.matmul(A_, b_)
>>> x_
tensor([2., 1.])
speechbrain.alignment.aligner.batch_log_maxvecmul(A, b)[source]

Similar to batch_log_matvecmul, but takes a maximum instead of logsumexp. Returns both the max and the argmax.

Parameters

Example

>>> A = torch.tensor([[[   0., -1.],
...                    [ -1e5,  0.]]])
>>> b = torch.tensor([[0., 0.,]])
>>> x, argmax = batch_log_maxvecmul(A, b)
>>> x
tensor([[0., 0.]])
>>> argmax
tensor([[0, 1]])