"""Library for Byte-pair-encoding (BPE) tokenization.
Authors
* Abdelwahab Heba 2020
* Loren Lugosch 2020
"""
import os.path
import torch
import logging
import csv
import json
from dataclasses import dataclass
from typing import List
import sentencepiece as spm
from speechbrain.dataio.dataio import merge_char
from speechbrain.utils import edit_distance
from speechbrain.utils.distributed import run_on_main
logger = logging.getLogger(__name__)
[docs]
class SentencePiece:
"""BPE class call the SentencePiece unsupervised text tokenizer from Google.
Reference: https://github.com/google/sentencepiece
SentencePiece lib is an unsupervised text tokenizer and detokenizer.
It implements subword units like Byte-pair-encoding (BPE),
Unigram language model and char/word tokenizer.
Arguments
---------
model_dir : str
The directory where the model will be saved (or already stored).
vocab_size : int, None, optional
Vocab size for the chosen tokenizer type (BPE, Unigram).
The vocab_size is optional for char, and mandatory for BPE & unigram
tokenization.
annotation_train : str
Path of the annotation file which is used to learn the tokenizer. It
can be in JSON or csv format.
annotation_read : str
The data entry which contains the word sequence in the annotation file.
model_type : str
(bpe, char, unigram).
If "bpe", train unsupervised tokenization of piece of words. see:
https://www.aclweb.org/anthology/P16-1162/
If "word" take the vocabulary from the input text.
If "unigram" do piece of word tokenization using unigram language
model, see: https://arxiv.org/abs/1804.10959
char_format_input : bool
Whether the read entry contains characters format input.
(default: False)
(e.g., a p p l e _ i s _ g o o d)
character_coverage : int
Amount of characters covered by the model, good defaults
are: 0.9995 for languages with a rich character set like Japanese or
Chinese and 1.0 for other languages with small character set.
(default: 1.0)
user_defined_symbols : string
String contained a list of symbols separated by a comma.
User-defined symbols are handled as one piece in any context.
(default: None)
max_sentencepiece_length : int
Maximum number of characters for the tokens. (default: 10)
bos_id : int
If -1 the bos_id = unk_id = 0. otherwise, bos_id = int. (default: -1)
eos_id : int
If -1 the bos_id = unk_id = 0. otherwise, bos_id = int. (default: -1)
split_by_whitespace : bool
If False, allow the sentencepiece to extract piece crossing multiple
words. This feature is important for : Chinese/Japanese/Korean.
(default: True)
num_sequences : int
If not none, use at most this many sequences to train the tokenizer
(for large datasets). (default: None)
annotation_list_to_check : list,
List of the annotation file which is used for checking the accuracy of
recovering words from the tokenizer.
annotation_format : str
The format of the annotation file. JSON or csv are the formats supported.
text_file: str
An alternate path to the text file (needed when multiple models are trained on
the same data file)
add_dummy_prefix : bool
If True the tokenizer adds dummy whitespace at the beginning of text. (default: True)
Example
-------
>>> import torch
>>> dict_int2lab = {1: "HELLO", 2: "MORNING"}
>>> model_dir = getfixture('tmpdir') / "tokenizer_data"
>>> # Example with csv
>>> annotation_train = "tests/samples/annotation/dev-clean.csv"
>>> annotation_read = "wrd"
>>> model_type = "bpe"
>>> bpe = SentencePiece(str(model_dir), 100, annotation_train, annotation_read, model_type)
>>> batch_seq = torch.Tensor([[1, 2, 2, 1],[1, 2, 1, 0]])
>>> batch_lens = torch.Tensor([1.0, 0.75])
>>> encoded_seq_ids, encoded_seq_pieces = bpe(
... batch_seq, batch_lens, dict_int2lab, task="encode"
... )
>>> # Example using JSON
>>> annotation_train = str(model_dir + "/dev-clean.json")
>>> annotation_read = "wrd"
>>> bpe = SentencePiece(model_dir, 100, annotation_train, annotation_read, model_type, annotation_format = 'json')
>>> encoded_seq_ids, encoded_seq_pieces = bpe(
... batch_seq, batch_lens, dict_int2lab, task="encode"
... )
"""
def __init__(
self,
model_dir,
vocab_size,
annotation_train=None,
annotation_read=None,
model_type="unigram",
char_format_input=False,
character_coverage=1.0,
user_defined_symbols=None,
max_sentencepiece_length=10,
bos_id=-1,
eos_id=-1,
pad_id=-1,
unk_id=0,
split_by_whitespace=True,
num_sequences=None,
annotation_list_to_check=None,
annotation_format="csv",
text_file=None,
add_dummy_prefix=True,
):
if model_type not in ["unigram", "bpe", "char"]:
raise ValueError("model_type must be one of : [unigram, bpe, char]")
if not os.path.isdir(model_dir):
os.makedirs(model_dir)
if not isinstance(vocab_size, int):
raise ValueError("vocab_size must be integer.")
self.annotation_train = annotation_train
self.annotation_read = annotation_read
self.annotation_format = annotation_format
if self.annotation_train is not None:
ext = os.path.splitext(self.annotation_train)[1]
if text_file is None:
text_file = os.path.join(
model_dir,
os.path.basename(self.annotation_train).replace(
ext, ".txt"
),
)
self.text_file = text_file
self.prefix_model_file = os.path.join(
model_dir, str(vocab_size) + "_" + model_type
)
self.vocab_size = str(vocab_size)
self.model_type = model_type
self.char_format_input = char_format_input
self.character_coverage = str(character_coverage)
self.max_sentencepiece_length = str(max_sentencepiece_length)
self.bos_id = str(bos_id)
self.eos_id = str(eos_id)
self.pad_id = str(pad_id)
self.unk_id = str(unk_id)
self.num_sequences = num_sequences
self.split_by_whitespace = split_by_whitespace
self.user_defined_symbols = user_defined_symbols
self.add_dummy_prefix = str(add_dummy_prefix)
if not os.path.isfile(self.prefix_model_file + ".model"):
logger.info("Train tokenizer with type:" + self.model_type)
if not os.path.isfile(self.text_file):
if annotation_format == "csv":
run_on_main(self._csv2text)
elif annotation_format == "json":
run_on_main(self._json2text)
else:
raise ValueError(
"Annotation format not supported. Supported formats are csv and json. Got "
+ annotation_format
)
run_on_main(self._train_BPE)
else:
logger.info("Tokenizer is already trained.")
logger.info("==== Loading Tokenizer ===")
logger.info("Tokenizer path: " + self.prefix_model_file + ".model")
logger.info("Tokenizer vocab_size: " + str(self.vocab_size))
logger.info("Tokenizer type: " + self.model_type)
self.sp = spm.SentencePieceProcessor()
self.sp.load(self.prefix_model_file + ".model")
if annotation_list_to_check is not None:
run_on_main(
self._check_coverage_from_bpe,
kwargs={"list_annotation_files": annotation_list_to_check},
)
def _csv2text(self):
"""Read CSV file and convert specific data entries into text file.
"""
if not os.path.isfile(os.path.abspath(self.annotation_train)):
raise ValueError(
self.annotation_train
+ " is not a file. please provide annotation file for training."
)
logger.info(
"Extract "
+ self.annotation_read
+ " sequences from:"
+ self.annotation_train
)
annotation_file = open(self.annotation_train, "r")
reader = csv.reader(annotation_file)
headers = next(reader, None)
if self.annotation_read not in headers:
raise ValueError(
self.annotation_read + " must exist in:" + self.annotation_train
)
index_label = headers.index(self.annotation_read)
text_file = open(self.text_file, "w+")
row_idx = 0
for row in reader:
if self.num_sequences is not None and row_idx > self.num_sequences:
print(
"Using %d sequences to train the tokenizer."
% self.num_sequences
)
break
row_idx += 1
sent = row[index_label]
if self.char_format_input:
(sent,) = merge_char([sent.split()])
sent = " ".join(sent)
text_file.write(sent + "\n")
text_file.close()
annotation_file.close()
logger.info("Text file created at: " + self.text_file)
def _json2text(self):
"""Read JSON file and convert specific data entries into text file.
"""
if not os.path.isfile(os.path.abspath(self.annotation_train)):
raise ValueError(
self.annotation_train
+ " is not a file. please provide annotation file for training."
)
logger.info(
"Extract "
+ self.annotation_read
+ " sequences from:"
+ self.annotation_train
)
# Read JSON
with open(self.annotation_train, "r") as f:
out_json = json.load(f)
# Save text file
text_file = open(self.text_file, "w+")
row_idx = 0
for snt_id in out_json.keys():
if self.num_sequences is not None and row_idx > self.num_sequences:
print(
"Using %d sequences to train the tokenizer."
% self.num_sequences
)
break
row_idx += 1
sent = out_json[snt_id][self.annotation_read]
if self.char_format_input:
(sent,) = merge_char([sent.split()])
sent = " ".join(sent)
text_file.write(sent + "\n")
text_file.close()
logger.info("Text file created at: " + self.text_file)
def _train_BPE(self):
"""Train tokenizer with unsupervised techniques (BPE, Unigram) using
SentencePiece Library. If you use "char" mode, the SentencePiece
creates a char dict so the vocab_size attribute is not needed.
"""
query = (
"--input="
+ self.text_file
+ " --model_prefix="
+ self.prefix_model_file
+ " --model_type="
+ self.model_type
+ " --bos_id="
+ self.bos_id
+ " --eos_id="
+ self.eos_id
+ " --pad_id="
+ self.pad_id
+ " --unk_id="
+ self.unk_id
+ " --max_sentencepiece_length="
+ self.max_sentencepiece_length
+ " --character_coverage="
+ self.character_coverage
+ " --add_dummy_prefix="
+ self.add_dummy_prefix
)
if self.model_type not in ["char"]:
# include vocab_size
query += " --vocab_size=" + str(self.vocab_size)
if self.user_defined_symbols is not None:
query += " --user_defined_symbols=" + self.user_defined_symbols
if not self.split_by_whitespace:
query += " --split_by_whitespace=false"
# Train tokenizer
spm.SentencePieceTrainer.train(query)
def _check_coverage_from_bpe(self, list_annotation_files=[]):
"""Logging the accuracy of the BPE model to recover words from the training text.
Arguments
---------
annotation_list_to_check : list,
List of the annotation file which is used for checking the accuracy of recovering words from the tokenizer.
"""
for annotation_file in list_annotation_files:
if os.path.isfile(os.path.abspath(annotation_file)):
logger.info(
"==== Accuracy checking for recovering text from tokenizer ==="
)
# csv reading
if self.annotation_format == "csv":
fannotation_file = open(annotation_file, "r")
reader = csv.reader(fannotation_file)
headers = next(reader, None)
if self.annotation_read not in headers:
raise ValueError(
self.annotation_read
+ " must exist in:"
+ annotation_file
)
index_label = headers.index(self.annotation_read)
# json reading
else:
with open(self.annotation_train, "r") as f:
reader = json.load(f)
index_label = self.annotation_read
wrong_recover_list = []
for row in reader:
if self.annotation_format == "csv":
row = row[index_label]
else:
row = reader[row][index_label]
if self.char_format_input:
(row,) = merge_char([row.split()])
row = " ".join(row)
row = row.split("\n")[0]
encoded_id = self.sp.encode_as_ids(row)
decode_text = self.sp.decode_ids(encoded_id)
(details,) = edit_distance.wer_details_for_batch(
["utt1"],
[row.split(" ")],
[decode_text.split(" ")],
compute_alignments=True,
)
if details["WER"] > 0:
for align in details["alignment"]:
if align[0] != "=" and align[1] is not None:
if align[1] not in wrong_recover_list:
wrong_recover_list.append(align[1])
if self.annotation_format == "csv":
fannotation_file.close()
logger.info("recover words from: " + annotation_file)
if len(wrong_recover_list) > 0:
logger.warning(
"Wrong recover words: " + str(len(wrong_recover_list))
)
logger.warning(
"Tokenizer vocab size: " + str(self.sp.vocab_size())
)
logger.warning(
"accuracy recovering words: "
+ str(
1
- float(len(wrong_recover_list))
/ self.sp.vocab_size()
)
)
else:
logger.info("Wrong recover words: 0")
logger.warning("accuracy recovering words: " + str(1.0))
else:
logger.info(
"No accuracy recover checking for" + annotation_file
)
[docs]
def __call__(
self, batch, batch_lens=None, ind2lab=None, task="encode",
):
"""This __call__ function implements the tokenizer encoder and decoder
(restoring the string of word) for BPE, Regularized BPE (with unigram),
and char (speechbrain/nnet/RNN.py).
Arguments
----------
batch : tensor.IntTensor or list
List if ( batch_lens = None and task = "decode_from_list")
Contains the original labels. Shape: [batch_size, max_length]
batch_lens : tensor.LongTensor
Containing the relative length of each label sequences. Must be 1D
tensor of shape: [batch_size]. (default: None)
ind2lab : dict
Dictionary which maps the index from label sequences
(batch tensor) to string label.
task : str
("encode", "decode", "decode_from_list)
"encode": convert the batch tensor into sequence of tokens.
the output contain a list of (tokens_seq, tokens_lens)
"decode": convert a tensor of tokens to a list of word sequences.
"decode_from_list": convert a list of token sequences to a list
of word sequences.
"""
if task == "encode" and ind2lab is None:
raise ValueError("Tokenizer encoder must have the ind2lab function")
if task == "encode":
# Convert list of words/chars to bpe ids
bpe = []
max_bpe_len = 0
batch_lens = (batch_lens * batch.shape[1]).round().int()
for i, utt_seq in enumerate(batch):
tokens = [
ind2lab[int(index)] for index in utt_seq[: batch_lens[i]]
]
if self.char_format_input:
(words_list,) = merge_char([tokens])
sent = " ".join(words_list)
else:
sent = " ".join(tokens)
bpe_encode = self.sp.encode_as_ids(sent)
bpe.append(bpe_encode)
# save the longest bpe sequence
# it help to compute the relative length of each utterance
if len(bpe_encode) > max_bpe_len:
max_bpe_len = len(bpe_encode)
# Create bpe tensor
bpe_tensor = torch.zeros(
(batch.shape[0], max_bpe_len), device=batch.device
)
bpe_lens = torch.zeros((batch.shape[0]), device=batch.device)
for i, bpe_utt in enumerate(bpe):
bpe_tensor[i, : len(bpe_utt)] = torch.Tensor(bpe_utt)
bpe_lens[i] = len(bpe_utt) / max_bpe_len
return bpe_tensor, bpe_lens
elif task == "decode_from_list":
# From list of hyps (not padded outputs)
# do decoding
return [self.sp.decode_ids(utt_seq).split(" ") for utt_seq in batch]
elif task == "decode":
# From a batch tensor and a length tensor
# find the absolute batch lengths and do decoding
batch_lens = (batch_lens * batch.shape[1]).round().int()
return [
self.sp.decode_ids(
utt_seq[: batch_lens[i]].int().tolist()
).split(" ")
for i, utt_seq in enumerate(batch)
]
[docs]
def get_spm_tokens(model_path):
"""Fetch list of tokens, can be indexed by token id
The resulting list can be used to map id to token.
Arguments
---------
model_path : str
Path to SentencePiece model
Returns
-------
list
Tokens in order by id (can be indexed by id)
"""
model = spm.SentencePieceProcessor()
model.load(model_path)
mapping = [model.sp.id_to_piece(i) for i in range(model.sp.vocab_size())]
return mapping
[docs]
@dataclass
class SentencePieceDecoderStreamingContext:
"""Mutable streaming context for a single SentencePiece streaming session.
"""
emitted_symbol_count: int = 0
"""The number of symbols that have been emitted for this transcription."""
[docs]
def spm_decode_preserve_leading_space(
tokenizer: spm.SentencePieceProcessor,
hyps: List[int],
context: SentencePieceDecoderStreamingContext,
) -> List[str]:
"""Assuming the tokenizer is sentencepiece, decodes the input hypothesis
but avoids incorrectly stripping leading spaces when streaming.
Operates on a single hypothesis, not a batch of hypotheses.
Normally, the tokenizer always decodes full sentences at a time, with the
consequence that the first space in decoding will get removed.
However, when streaming, we might be decoding mid-utterance where spaces
must not be removed mid-sentence. This function handles this case.
e.g. if within the same streaming context, you decode `["▁how", "▁are"]`
then `["▁you"]`, the decoder would normally return `"how areyou"` instead of
`"how are you"` like this function does.
Arguments
---------
tokenizer : sentencepiece.SentencePieceProcessor
The SentencePiece processor to use for decoding.
hyps : list of output token hypotheses
List of tokens to decode of any length `>=0`.
context : SentencePieceDecoderStreamingContext
Mutable streaming context for the sentencepiece decoder, which should be
reused across calls for the same decoding stream.
Returns
-------
str
Decoded text. Leading spaces are preserved, except at the start of a
transcription."""
proto = tokenizer.decode([hyps], out_type="immutable_proto")[0]
text = proto.text
if len(proto.pieces) >= 1:
should_preserve_space = context.emitted_symbol_count > 0
# By default, SentencePiece tags spaces with `▁` i.e. \u2581
# (unicode for "Lower One Eighth Block").
if should_preserve_space and proto.pieces[0].piece.startswith("\u2581"):
# We are mid-sentence and the decoder has nuked the first space,
# as the decoder believes we are decoding a full sentence.
# Insert it back.
text = " " + text
context.emitted_symbol_count += len(proto.pieces)
return text