Source code for speechbrain.lobes.models.huggingface_transformers.weighted_ssl

"""This lobe enables the integration of huggingface pretrained wav2vec2 models.

Reference: https://arxiv.org/abs/2006.11477
Reference: https://arxiv.org/abs/1904.05862
Reference: https://arxiv.org/abs/2110.13900
Transformer from HuggingFace needs to be installed:
https://huggingface.co/transformers/installation.html

Authors
 * Salah Zaiem 2023
 * Adel Moumen 2023, 2024
"""

import torch
import logging
import torch.nn.functional as F
from speechbrain.lobes.models.huggingface_transformers.huggingface import (
    HFTransformersInterface,
)

logger = logging.getLogger(__name__)


[docs] class WeightedSSLModel(HFTransformersInterface): """This lobe enables the integration of use of weighted sum representations from different layers in a SSL encoder. The model can be used as a fixed feature extractor for SSL benchmarking. It will download automatically the model from HuggingFace or use a local path. More details in recipes/SSL_benchmark Arguments --------- hub : str HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60" save_path : str Path (dir) of the downloaded model. layernorm: bool, (default: False) Whether layer representations should be layernormed before sum freeze : bool (default: True) If True, the model is frozen. If False, the model will be trained alongside with the rest of the pipeline. Example ------- >>> inputs = torch.rand([10, 600]) >>> model_hub = "facebook/wav2vec2-base-960h" >>> save_path = "savedir" >>> model = WeightedSSLModel(model_hub, save_path) >>> outputs = model(inputs) """ def __init__(self, hub, save_path="", layernorm=False, freeze=False): super().__init__(source=hub, save_path=save_path, freeze=freeze) self.model.eval() self.num_layers = self.config.num_hidden_layers + 1 # Initializing the learnable weights zero_init = torch.cat([torch.zeros(self.num_layers)]) self.weights = torch.nn.Parameter(zero_init, requires_grad=True) self.layernorm = layernorm
[docs] def forward(self, wav, wav_lens=None): """This method outputs a weighted sum of the layers representations of the SSL encoder Arguments --------- wav : tensor The wavs wav_lens : tensor The wav lengths """ feats = self.model(wav) hidden_states = torch.stack(feats.hidden_states, dim=0).detach() # First dimension should be equal to the number of layers in the hparams assert ( self.num_layers == hidden_states.shape[0] ), "Num layers not equal to num hidden states" norm_weights = torch.nn.functional.softmax(self.weights, dim=-1) # Layernorming the layers representations if asked if self.layernorm: hidden_states = [ F.layer_norm(t, (t.shape[-1],)) for t in hidden_states ] # Summing the weighted layers weighted_feats = ( hidden_states * norm_weights[:, None, None, None] ).sum(axis=0) return weighted_feats
[docs] def override_config(self, config): """If the config needs to be overrided, here is the place Arguments --------- config : Wav2Vec2Config The original config needs to be overrided. Returns ------- Overridded config """ config.output_hidden_states = True return config