speechbrain.lobes.models.huggingface_transformers.weighted_ssl moduleο
This lobe enables the integration of huggingface pretrained wav2vec2 models.
Reference: https://arxiv.org/abs/2006.11477 Reference: https://arxiv.org/abs/1904.05862 Reference: https://arxiv.org/abs/2110.13900 Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html
- Authors
Salah Zaiem 2023
Adel Moumen 2023, 2024
Summaryο
Classes:
This lobe enables the integration of use of weighted sum representations from different layers in a SSL encoder. |
Referenceο
- class speechbrain.lobes.models.huggingface_transformers.weighted_ssl.WeightedSSLModel(hub, save_path='', layernorm=False, freeze=False, **kwargs)[source]ο
Bases:
HFTransformersInterface
This lobe enables the integration of use of weighted sum representations from different layers in a SSL encoder.
The model can be used as a fixed feature extractor for SSL benchmarking. It will download automatically the model from HuggingFace or use a local path.
More details in recipes/SSL_benchmark
- Parameters:
hub (str) β HuggingFace hub name: e.g βfacebook/wav2vec2-large-lv60β
save_path (str) β Path (dir) of the downloaded model.
layernorm (bool, (default: False)) β Whether layer representations should be layernormed before sum
freeze (bool (default: True)) β If True, the model is frozen. If False, the model will be trained alongside with the rest of the pipeline.
**kwargs (dict) β Additional arguments to pass to HFTransformersInterface
Example
>>> inputs = torch.rand([10, 600]) >>> model_hub = "facebook/wav2vec2-base-960h" >>> save_path = "savedir" >>> model = WeightedSSLModel(model_hub, save_path) >>> outputs = model(inputs)
- forward(wav, wav_lens=None)[source]ο
This method outputs a weighted sum of the layer representations of the SSL encoder
- Parameters:
wav (torch.Tensor) β The wavs
wav_lens (torch.Tensor) β The wav lengths
- Returns:
weighted_feats β The weighted sum of layer representations.
- Return type:
torch.Tensor