speechbrain.lobes.models.huggingface_wav2vec module

This lobe enables the integration of huggingface pretrained wav2vec2 models.

Reference: https://arxiv.org/abs/2006.11477 Reference: https://arxiv.org/abs/1904.05862 Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html

Authors
  • Titouan Parcollet 2021

Summary

Classes:

HuggingFaceWav2Vec2

This lobe enables the integration of HuggingFace pretrained wav2vec2.0 models.

Reference

class speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2(source, save_path, output_norm=True, freeze=True, pretrain=True)[source]

Bases: torch.nn.modules.module.Module

This lobe enables the integration of HuggingFace pretrained wav2vec2.0 models.

Source paper: https://arxiv.org/abs/2006.11477 Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html

The model can be used as a fixed feature extractor or can be finetuned. It will download automatically the model from HuggingFace.

Parameters
  • source (str) – HuggingFace hub name: e.g “facebook/wav2vec2-large-lv60”

  • save_path (str) – Path (dir) of the downloaded model.

  • output_norm (bool (default: True)) – If True, a layer_norm (affine) will be applied to the output obtained from the wav2vec model.

  • freeze (bool (default: True)) – If True, the model is frozen. If False, the model will be trained alongside with the rest of the pipeline.

  • pretrain (bool (default: True)) – If True, the model is pretrained with the specified source. If False, the randomly-initialized model is instantiated.

Example

>>> inputs = torch.rand([10, 600])
>>> model_hub = "facebook/wav2vec2-base-960h"
>>> save_path = "savedir"
>>> model = HuggingFaceWav2Vec2(model_hub, save_path)
>>> outputs = model(inputs)
>>> outputs.shape
torch.Size([10, 1,  768])
forward(wav)[source]

Takes an input waveform and return its corresponding wav2vec encoding.

Parameters

wav (torch.Tensor (signal)) – A batch of audio signals to transform to features.

extract_features(wav)[source]

Takes an input waveform and return its corresponding wav2vec encoding.

Parameters

wav (torch.Tensor (signal)) – A batch of audio signals to transform to features.

training: bool