speechbrain.lobes.models.huggingface_transformers.weighted_ssl module

This lobe enables the integration of huggingface pretrained wav2vec2 models.

Reference: https://arxiv.org/abs/2006.11477 Reference: https://arxiv.org/abs/1904.05862 Reference: https://arxiv.org/abs/2110.13900 Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html

Authors
  • Salah Zaiem 2023

  • Adel Moumen 2023, 2024

Summary

Classes:

WeightedSSLModel

This lobe enables the integration of use of weighted sum representations from different layers in a SSL encoder.

Reference

class speechbrain.lobes.models.huggingface_transformers.weighted_ssl.WeightedSSLModel(hub, save_path='', layernorm=False, freeze=False)[source]

Bases: HFTransformersInterface

This lobe enables the integration of use of weighted sum representations from different layers in a SSL encoder.

The model can be used as a fixed feature extractor for SSL benchmarking. It will download automatically the model from HuggingFace or use a local path.

More details in recipes/SSL_benchmark

Parameters:
  • hub (str) – HuggingFace hub name: e.g “facebook/wav2vec2-large-lv60”

  • save_path (str) – Path (dir) of the downloaded model.

  • layernorm (bool, (default: False)) – Whether layer representations should be layernormed before sum

  • freeze (bool (default: True)) – If True, the model is frozen. If False, the model will be trained alongside with the rest of the pipeline.

Example

>>> inputs = torch.rand([10, 600])
>>> model_hub = "facebook/wav2vec2-base-960h"
>>> save_path = "savedir"
>>> model = WeightedSSLModel(model_hub, save_path)
>>> outputs = model(inputs)
forward(wav, wav_lens=None)[source]

This method outputs a weighted sum of the layers representations of the SSL encoder

Parameters:
  • wav (tensor) – The wavs

  • wav_lens (tensor) – The wav lengths

override_config(config)[source]

If the config needs to be overrided, here is the place

Parameters:

config (Wav2Vec2Config) – The original config needs to be overrided.

Return type:

Overridded config

training: bool