speechbrain.lobes.models.huggingface_transformers.discrete_ssl moduleο
This lobe enables the integration of pretrained discrete SSL (hubert,wavlm,wav2vec) for extracting semnatic tokens from output of SSL layers.
Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html
- Author
Pooneh Mousavi 2024
Summaryο
Classes:
This lobe enables the integration of HuggingFace and SpeechBrain pretrained Discrete SSL models. |
Referenceο
- class speechbrain.lobes.models.huggingface_transformers.discrete_ssl.DiscreteSSL(save_path, ssl_model, kmeans_dataset, kmeans_repo_id='speechbrain/SSL_Quantization', num_clusters=1000, layers_num=None)[source]ο
Bases:
ModuleThis lobe enables the integration of HuggingFace and SpeechBrain pretrained Discrete SSL models.
Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html
The model can be used as a fixed Discrete feature extractor or can be finetuned. It will download automatically the model from HuggingFace or use a local path.
- Parameters:
save_path (str) β Path (dir) of the downloaded model.
ssl_model (str) β SSL model to extract semantic tokens from its layersβ output. Note that output_all_hiddens should be set to True to enable multi-layer discretenation.
kmeans_dataset (str) β Name of the dataset that Kmeans model on HF repo is trained with.
kmeans_repo_id (str) β Huggingface repository that contains the pre-trained k-means models.
num_clusters (int or List[int] (default: 1000)) β determine the number of clusters of the targeted kmeans models to be downloaded. It could be varying for each layer.
layers_num (List[int] (Optional)) β detremine layers to be download from HF repo. If it is not provided, all layers with num_clusters(int) is loaded from HF repo. If num_clusters is a list, the layers_num should be provided to determine the cluster number for each layer.
Example
>>> import torch >>> from speechbrain.lobes.models.huggingface_transformers.hubert import (HuBERT) >>> inputs = torch.rand([3, 2000]) >>> model_hub = "facebook/hubert-large-ll60k" >>> save_path = "savedir" >>> ssl_layer_num = [7,23] >>> deduplicate =[False, True] >>> bpe_tokenizers=[None, None] >>> kmeans_repo_id = "speechbrain/SSL_Quantization" >>> kmeans_dataset = "LJSpeech" >>> num_clusters = 1000 >>> ssl_model = HuBERT(model_hub, save_path,output_all_hiddens=True) >>> model = DiscreteSSL(save_path, ssl_model, kmeans_repo_id=kmeans_repo_id, kmeans_dataset=kmeans_dataset,num_clusters=num_clusters) >>> tokens, embs ,pr_tokens= model(inputs,SSL_layers=ssl_layer_num, deduplicates=deduplicate, bpe_tokenizers=bpe_tokenizers) >>> print(tokens.shape) torch.Size([3, 6, 2]) >>> print(embs.shape) torch.Size([3, 6, 2, 1024]) >>> print(pr_tokens.shape) torch.Size([3, 6, 2])
- check_if_input_is_compatible(layers_num, num_clusters)[source]ο
check if layer_number and num_clusters is consistent with each other.
- Parameters:
- load_kmeans(repo_id, kmeans_dataset, encoder_name, num_clusters, cache_dir, layers_num=None)[source]ο
Load a Pretrained kmeans model from HF.
- Parameters:
repo_id (str) β The hugingface repo id that contains the model.
kmeans_dataset (str) β Name of the dataset that Kmeans model are trained with in HF repo that need to be downloaded.
encoder_name (str) β Name of the encoder for locating files.
num_clusters (int or List[int]) β determine the number of clusters of the targeted kmeans models to be downloaded. It could be varying for each layer.
cache_dir (str) β Path (dir) of the downloaded model.
layers_num (List[int] (Optional)) β If num_clusters is a list, the layers_num should be provided to determine the cluster number for each layer.
- Returns:
kmeans_model (MiniBatchKMeans) β pretrained Kmeans model loaded from the HF.
layer_ids (List[int]) β supported layer nums for kmeans (extracted from the name of kmeans model.)
- forward(wav, wav_lens=None, SSL_layers=None, deduplicates=None, bpe_tokenizers=None)[source]ο
Takes an input waveform and return its corresponding wav2vec encoding.
- Parameters:
wav (torch.Tensor (signal)) β A batch of audio signals to transform to features.
wav_lens (tensor) β The relative length of the wav given in SpeechBrain format.
SSL_layers (List[int]:) β determine which layers of SSL should be used to extract information.
deduplicates (List[boolean]:) β determine to apply deduplication(remove duplicate subsequent tokens) on the tokens extracted for the corresponding layer.
bpe_tokenizers (List[int]:) β determine to apply subwording on the tokens extracted for the corresponding layer if the sentencePiece tokenizer is trained for that layer.
- Returns:
tokens (torch.Tensor) β A (Batch x Seq x num_SSL_layers) tensor of audio tokens
emb (torch.Tensor) β A (Batch x Seq x num_SSL_layers x embedding_dim ) cluster_centers embeddings for each tokens
processed_tokens (torch.Tensor) β A (Batch x Seq x num_SSL_layers) tensor of audio tokens after applying deduplication and subwording if necessary.