"""This lobe enables the integration of huggingface pretrained Mimi.
Mimi codec is a state-of-the-art audio neural codec, developed by Kyutai.
It combines semantic and acoustic information into audio tokens running at 12Hz and a bitrate of 1.1kbps.
Note that you need to install `transformers>=4.45.1` to use this module.
Repository: https://huggingface.co/kyutai/mimi
Paper: https://kyutai.org/Moshi.pdf
Authors
* Pooneh Mousavi 2024
"""
import torch
from speechbrain.dataio.dataio import clean_padding_, length_to_mask
from speechbrain.integrations.huggingface.huggingface import (
HFTransformersInterface,
)
from speechbrain.utils.logger import get_logger
logger = get_logger(__name__)
[docs]
class Mimi(HFTransformersInterface):
"""This lobe enables the integration of HuggingFace pretrained Mimi model.
Mimi codec is a state-of-the-art audio neural codec, developed by Kyutai.
It combines semantic and acoustic information into audio tokens running at 12Hz and a bitrate of 1.1kbps.
Source paper:
https://kyutai.org/Moshi.pdf
Transformers>=4.45.1 from HuggingFace needs to be installed:
https://huggingface.co/transformers/installation.html
The code is adapted from the official HF Kyutai repository:
https://huggingface.co/kyutai/mimi
Arguments
---------
source : str
A HuggingFace repository identifier or a path
save_path : str
The location where the pretrained model will be saved
sample_rate : int (default: 24000)
The audio sampling rate
freeze : bool
whether the model will be frozen (e.g. not trainable if used as part of training another model)
num_codebooks : int (default: 8)
Number of codebooks. It could be [2,3,4,5,6,7,8]
Example
-------
>>> model_hub = "kyutai/mimi"
>>> save_path = "savedir"
>>> model = Mimi(model_hub, save_path)
>>> audio = torch.randn(4, 48000)
>>> length = torch.tensor([1.0, 0.5, 0.75, 1.0])
>>> tokens, emb = model.encode(audio, length)
>>> tokens.shape
torch.Size([4, 8, 25])
>>> emb.shape
torch.Size([4, 8, 25, 256])
>>> rec = model.decode(tokens, length)
>>> rec.shape
torch.Size([4, 1, 48000])
"""
def __init__(
self,
source,
save_path,
sample_rate=24000,
freeze=True,
num_codebooks=8,
):
super().__init__(source=source, save_path=save_path, freeze=freeze)
self.num_codebooks = num_codebooks
self.sample_rate = sample_rate
self.embeddings = None
@torch.no_grad()
def _compute_embedding(self):
semantic_layers = (
self.model.quantizer.semantic_residual_vector_quantizer.layers
)
acoustic_layers = (
self.model.quantizer.acoustic_residual_vector_quantizer.layers
)
layers = (semantic_layers + acoustic_layers)[: self.num_codebooks]
embs = [layer.codebook.embed for layer in layers]
embs = torch.stack(embs) # [K, C, H]
return embs
[docs]
def forward(self, inputs, length):
"""Encodes the input audio as tokens and embeddings and decodes audio from tokens
Arguments
---------
inputs : torch.Tensor
A (Batch x Samples) or (Batch x Channel x Samples)
tensor of audio
length : torch.Tensor
A tensor of relative lengths
Returns
-------
tokens : torch.Tensor
A (Batch x Tokens x Heads) tensor of audio tokens
emb : torch.Tensor
Raw vector embeddings from the model's
quantizers
audio : torch.Tensor
the reconstructed audio
"""
tokens, embedding = self.encode(inputs, length)
audio = self.decode(tokens, length)
return tokens, embedding, audio
[docs]
def encode(self, inputs, length):
"""Encodes the input audio as tokens and embeddings
Arguments
---------
inputs : torch.Tensor
A (Batch x Samples) or (Batch x Channel x Samples)
tensor of audio
length : torch.Tensor
A tensor of relative lengths
Returns
-------
tokens : torch.Tensor
A (Batch x num_codebooks x Length) tensor of audio tokens
emb : torch.Tensor
Raw vector embeddings from the model's
quantizers
"""
if self.embeddings is None:
self.embeddings = self._compute_embedding()
if inputs.dim() == 2:
inputs = inputs.unsqueeze(1)
max_len = inputs.size(-1)
padding_mask = length_to_mask(
length * max_len, max_len, device=inputs.device
).unsqueeze(1)
tokens = self.model.encode(
inputs, padding_mask, num_quantizers=self.num_codebooks
)[0]
# Reshape input_tensor for broadcasting
input_tensor = tokens.unsqueeze(-1).expand(
-1, -1, -1, self.embeddings.shape[-1]
) # [B, N, T, D]
# Gather embeddings for each token
embeddings = torch.gather(
self.embeddings.unsqueeze(0).expand(tokens.shape[0], -1, -1, -1),
2,
input_tensor,
)
return tokens, embeddings
[docs]
def decode(self, tokens, length=None):
"""Decodes audio from tokens
Arguments
---------
tokens : torch.Tensor
A (Batch x num_codebooks x Length) tensor of audio tokens
length : torch.Tensor
A 1-D tensor of relative lengths
Returns
-------
audio : torch.Tensor
the reconstructed audio
"""
if self.embeddings is None:
self.embeddings = self._compute_embedding()
result = self.model.decode(tokens)
audio = result.audio_values
if length is not None:
clean_padding_(audio, length)
return audio