Source code for speechbrain.integrations.audio_tokenizers.wavtokenizer_interface

"""This lobe enables the integration of pretrained WavTokenizer.

Note that you need to pip install `git+https://github.com/Tomiinek/WavTokenizer` to use this module.

Repository: https://github.com/jishengpeng/WavTokenizer/
Paper: https://arxiv.org/abs/2408.16532

Authors
 * Pooneh Mousavi 2024
"""

import os

import torch
import torch.nn as nn
from huggingface_hub import snapshot_download


[docs] class WavTokenizer(nn.Module): """This lobe enables the integration of pretrained WavTokenizer model, a discrete codec models with single codebook for Audio Language Modeling. Source paper: https://arxiv.org/abs/2408.16532 You need to pip install `git+https://github.com/Tomiinek/WavTokenizer` to use this module. The code is adapted from the official WavTokenizer repository: https://github.com/jishengpeng/WavTokenizer/ Arguments --------- source : str A HuggingFace repository identifier or a path save_path : str The location where the pretrained model will be saved config : str The name of the HF config file. checkpoint : str The name of the HF checkpoint file. sample_rate : int (default: 24000) The audio sampling rate freeze : bool whether the model will be frozen (e.g. not trainable if used as part of training another model) Example ------- >>> model_hub = "novateur/WavTokenizer" >>> save_path = "savedir" >>> config = "wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml" >>> checkpoint = "WavTokenizer_small_600_24k_4096.ckpt" >>> model = WavTokenizer( ... model_hub, save_path, config=config, checkpoint=checkpoint ... ) >>> audio = torch.randn(4, 48000) >>> length = torch.tensor([1.0, 0.5, 0.75, 1.0]) >>> tokens, embs = model.encode(audio) >>> tokens.shape torch.Size([4, 1, 80]) >>> embs.shape torch.Size([4, 80, 512]) >>> rec = model.decode(tokens) >>> rec.shape torch.Size([4, 48000]) """ def __init__( self, source, save_path=None, config="wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml", checkpoint="WavTokenizer_small_600_24k_4096.ckpt", sample_rate=24000, freeze=True, ): # Lazy import to avoid circular dependency issues try: import wavtokenizer self.wavtokenizer = wavtokenizer except ImportError: raise ImportError( "Please install the WavTokenizer module using: " "`pip install git+https://github.com/Tomiinek/WavTokenizer`" ) super().__init__() path = snapshot_download(repo_id=source, cache_dir=save_path) checkpoint_path = os.path.join(path, checkpoint) config_path = os.path.join(path, config) self.model = self.wavtokenizer.WavTokenizer.from_pretrained0802( config_path, checkpoint_path ) self.embeddings = self._compute_embedding() self.sample_rate = sample_rate
[docs] def forward(self, inputs): """Encodes the input audio as tokens and embeddings and decodes audio from tokens Arguments --------- inputs : torch.Tensor A (Batch x Samples) tensor of audio Returns ------- tokens : torch.Tensor A (Batch x Tokens x Heads) tensor of audio tokens emb : torch.Tensor Raw vector embeddings from the model's quantizers audio : torch.Tensor the reconstructed audio """ tokens, embedding = self.encode(inputs) audio = self.decode(tokens) return tokens, embedding, audio
@torch.no_grad() def _compute_embedding(self): embs = self.model.feature_extractor.encodec.quantizer.vq.layers[ 0 ].codebook return embs
[docs] def encode(self, inputs): """Encodes the input audio as tokens and embeddings Arguments --------- inputs : torch.Tensor A (Batch x Samples) or (Batch x Channel x Samples) tensor of audio Returns ------- tokens : torch.Tensor A (Batch x NQ x Length) tensor of audio tokens emb : torch.Tensor Raw vector embeddings from the model's quantizers """ emb, tokens = self.model.encode(inputs, bandwidth_id=0) return tokens.movedim(0, 1), emb.movedim(1, -1)
[docs] def decode( self, tokens, ): """Decodes audio from tokens Arguments --------- tokens : torch.Tensor A (Batch x NQ x Length) tensor of audio tokens Returns ------- audio : torch.Tensor the reconstructed audio """ feats = self.model.codes_to_features(tokens.movedim(1, 0)) sig = self.model.decode( feats, bandwidth_id=torch.tensor(0, device=tokens.device) ) return sig