Source code for speechbrain.integrations.audio_tokenizers.discrete_ssl

"""This lobe enables the integration of pretrained discrete SSL (hubert,wavlm,wav2vec) for extracting semnatic tokens from output of SSL layers.

Transformer from HuggingFace needs to be installed:
https://huggingface.co/transformers/installation.html

Author
 * Pooneh Mousavi 2024
 * Jarod Duret 2024
"""

import os
from glob import glob

import joblib
import torch
from huggingface_hub import snapshot_download
from torch import nn

from speechbrain.inference.vocoders import UnitHIFIGAN
from speechbrain.tokenizers.discrete_SSL_tokenizer import DiscreteSSLTokenizer
from speechbrain.utils.logger import get_logger

logger = get_logger(__name__)


[docs] class DiscreteSSL(nn.Module): """This lobe enables the integration of HuggingFace and SpeechBrain pretrained Discrete SSL models. Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html The model can be used as a fixed Discrete feature extractor or can be finetuned. It will download automatically the model from HuggingFace or use a local path. The following table summarizes the compatible SSL models, their respective HF encoders, k-means training details, supported layers, and pretrained vocoder: | SSL Model | HF Encoder | K-Means Dataset | K-Means Size | SSL Layers | Vocoder Model | |------------|----------------------------------------|-----------------|--------------|----------------------|---------------------------------------------| | WavLM | microsoft/wavlm-large | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | speechbrain/hifigan-wavlm-k1000-LibriTTS | | HuBERT | facebook/hubert-large-ll60k | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | speechbrain/hifigan-hubert-k1000-LibriTTS | | Wav2Vec2 | facebook/wav2vec2-large | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | speechbrain/hifigan-wav2vec2-k1000-LibriTTS | Arguments --------- save_path : str Path (dir) of the downloaded model. ssl_model : str SSL model to extract semantic tokens from its layers' output. Note that output_all_hiddens should be set to True to enable multi-layer discretization. kmeans_dataset : str Name of the dataset that Kmeans model on HF repo is trained with. vocoder_repo_id: str Huggingface repository that contains the pre-trained HiFi-GAN model. num_clusters : int or List[int] (default: 1000) Determine the number of clusters of the targeted kmeans models to be downloaded. It could be varying for each layer. layers_num : List[int] (Optional) Detremine layers to be download from HF repo. If it is not provided, all layers with num_clusters(int) is loaded from HF repo. If num_clusters is a list, the layers_num should be provided to determine the cluster number for each layer. device : str (default 'cpu') The device to use for computation ('cpu' or 'cuda'). sample_rate : int (default: 16000) Sample rate of the input audio. Example ------- >>> import torch >>> from speechbrain.integrations.huggingface.wavlm import WavLM >>> inputs = torch.rand([3, 2000]) >>> model_hub = "microsoft/wavlm-large" >>> save_path = "savedir" >>> ssl_layer_num = [7, 23] >>> deduplicate = [False, True] >>> bpe_tokenizers = [None, None] >>> vocoder_repo_id = "speechbrain/hifigan-wavlm-k1000-LibriTTS" >>> kmeans_dataset = "LibriSpeech" >>> num_clusters = 1000 >>> ssl_model = WavLM(model_hub, save_path, output_all_hiddens=True) >>> model = DiscreteSSL( ... save_path, ... ssl_model, ... vocoder_repo_id=vocoder_repo_id, ... kmeans_dataset=kmeans_dataset, ... num_clusters=num_clusters, ... ) >>> tokens, _, _ = model.encode( ... inputs, ... SSL_layers=ssl_layer_num, ... deduplicates=deduplicate, ... bpe_tokenizers=bpe_tokenizers, ... ) >>> print(tokens.shape) torch.Size([3, 6, 2]) >>> sig = model.decode(tokens, ssl_layer_num) >>> print(sig.shape) torch.Size([3, 1, 1920]) """ def __init__( self, save_path, ssl_model, kmeans_dataset, vocoder_repo_id="speechbrain/hifigan-wavlm-k1000-LibriTTS", num_clusters=1000, layers_num=None, device="cpu", sample_rate=16000, ): super().__init__() self.device = device self.ssl_model = ssl_model model_name = ssl_model.__class__.__name__.lower() self.check_if_input_is_compatible(layers_num, num_clusters) self.kmeans_models, self.ssl_layer_ids, self.num_clusters = ( self.load_kmeans( vocoder_repo_id, kmeans_dataset, model_name, self.num_clusters, save_path, layers_num, ) ) self.vocabularies = [] for model in self.kmeans_models: self.vocabularies.append(model.cluster_centers_) self.tokenizer = DiscreteSSLTokenizer(self.num_clusters) self.codec_vocoder = UnitHIFIGAN.from_hparams( source=vocoder_repo_id, savedir=save_path, ) self.codec_vocoder.tokenize = False self.sample_rate = sample_rate
[docs] def check_if_input_is_compatible(self, layers_num, num_clusters): """check if layer_number and num_clusters is consistent with each other. Arguments --------- layers_num: List[int] (Optional) If num_clusters is a list, the layers_num should be provided to determine the cluster number for each layer. num_clusters: int or List[int] determine the number of clusters of the targeted kmeans models to be downloaded. It could be varying for each layer. """ if layers_num: if isinstance(num_clusters, int): num_clusters = [num_clusters for i in layers_num] assert len(num_clusters) == len(layers_num), ( "length of num_clusters and layers_num should be the same!!!" ) if layers_num is None: assert isinstance(num_clusters, int), ( "num_clusters is expected to be int since the layers_num is not provided." ) self.num_clusters = num_clusters
[docs] def load_kmeans( self, repo_id, kmeans_dataset, encoder_name, num_clusters, cache_dir, layers_num=None, ): """Load a Pretrained kmeans model from HF. Arguments --------- repo_id : str The hugingface repo id that contains the model. kmeans_dataset : str Name of the dataset that Kmeans model are trained with in HF repo that need to be downloaded. encoder_name : str Name of the encoder for locating files. num_clusters : int or List[int] determine the number of clusters of the targeted kmeans models to be downloaded. It could be varying for each layer. cache_dir : str Path (dir) of the downloaded model. layers_num : List[int] (Optional) If num_clusters is a list, the layers_num should be provided to determine the cluster number for each layer. Returns ------- kmeans_model : MiniBatchKMeans pretrained Kmeans model loaded from the HF. layer_ids : List[int] supported layer nums for kmeans (extracted from the name of kmeans model.) """ kmeans_models = [] layer_ids = [] file_patterns = [] if layers_num: for i, layer in enumerate(layers_num): file_patterns.append( f"kmeans/{kmeans_dataset}_{encoder_name}_k{num_clusters[i]}_L{layer}.pt" ) else: file_patterns.append( f"kmeans/{kmeans_dataset}_{encoder_name}_k{num_clusters}*.pt" ) kmeans_dir = snapshot_download( repo_id=repo_id, allow_patterns=file_patterns, cache_dir=cache_dir ) files = [] for ext in file_patterns: for file in glob(os.path.join(kmeans_dir, ext)): if file not in files: files.append(file) layer_ids.append( int( file.split("/")[-1].split("_")[-1].split(".")[0][1:] ) ) kmeans_models.append(joblib.load(file)) assert len(layer_ids) > 0, ( f"There is no trained k-means model available for {repo_id}" ) if isinstance(num_clusters, int): num_clusters = [num_clusters for i in layer_ids] layer_ids, kmeans_models, num_clusters = zip( *sorted(zip(layer_ids, kmeans_models, num_clusters)) ) return kmeans_models, layer_ids, num_clusters
[docs] def forward( self, wav, wav_lens=None, SSL_layers=None, deduplicates=None, bpe_tokenizers=None, ): """Takes an input waveform and return its corresponding tokens and reconstructed signal. Arguments --------- wav : torch.Tensor (signal) A batch of audio signals to transform to features. wav_lens : tensor The relative length of the wav given in SpeechBrain format. SSL_layers: List[int]: determine which layers of SSL should be used to extract information. deduplicates: List[boolean]: determine to apply deduplication(remove duplicate subsequent tokens) on the tokens extracted for the corresponding layer. bpe_tokenizers: List[int]: determine to apply subwording on the tokens extracted for the corresponding layer if the sentencePiece tokenizer is trained for that layer. Returns ------- tokens : torch.Tensor A (Batch x Seq x num_SSL_layers) tensor of audio tokens waveforms: torch.tensor Batch of mel-waveforms [batch, time] """ tokens = self.encode( wav, wav_lens, SSL_layers, deduplicates, bpe_tokenizers )[0] sig = self.decode(tokens, SSL_layers=SSL_layers) return tokens, sig
[docs] def encode( self, wav, wav_lens=None, SSL_layers=None, deduplicates=None, bpe_tokenizers=None, ): """Takes an input waveform and return its corresponding encoding. Arguments --------- wav : torch.Tensor (signal) A batch of audio signals to transform to features. wav_lens : tensor The relative length of the wav given in SpeechBrain format. SSL_layers: List[int]: determine which layers of SSL should be used to extract information. deduplicates: List[boolean]: determine to apply deduplication(remove duplicate subsequent tokens) on the tokens extracted for the corresponding layer. bpe_tokenizers: List[int]: determine to apply subwording on the tokens extracted for the corresponding layer if the sentencePiece tokenizer is trained for that layer. Returns ------- tokens : torch.Tensor A (Batch x Seq x num_SSL_layers) tensor of audio tokens emb : torch.Tensor A (Batch x Seq x num_SSL_layers x embedding_dim ) cluster_centers embeddings for each tokens processed_tokens : torch.Tensor A (Batch x Seq x num_SSL_layers) tensor of audio tokens after applying deduplication and subwording if necessary. """ if SSL_layers is None: SSL_layers = self.ssl_layer_ids if deduplicates is None: deduplicates = [False] * len(SSL_layers) if bpe_tokenizers is None: bpe_tokenizers = [None] * len(SSL_layers) assert len(deduplicates) == len(SSL_layers) == len(bpe_tokenizers), ( "length of SSL_layers,deduplicates,bpe_tokenizers should be the same!!!" ) embeddings = [] token_ids = [] for layer in SSL_layers: if layer not in self.ssl_layer_ids: raise ValueError( f"Layer {layer} is not among trained layers for k-means. Supported layers are: {self.ssl_layer_ids}." ) with torch.no_grad(): feats = self.ssl_model.extract_features(wav, wav_lens) for layer_num, model, vocabulary in zip( self.ssl_layer_ids, self.kmeans_models, self.vocabularies ): if layer_num not in SSL_layers: continue tokens = model.predict( feats[layer_num].flatten(end_dim=-2).cpu() ) embs = vocabulary[tokens] embeddings.append( torch.tensor( embs.reshape(wav.shape[0], -1, embs.shape[-1]), dtype=torch.float, device=wav.device, ) ) token_ids.append( torch.tensor( tokens.reshape(wav.shape[0], -1), dtype=torch.long, device=wav.device, ) ) org_tokens = torch.stack(token_ids, 2) org_embedding = torch.stack(embeddings, 2) processed_tokens = self.tokenizer.encode( org_tokens, SSL_layers, deduplicates, bpe_tokenizers ) return org_tokens, org_embedding, processed_tokens
[docs] def decode(self, tokens, SSL_layers=None): """Takes an input waveform and return its corresponding waveform. Original source: https://github.com/speechbrain/benchmarks/blob/c87beb61d4747909a133d3e1b3a3df7c8eda1f08/ benchmarks/DASB/Libri2Mix/separation/conformer/train_discrete_ssl.py#L44 Arguments --------- tokens : torch.Tensor A (Batch, codes, layers) tensor of discrete units SSL_layers: List[int]: determine which layers of SSL should be used by the vocoder. Returns ------- waveforms: torch.tensor Batch of mel-waveforms [batch, time] """ assert all( cluster == self.num_clusters[0] for cluster in self.num_clusters ), "All values in num_clusters must be equal." num_clusters = self.num_clusters[0] offsets = torch.arange( 0, len(self.ssl_layer_ids) * num_clusters, num_clusters, device=self.device, ) layers = self.ssl_layer_ids if SSL_layers is not None: layers = SSL_layers offset_idxes = [self.ssl_layer_ids.index(x) for x in layers] offsets = offsets[offset_idxes] tokens = tokens + offsets + 1 if len(layers) < len(self.ssl_layer_ids): full_tokens = torch.zeros( *tokens.shape[:2], len(self.ssl_layer_ids), dtype=tokens.dtype, device=self.device, ) for i, idx in enumerate(offset_idxes): full_tokens[..., idx] = tokens[..., i] tokens = full_tokens return self.codec_vocoder(tokens)