Source code for speechbrain.lobes.models.g2p.homograph
"""Tools for homograph disambiguation
Authors
* Artem Ploujnikov 2021
"""
import torch
from torch import nn
[docs]
class SubsequenceLoss(nn.Module):
"""
A loss function for a specific word in the output, used in
the homograph disambiguation task
The approach is as follows:
1. Arrange only the target words from the original batch into a
single tensor
2. Find the word index of each target word
3. Compute the beginnings and endings of words in the predicted
sequences. The assumption is that the model has been trained well
enough to identify word boundaries with a simple argmax without
having to perform a beam search.
Important! This loss can be used for fine-tuning only
The model is expected to be able to already be able
to correctly predict word boundaries
Arguments
---------
seq_cost: callable
the loss to be used on the extracted subsequences
word_separator: int
the index of the "space" character (in phonemes)
word_separator_base: str
the index of word separators used in unprocessed
targets (if different, used with tokenizations)
Example
-------
>>> import torch
>>> from speechbrain.lobes.models.g2p.homograph import SubsequenceLoss
>>> from speechbrain.nnet.losses import nll_loss
>>> loss = SubsequenceLoss(
... seq_cost=nll_loss
... )
>>> phns = torch.Tensor(
... [[1, 2, 0, 1, 3, 0, 2, 1, 0],
... [2, 1, 3, 0, 1, 2, 0, 3, 2]]
... )
>>> phn_lens = torch.IntTensor([8, 9])
>>> subsequence_phn_start = torch.IntTensor([3, 4])
>>> subsequence_phn_end = torch.IntTensor([5, 7])
>>> p_seq = torch.Tensor([
... [[0., 1., 0., 0.],
... [0., 0., 1., 0.],
... [1., 0., 0., 0.],
... [0., 1., 0., 0.],
... [0., 0., 0., 1.],
... [1., 0., 0., 0.],
... [0., 0., 1., 0.],
... [0., 1., 0., 0.],
... [1., 0., 0., 0.]],
... [[0., 0., 1., 0.],
... [0., 1., 0., 0.],
... [0., 0., 0., 1.],
... [1., 0., 0., 0.],
... [0., 1., 0., 0.],
... [0., 0., 1., 0.],
... [1., 0., 0., 0.],
... [0., 0., 0., 1.],
... [0., 0., 1., 0.]]
... ])
>>> loss_value = loss(
... phns,
... phn_lens,
... p_seq,
... subsequence_phn_start,
... subsequence_phn_end
... )
>>> loss_value
tensor(-0.8000)
"""
def __init__(self, seq_cost, word_separator=0, word_separator_base=0):
super().__init__()
self.seq_cost = seq_cost
self._subsequence_extractor = SubsequenceExtractor(
word_separator, word_separator_base
)
@property
def word_separator(self):
"""
The word separator being used
"""
return self._subsequence_extractor.word_separator
@word_separator.setter
def word_separator(self, value):
"""
Sets the word separator
"""
self._subsequence_extractor.word_separator = value
@property
def word_separator_base(self):
"""
The word separator being used
"""
return self._subsequence_extractor.word_separator_base
@word_separator.setter
def word_separator_base(self, value): # noqa
"""
Sets the base word separator
"""
self._subsequence_extractor.word_separator_base = value
[docs]
def forward(
self,
phns,
phn_lens,
p_seq,
subsequence_phn_start,
subsequence_phn_end,
phns_base=None,
phn_lens_base=None,
):
"""
Evaluates the subsequence loss
Arguments
---------
phns: torch.Tensor
the phoneme tensor (batch x length)
phn_lens: torch.Tensor
the phoneme length tensor
p_seq: torch.Tensor
the output phoneme probability tensor
(batch x length x phns)
subsequence_phn_start: torch.Tensor
the beginning of the target subsequence
(i.e. the homograph)
subsequence_phn_end: torch.Tensor
the end of the target subsequence
(i.e. the homograph)
phns_base: torch.Tensor
the phoneme tensor (not preprocessed)
phn_lens_base: torch.Tensor
the phoneme lengths (not preprocessed)
Returns
-------
loss: torch.Tensor
the loss tensor
"""
(
p_seq_subsequence,
phns_subsequence,
subsequence_lengths,
) = self._subsequence_extractor(
phns,
phn_lens,
p_seq,
subsequence_phn_start,
subsequence_phn_end,
phns_base,
phn_lens_base,
)
return self.seq_cost(
p_seq_subsequence, phns_subsequence, subsequence_lengths
)
[docs]
class SubsequenceExtractor:
"""
A utility class to help extract subsequences out of a batch
of sequences
Arguments
---------
word_separator: int
the index of the word separator (used in p_seq)
word_separator_base: int
the index of word separators used in unprocessed
targets (if different)
Example
-------
>>> import torch
>>> from speechbrain.lobes.models.g2p.homograph import SubsequenceExtractor
>>> extractor = SubsequenceExtractor()
>>> phns = torch.Tensor(
... [[1, 2, 0, 1, 3, 0, 2, 1, 0],
... [2, 1, 3, 0, 1, 2, 0, 3, 2]]
... )
>>> phn_lens = torch.IntTensor([8, 9])
>>> subsequence_phn_start = torch.IntTensor([3, 4])
>>> subsequence_phn_end = torch.IntTensor([5, 7])
>>> p_seq = torch.Tensor([
... [[0., 1., 0., 0.],
... [0., 0., 1., 0.],
... [1., 0., 0., 0.],
... [0., 1., 0., 0.],
... [0., 0., 0., 1.],
... [1., 0., 0., 0.],
... [0., 0., 1., 0.],
... [0., 1., 0., 0.],
... [1., 0., 0., 0.]],
... [[0., 0., 1., 0.],
... [0., 1., 0., 0.],
... [0., 0., 0., 1.],
... [1., 0., 0., 0.],
... [0., 1., 0., 0.],
... [0., 0., 1., 0.],
... [1., 0., 0., 0.],
... [0., 0., 0., 1.],
... [0., 0., 1., 0.]]
... ])
>>> extractor.extract_seq(
... phns,
... phn_lens,
... p_seq,
... subsequence_phn_start,
... subsequence_phn_end
... )
(tensor([[[0., 1., 0., 0.],
[0., 0., 0., 1.],
[0., 0., 0., 0.]],
<BLANKLINE>
[[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 0.]]]), tensor([[1., 3., 0.],
[1., 2., 0.]]), tensor([0.6667, 1.0000]))
"""
def __init__(self, word_separator=0, word_separator_base=None):
self.word_separator = word_separator
if word_separator_base is None:
word_separator_base = word_separator
self.word_separator_base = word_separator_base
def __call__(self, *args, **kwargs):
return self.extract_seq(*args, **kwargs)
[docs]
def extract_seq(
self,
phns,
phn_lens,
p_seq,
subsequence_phn_start,
subsequence_phn_end,
phns_base=None,
phn_base_lens=None,
):
"""
Extracts the subsequence from the complete sequence
Arguments
---------
phns: torch.Tensor
the phoneme tensor (batch x length)
phn_lens: torch.Tensor
the phoneme length tensor
p_seq: torch.Tensor
the output phoneme probability tensor
(batch x length x phns)
subsequence_phn_start: torch.Tensor
the beginning of the target subsequence
(i.e. the homograph)
subsequence_phn_end: torch.Tensor
the end of the target subsequence
(i.e. the homograph)
phns_base: torch.Tensor
the phoneme tensor (not preprocessed)
phn_base_lens: torch.Tensor
the phoneme lengths (not preprocessed)
Returns
-------
p_seq_subsequence: torch.Tensor
the output subsequence (of probabilities)
phns_subsequence: torch.Tensor
the target subsequence
subsequence_lengths: torch.Tensor
subsequence lengths, expressed as a fraction
of the tensor's last dimension
"""
has_base = False
if phns_base is None and phn_base_lens is None:
phns_base = phns
phn_base_lens = phn_lens
elif phns_base is None or phn_base_lens is None:
raise ValueError(
"phn_base and phn_lens_base, if provided, should be provided together"
)
else:
has_base = True
p_seq_edge = p_seq.size(1)
phns_edge = (phns.size(1) * phn_lens).long().unsqueeze(-1)
# Compute subsequence lengths and the longest length
subsequence_lengths = subsequence_phn_end - subsequence_phn_start
longest_subsequence = subsequence_lengths.max()
# Pad the sequence axis to make sure the "distance" from the start of
# each subsequence to the end of the sequence is at least as long
# as the longest subsequence (e.g. subsequence = homograph)
phns = self._pad_subsequence(phns, longest_subsequence)
phns_base = self._pad_subsequence(phns_base, longest_subsequence)
# p_seq_pad = (gap + longest_subsequence + 1).item()
p_seq_pad = p_seq.size(1)
p_seq = torch.nn.functional.pad(p_seq, (0, 0, 0, p_seq_pad))
# Copy only the subsequences from the targets and inputs
# into new tensors
subsequence_phn_start_unsq = subsequence_phn_start.unsqueeze(-1)
range_phns_base = torch.arange(
phns_base.size(1), device=phns_base.device
).expand_as(phns_base)
range_phns_subsequence = torch.arange(
longest_subsequence, device=phns.device
).expand(phns.size(0), longest_subsequence)
# Count the words in predictions
target_word_indexes = self._get_target_word_indexes(
phns_base,
range_phns_base,
subsequence_phn_start_unsq,
self.word_separator_base,
phn_lens=phn_base_lens,
)
if has_base:
# Needed if tokenization or any other transformation was used
phns_subsequence, subsequence_lengths = self._get_phns_subsequence(
phns, target_word_indexes, longest_subsequence, phns_edge
)
else:
# If phns and phns_base are the same, there is no need to re-detect word boundaries
match = (range_phns_base >= subsequence_phn_start_unsq) & (
range_phns_base
< subsequence_phn_start_unsq + longest_subsequence
)
phns_subsequence = phns[match].reshape(range_phns_subsequence.shape)
phns_subsequence[
range_phns_subsequence >= subsequence_lengths.unsqueeze(-1)
] = 0.0
p_seq_subsequence = self._get_p_seq_subsequence(
p_seq, target_word_indexes, longest_subsequence, p_seq_edge
)
return (
p_seq_subsequence,
phns_subsequence,
subsequence_lengths / longest_subsequence,
)
def _pad_subsequence(self, sequence, longest_subsequence):
"""Pads a subsequence to the length of the longest subsequence
Arguments
---------
sequence: torch.Tensor
the sequence to be padded
longest_subsequence: int
the length of the longest subsequence
Returns
-------
sequence: torch.Tensor
The padded sequence
"""
if longest_subsequence > 0:
sequence = torch.nn.functional.pad(
sequence, (0, longest_subsequence)
)
return sequence
def _get_phns_subsequence(
self, phns, target_word_indexes, longest_subsequence, edge
):
"""Extracts a subsequence
Arguments
---------
phns: torch.Tensor
a tensor of phoneme indexes
target_word_indexes: torch.Tensor
a tensor of word indexes to extract, zero-based
(e.g.) torch.IntTensor([2, 3]) means extracting
the third word from the first sample and the
fourth word from the second sample
longest_subsequence: int
the length of the longest subsequence
edge: int
the index of the "edge" of the sequence
Returns
-------
phn_subsequence: torch.Tensor
a tensor with only the target words
subsequence_lengths: torch.Tensor
the lengths of the extracted words
"""
word_start, word_end = self._get_word_boundaries(
phns, target_word_indexes, edge
)
word_start_unsq = word_start.unsqueeze(-1)
word_end_unsq = word_end.unsqueeze(-1)
phns_range = (
torch.arange(phns.size(1), device=phns.device)
.unsqueeze(0)
.expand_as(phns)
)
phn_match = (phns_range >= word_start_unsq) & (
phns_range < word_start_unsq + longest_subsequence
)
phns_subsequence = phns[phn_match].view(
phns.size(0), longest_subsequence
)
phns_subsequence_range = (
torch.arange(
phns_subsequence.size(1), device=phns_subsequence.device
)
.unsqueeze(0)
.expand_as(phns_subsequence)
)
phns_subsequence[
phns_subsequence_range >= (word_end_unsq - word_start_unsq)
] = 0.0
subsequence_lengths = torch.minimum(
word_end - word_start, torch.tensor(phns_subsequence.size(1))
)
return phns_subsequence, subsequence_lengths
def _get_p_seq_subsequence(
self, p_seq, target_word_indexes, longest_subsequence, edge
):
"""Extracts a subsequence out of a tensor of probabilities
Arguments
---------
p_seq: torch.Tensor
a tensor of phoneme probabilities
(batch x sequence index x phoneme index)
target_word_indexes: torch.Tensor
a tensor of word indexes to extract, zero-based
(e.g.) torch.IntTensor([2, 3]) means extracting
the third word from the first sample and the
fourth word from the second sample
longest_subsequence: int
the length of the longest subsequence
edge: int
the index of the "edge" of the sequence
Returns
-------
p_seq_subsequence: torch.Tensor
a probability tensor composed of the phoneme
probabilities for target words only
"""
# Determine where the predicted subsequences start and end
word_start, word_end = self._get_word_boundaries(
p_seq, target_word_indexes, edge
)
p_seq_range = (
torch.arange(p_seq.size(1), device=p_seq.device)
.unsqueeze(0)
.unsqueeze(-1)
.expand_as(p_seq)
)
word_start_unsq = word_start.unsqueeze(-1).unsqueeze(-1)
word_end_unsq = word_end.unsqueeze(-1).unsqueeze(-1)
phn_match = (p_seq_range >= word_start_unsq) & (
p_seq_range < word_start_unsq + longest_subsequence
)
p_seq_subsequence = p_seq[phn_match].view(
p_seq.size(0), longest_subsequence, p_seq.size(-1)
)
p_seq_subsequence_range = (
torch.arange(
p_seq_subsequence.size(1), device=p_seq_subsequence.device
)
.unsqueeze(0)
.unsqueeze(-1)
.expand_as(p_seq_subsequence)
)
p_seq_subsequence[
p_seq_subsequence_range >= (word_end_unsq - word_start_unsq)
] = 0.0
return p_seq_subsequence
def _get_target_word_indexes(
self, phns, range_phns, start, word_separator, phn_lens=None
):
"""Computes the target word indexes
Arguments
---------
phns: torch.Tensor
a phoneme batch tensor
range_phns: torch.Tensor
a range tensor over thephoneme sequence
start: torch.Tensor
the beginning of the subsequence
word_separator: int
the word separator being used
phn_lens: torch.Tensor
Lengths corresponding to input phns
Returns
-------
word_indexes: torch.Tensor
the word index tensor
"""
end_of_sequence = (
(range_phns == ((phn_lens).unsqueeze(-1) * phns.size(1)).long())
if phn_lens is not None
else False
)
word_boundaries = (range_phns < start) & (
(phns == word_separator) | end_of_sequence
)
word_indexes = word_boundaries.sum(dim=-1)
return word_indexes
def _get_word_boundaries(
self, seq, word_indexes, edge, word_separator=None
):
"""Determines the word boundaries for the specified
word indexes within a sequence
Arguments
---------
seq: torch.Tensor
a sequence (phonemes or graphemes)
word_indexes: torch.Tensor
the word indexes
edge: int
a tensor indicating the last position
word_separator: int
the word separator token
Returns
-------
start: torch.Tensor
word start indexes
end: torch.Tensor
word end indexes
"""
if word_separator is None:
word_separator = self.word_separator
# Find all spaces in the tensor
tokens = seq.argmax(-1) if seq.dim() == 3 else seq
# Compute an auxiliary range tensor to help determine
# word boundaries
words_range = torch.arange(
tokens.size(-1), device=tokens.device
).expand_as(tokens)
word_boundaries = (tokens == word_separator) | (words_range == edge)
# Find which word a given position in the tensor belongs in
words = word_boundaries.cumsum(dim=-1)
index_match = words == word_indexes.unsqueeze(-1)
start = self._get_positions(index_match, words_range, torch.min, edge)
end = self._get_positions(index_match, words_range, torch.max, 0)
return start, end
def _get_positions(
self, index_match, words_range, aggregation, no_match_value
):
"""A helper method to calculate start or end positions corresponding
to specific words
Arguments
---------
index_match: torch.Tensor
a mask where positions matching the word index are
indicated as a 1 and the remaining positions are 0
words_range: torch.Tensor
a range tensor over the tokens
aggregation: callable
the aggregation to use (torch.min or torch.max)
no_match_value: int
the value to output if no match is found (this could
happen when searching in model outputs rather than
in source data)
Returns
-------
Start or end positions of specific words.
"""
positions = torch.where(index_match, words_range, no_match_value)
positions = aggregation(positions, dim=-1).values
return torch.where(positions == 0, 0, positions + 1)
[docs]
def extract_hyps(
self, ref_seq, hyps, subsequence_phn_start, use_base=False
):
"""Extracts a subsequence from hypotheses (e.g. the result of a beam
search) based on a reference sequence, which can be either a sequence of phonemes (the target during training)
Arguments
---------
ref_seq: torch.Tensor
a reference sequence (e.g. phoneme targets)
hyps: list
a batch of hypotheses, a list of list of
integer indices (usually of phonemes)
subsequence_phn_start: torch.Tensor
the index of the beginning of the subsequence to
use_base: bool
whether to use the raw (token) space for word separators
Returns
-------
result: torch.Tensor
The extracted subsequence.
"""
range_phns = torch.arange(
ref_seq.size(1), device=ref_seq.device
).expand_as(ref_seq)
target_word_indexes = self._get_target_word_indexes(
ref_seq,
range_phns,
subsequence_phn_start.unsqueeze(-1),
self.word_separator_base if use_base else self.word_separator,
)
separator_indexes = [
[-1]
+ [
idx
for idx, phn in enumerate(item_hyps)
if phn == self.word_separator
]
+ [None]
for item_hyps in hyps
]
result = [
self._extract_hyp_word(
item_hyps, item_separator_indexes, word_index
)
for item_hyps, item_separator_indexes, word_index in zip(
hyps, separator_indexes, target_word_indexes
)
]
return result
def _extract_hyp_word(self, hyps, separator_indexes, word_index):
"""Extracts a single word out of a hypothesis sequence
Arguments
---------
hyps: list
a hypotheses list (or tensor)
separator_indexes: torch.Tensor
a tensor of word separators
word_index: int
the index of the word to eb retrieved
Returns
-------
result: list|str
the extracted word
"""
if word_index < len(separator_indexes):
left = separator_indexes[word_index]
if left is None:
return ""
left += 1
right = separator_indexes[word_index + 1]
result = hyps[left:right]
else:
result = []
return result