Source code for speechbrain.integrations.huggingface.w2v_bert

"""This lobe enables the integration of HuggingFace pretrained w2v-bert-2.0 models.

Reference: https://arxiv.org/abs/2312.05187
Transformer from HuggingFace needs to be installed:
https://huggingface.co/transformers/installation.html

Authors
 * Maryem Bouziane 2025
 * Salima Mdhaffar 2025
 * Yannick Estève 2025
"""

from typing import Optional

import torch
import torch.nn.functional as F

from speechbrain.integrations.huggingface.huggingface import (
    HFTransformersInterface,
)
from speechbrain.utils.data_utils import undo_padding
from speechbrain.utils.logger import get_logger

logger = get_logger(__name__)


[docs] class W2VBert(HFTransformersInterface): """This lobe enables the integration of HuggingFace and SpeechBrain pretrained w2v-bert-2.0 models. Source paper w2v-BERT: https://arxiv.org/abs/2312.05187 Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html The model can be used as a fixed feature extractor or can be finetuned. It will download automatically the model from HuggingFace or use a local path. Arguments --------- source : str HuggingFace hub name or local path, e.g. "facebook/w2v-bert-2.0". save_path : str Path (dir) used to cache / save the model. output_norm : bool (default: False) If True, a layer_norm is applied to the output features. freeze : bool (default: True) If True, the model is frozen. If False, the model is trained alongside the rest of the pipeline. freeze_feature_extractor : bool (default: False) When ``freeze`` is False and this flag is True, only the convolutional feature extractor is frozen. apply_spec_augment : bool (default: False) If True, the internal SpecAugment of the HF model is enabled. output_all_hiddens : bool (default: False) If True, the forward method outputs the hidden states from all transformer layers. sample_rate : int or None (default: None) Expected sampling rate of the input waveforms. If None, the sampling rate is read from the HF feature extractor when available, otherwise it defaults to 16000. **kwargs Extra keyword arguments passed to the `from_pretrained` function. Example ------- >>> inputs = torch.rand([2, 16000]) >>> model_hub = "facebook/w2v-bert-2.0" >>> save_path = "savedir" >>> model = W2VBert(model_hub, save_path) >>> outputs = model(inputs) """ def __init__( self, source: str, save_path: str, output_norm: bool = False, freeze: bool = True, freeze_feature_extractor: bool = False, apply_spec_augment: bool = False, output_all_hiddens: bool = False, sample_rate: Optional[int] = None, **kwargs, ): super().__init__( source=source, save_path=save_path, freeze=freeze, **kwargs, ) # We load the HF feature extractor self.load_feature_extractor(source, cache_dir=save_path) # We determine the sampling rate to be used if sample_rate is not None: self.sample_rate = sample_rate else: self.sample_rate = getattr( self.feature_extractor, "sampling_rate", 16000 ) logger.info( f"[W2VBert] feature_extractor sample_rate = {self.sample_rate}" ) self.model.config.apply_spec_augment = apply_spec_augment self.output_norm = output_norm self.output_all_hiddens = output_all_hiddens self.freeze_feature_extractor = freeze_feature_extractor if not self.freeze and self.freeze_feature_extractor: logger.warning( "speechbrain.integrations.huggingface.w2v_bert - " "w2v-bert feature extractor is frozen." ) self.model.feature_extractor.eval() for param in self.model.feature_extractor.parameters(): param.requires_grad = False
[docs] def forward( self, wav: torch.Tensor, wav_lens: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Takes an input waveform and returns its corresponding w2v-BERT encoding. Arguments --------- wav : torch.Tensor (signal) A batch of audio signals to transform to features. wav_lens : torch.Tensor or None The relative length of the wav given in SpeechBrain format. Returns ------- torch.Tensor w2v-BERT encoded features. """ if self.freeze: with torch.no_grad(): return self._forward_hf(wav, wav_lens) return self._forward_hf(wav, wav_lens)
def _forward_hf( self, wav: torch.Tensor, wav_lens: Optional[torch.Tensor], ) -> torch.Tensor: """Takes an input waveform and returns its corresponding w2v-BERT encoding. Arguments --------- wav : torch.Tensor (signal) A batch of padded audio signals to transform to features. wav_lens : torch.Tensor or None The relative length of the wav given in SpeechBrain format. Returns ------- torch.Tensor w2v-BERT encoded features. """ device = wav.device B, _ = wav.shape if wav_lens is not None: wav_list = undo_padding( wav.detach().cpu(), wav_lens.detach().cpu(), ) else: wav_list = [wav[b].detach().cpu() for b in range(B)] inputs = self.feature_extractor( wav_list, sampling_rate=self.sample_rate, return_tensors="pt", padding=True, ) inputs = {k: v.to(device) for k, v in inputs.items()} out = self.model( **inputs, output_hidden_states=self.output_all_hiddens, ) if self.output_all_hiddens: out_tensor = torch.stack(list(out.hidden_states), dim=0) norm_shape = out_tensor.shape[-1:] else: out_tensor = out.last_hidden_state norm_shape = out_tensor.shape[-1:] if self.output_norm: out_tensor = F.layer_norm(out_tensor, norm_shape) return out_tensor