speechbrain.lobes.models.discrete.wavtokenizer module

This lobe enables the integration of pretrained WavTokenizer.

Note that you need to pip install git+https://github.com/Tomiinek/WavTokenizer to use this module.

Repository: https://github.com/jishengpeng/WavTokenizer/ Paper: https://arxiv.org/abs/2408.16532

Authors
  • Pooneh Mousavi 2024

Summary

Classes:

WavTokenizer

This lobe enables the integration of pretrained WavTokenizer model, a discrete codec models with single codebook for Audio Language Modeling.

Reference

class speechbrain.lobes.models.discrete.wavtokenizer.WavTokenizer(source, save_path=None, config='wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml', checkpoint='WavTokenizer_small_600_24k_4096.ckpt', sample_rate=24000, freeze=True)[source]

Bases: Module

This lobe enables the integration of pretrained WavTokenizer model, a discrete codec models with single codebook for Audio Language Modeling.

Source paper:

https://arxiv.org/abs/2408.16532

You need to pip install git+https://github.com/Tomiinek/WavTokenizer to use this module.

The code is adapted from the official WavTokenizer repository: https://github.com/jishengpeng/WavTokenizer/

Parameters:
  • source (str) – A HuggingFace repository identifier or a path

  • save_path (str) – The location where the pretrained model will be saved

  • config (str) – The name of the HF config file.

  • checkpoint (str) – The name of the HF checkpoint file.

  • sample_rate (int (default: 24000)) – The audio sampling rate

  • freeze (bool) – whether the model will be frozen (e.g. not trainable if used as part of training another model)

Example

>>> model_hub = "novateur/WavTokenizer"
>>> save_path = "savedir"
>>> config="wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
>>> checkpoint="WavTokenizer_small_600_24k_4096.ckpt"
>>> model = WavTokenizer(model_hub, save_path,config=config,checkpoint=checkpoint)
>>> audio = torch.randn(4, 48000)
>>> length = torch.tensor([1.0, .5, .75, 1.0])
>>> tokens, embs= model.encode(audio)
>>> tokens.shape
torch.Size([4, 1, 80])
>>> embs.shape
torch.Size([4, 80, 512])
>>> rec = model.decode(tokens)
>>> rec.shape
torch.Size([4, 48000])
forward(inputs)[source]

Encodes the input audio as tokens and embeddings and decodes audio from tokens

Parameters:

inputs (torch.Tensor) – A (Batch x Samples) tensor of audio

Returns:

  • tokens (torch.Tensor) – A (Batch x Tokens x Heads) tensor of audio tokens

  • emb (torch.Tensor) – Raw vector embeddings from the model’s quantizers

  • audio (torch.Tensor) – the reconstructed audio

encode(inputs)[source]

Encodes the input audio as tokens and embeddings

Parameters:

inputs (torch.Tensor) – A (Batch x Samples) or (Batch x Channel x Samples) tensor of audio

Returns:

  • tokens (torch.Tensor) – A (Batch x NQ x Length) tensor of audio tokens

  • emb (torch.Tensor) – Raw vector embeddings from the model’s quantizers

decode(tokens)[source]

Decodes audio from tokens

Parameters:

tokens (torch.Tensor) – A (Batch x NQ x Length) tensor of audio tokens

Returns:

audio – the reconstructed audio

Return type:

torch.Tensor