speechbrain.lobes.models.discrete.dac moduleο
This lobe enables the integration of pretrained discrete DAC model. Reference: http://arxiv.org/abs/2306.06546 Reference: https://descript.notion.site/Descript-Audio-Codec-11389fce0ce2419891d6591a68f814d5 Reference: https://github.com/descriptinc/descript-audio-codec
- Author
Shubham Gupta 2023
Summaryο
Classes:
Discrete Autoencoder Codec (DAC) for audio data encoding and decoding. |
|
A PyTorch module for the Decoder part of DAC. |
|
A PyTorch module representing a block within the Decoder architecture. |
|
A PyTorch module for the Encoder part of DAC. |
|
An encoder block module for convolutional neural networks. |
|
A residual unit module for convolutional neural networks. |
|
Introduced in SoundStream: An end2end neural audio codec https://arxiv.org/abs/2107.03312 |
|
A PyTorch module implementing the Snake activation function in 1D. |
|
An implementation for Vector Quantization |
Functions:
Apply weight normalization to a 1D convolutional layer. |
|
Apply weight normalization to a 1D transposed convolutional layer. |
|
Downloads a specified model file based on model type, bitrate, and tag, saving it to a local path. |
|
Initialize the weights of a 1D convolutional layer. |
Referenceο
- speechbrain.lobes.models.discrete.dac.WNConv1d(*args, **kwargs)[source]ο
Apply weight normalization to a 1D convolutional layer.
- speechbrain.lobes.models.discrete.dac.WNConvTranspose1d(*args, **kwargs)[source]ο
Apply weight normalization to a 1D transposed convolutional layer.
- speechbrain.lobes.models.discrete.dac.init_weights(m)[source]ο
Initialize the weights of a 1D convolutional layer.
- speechbrain.lobes.models.discrete.dac.download(model_type: str = '44khz', model_bitrate: str = '8kbps', tag: str = 'latest', local_path: Path = None)[source]ο
Downloads a specified model file based on model type, bitrate, and tag, saving it to a local path.
- Parameters:
model_type (str, optional) β The type of model to download. Can be β44khzβ, β24khzβ, or β16khzβ. Default is β44khzβ.
model_bitrate (str, optional) β The bitrate of the model. Can be β8kbpsβ or β16kbpsβ. Default is β8kbpsβ.
tag (str, optional) β A specific version tag for the model. Default is βlatestβ.
local_path (Path, optional) β The local file path where the model will be saved. If not provided, a default path will be used.
- Returns:
The local path where the model is saved.
- Return type:
Path
- Raises:
ValueError β If the model type or bitrate is not supported, or if the model cannot be found or downloaded.
- class speechbrain.lobes.models.discrete.dac.VectorQuantize(input_dim: int, codebook_size: int, codebook_dim: int)[source]ο
Bases:
Module
An implementation for Vector Quantization
Implementation of VQ similar to Karpathyβs repo: https://github.com/karpathy/deep-vector-quantization Additionally uses following tricks from Improved VQGAN (https://arxiv.org/pdf/2110.04627.pdf):
- Factorized codes: Perform nearest neighbor lookup in low-dimensional space
for improved codebook usage
- l2-normalized codes: Converts euclidean distance to cosine similarity which
improves training stability
- Parameters:
- forward(z: Tensor)[source]ο
Quantized the input tensor using a fixed codebook and returns the corresponding codebook vectors
- Parameters:
z (torch.Tensor[B x D x T])
- Returns:
torch.Tensor[B x D x T] β Quantized continuous representation of input
torch.Tensor[1] β Commitment loss to train encoder to predict vectors closer to codebook entries
torch.Tensor[1] β Codebook loss to update the codebook
torch.Tensor[B x T] β Codebook indices (quantized discrete representation of input)
torch.Tensor[B x D x T] β Projected latents (continuous representation of input before quantization)
- embed_code(embed_id: Tensor)[source]ο
Embeds an ID using the codebook weights.
This method utilizes the codebook weights to embed the given ID.
- Parameters:
embed_id (torch.Tensor) β The tensor containing IDs that need to be embedded.
- Returns:
The embedded output tensor after applying the codebook weights.
- Return type:
torch.Tensor
- decode_code(embed_id: Tensor)[source]ο
Decodes the embedded ID by transposing the dimensions.
This method decodes the embedded ID by applying a transpose operation to the dimensions of the output tensor from the
embed_code
method.- Parameters:
embed_id (torch.Tensor) β The tensor containing embedded IDs.
- Returns:
The decoded tensor
- Return type:
torch.Tensor
- decode_latents(latents: Tensor)[source]ο
Decodes latent representations into discrete codes by comparing with the codebook.
- Parameters:
latents (torch.Tensor) β The latent tensor representations to be decoded.
- Returns:
A tuple containing the decoded latent tensor (
z_q
) and the indices of the codes.- Return type:
Tuple[torch.Tensor, torch.Tensor]
- class speechbrain.lobes.models.discrete.dac.ResidualVectorQuantize(input_dim: int = 512, n_codebooks: int = 9, codebook_size: int = 1024, codebook_dim: int | list = 8, quantizer_dropout: float = 0.0)[source]ο
Bases:
Module
Introduced in SoundStream: An end2end neural audio codec https://arxiv.org/abs/2107.03312
- Parameters:
Example
Using a pretrained RVQ unit.
>>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest") >>> quantizer = dac.quantizer >>> continuous_embeddings = torch.randn(1, 1024, 100) # Example shape: [Batch, Channels, Time] >>> discrete_embeddings, codes, _, _, _ = quantizer(continuous_embeddings)
- forward(z, n_quantizers: int = None)[source]ο
Quantized the input tensor using a fixed set of
n
codebooks and returns the corresponding codebook vectors- Parameters:
z (torch.Tensor) β Shape [B x D x T]
n_quantizers (int, optional) β
No. of quantizers to use (n_quantizers < self.n_codebooks ex: for quantizer dropout) Note: if
self.quantizer_dropout
is True, this argument is ignoredwhen in training mode, and a random number of quantizers is used.
- Returns:
z (torch.Tensor[B x D x T]) β Quantized continuous representation of input
codes (torch.Tensor[B x N x T]) β Codebook indices for each codebook (quantized discrete representation of input)
latents (torch.Tensor[B x N*D x T]) β Projected latents (continuous representation of input before quantization)
vq/commitment_loss (torch.Tensor[1]) β Commitment loss to train encoder to predict vectors closer to codebook entries
vq/codebook_loss (torch.Tensor[1]) β Codebook loss to update the codebook
- from_codes(codes: Tensor)[source]ο
Given the quantized codes, reconstruct the continuous representation
- Parameters:
codes (torch.Tensor[B x N x T]) β Quantized discrete representation of input
- Returns:
Quantized continuous representation of input
- Return type:
torch.Tensor[B x D x T]
- from_latents(latents: Tensor)[source]ο
Given the unquantized latents, reconstruct the continuous representation after quantization.
- Parameters:
latents (torch.Tensor[B x N x T]) β Continuous representation of input after projection
- Returns:
torch.Tensor[B x D x T] β Quantized representation of full-projected space
torch.Tensor[B x D x T] β Quantized representation of latent space
- class speechbrain.lobes.models.discrete.dac.Snake1d(channels)[source]ο
Bases:
Module
A PyTorch module implementing the Snake activation function in 1D.
- Parameters:
channels (int) β The number of channels in the input tensor.
- class speechbrain.lobes.models.discrete.dac.ResidualUnit(dim: int = 16, dilation: int = 1)[source]ο
Bases:
Module
A residual unit module for convolutional neural networks.
- Parameters:
- class speechbrain.lobes.models.discrete.dac.EncoderBlock(dim: int = 16, stride: int = 1)[source]ο
Bases:
Module
An encoder block module for convolutional neural networks.
This module constructs an encoder block consisting of a series of ResidualUnits and a final Snake1d activation followed by a weighted normalized 1D convolution. This block can be used as part of an encoder in architectures like autoencoders.
- Parameters:
- class speechbrain.lobes.models.discrete.dac.Encoder(d_model: int = 64, strides: list = [2, 4, 8, 8], d_latent: int = 64)[source]ο
Bases:
Module
A PyTorch module for the Encoder part of DAC.
- Parameters:
Example
Creating an Encoder instance >>> encoder = Encoder() >>> audio_input = torch.randn(1, 1, 44100) # Example shape: [Batch, Channels, Time] >>> continuous_embedding = encoder(audio_input)
Using a pretrained encoder.
>>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest") >>> encoder = dac.encoder >>> audio_input = torch.randn(1, 1, 44100) # Example shape: [Batch, Channels, Time] >>> continuous_embeddings = encoder(audio_input)
- class speechbrain.lobes.models.discrete.dac.DecoderBlock(input_dim: int = 16, output_dim: int = 8, stride: int = 1)[source]ο
Bases:
Module
A PyTorch module representing a block within the Decoder architecture.
- Parameters:
- class speechbrain.lobes.models.discrete.dac.Decoder(input_channel: int, channels: int, rates: List[int], d_out: int = 1)[source]ο
Bases:
Module
A PyTorch module for the Decoder part of DAC.
- Parameters:
Example
Creating a Decoder instance
>>> decoder = Decoder(256, 1536, [8, 8, 4, 2]) >>> discrete_embeddings = torch.randn(2, 256, 200) # Example shape: [Batch, Channels, Time] >>> recovered_audio = decoder(discrete_embeddings)
Using a pretrained decoder. Note that the actual input should be proper discrete representation. Using randomly generated input here for illustration of use.
>>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest") >>> decoder = dac.decoder >>> discrete_embeddings = torch.randn(1, 1024, 500) # Example shape: [Batch, Channels, Time] >>> recovered_audio = decoder(discrete_embeddings)
- class speechbrain.lobes.models.discrete.dac.DAC(encoder_dim: int = 64, encoder_rates: List[int] = [2, 4, 8, 8], latent_dim: int = None, decoder_dim: int = 1536, decoder_rates: List[int] = [8, 8, 4, 2], n_codebooks: int = 9, codebook_size: int = 1024, codebook_dim: int | list = 8, quantizer_dropout: bool = False, sample_rate: int = 44100, model_type: str = '44khz', model_bitrate: str = '8kbps', tag: str = 'latest', load_path: str = None, strict: bool = False, load_pretrained: bool = False)[source]ο
Bases:
Module
Discrete Autoencoder Codec (DAC) for audio data encoding and decoding.
This class implements an autoencoder architecture with quantization for efficient audio processing. It includes an encoder, quantizer, and decoder for transforming audio data into a compressed latent representation and reconstructing it back into audio. This implementation supports both initializing a new model and loading a pretrained model.
- Parameters:
encoder_dim (int) β Dimensionality of the encoder.
encoder_rates (List[int]) β Downsampling rates for each encoder layer.
latent_dim (int, optional) β Dimensionality of the latent space, automatically calculated if None.
decoder_dim (int) β Dimensionality of the decoder.
decoder_rates (List[int]) β Upsampling rates for each decoder layer.
n_codebooks (int) β Number of codebooks for vector quantization.
codebook_size (int) β Size of each codebook.
codebook_dim (Union[int, list]) β Dimensionality of each codebook entry.
quantizer_dropout (bool) β Whether to use dropout in the quantizer.
sample_rate (int) β Sample rate of the audio data.
model_type (str) β Type of the model to load (if pretrained).
model_bitrate (str) β Bitrate of the model to load (if pretrained).
tag (str) β Specific tag of the model to load (if pretrained).
load_path (str, optional) β Path to load the pretrained model from, automatically downloaded if None.
strict (bool) β Whether to strictly enforce the state dictionary match.
load_pretrained (bool) β Whether to load a pretrained model.
Example
Creating a new DAC instance:
>>> dac = DAC() >>> audio_data = torch.randn(1, 1, 16000) # Example shape: [Batch, Channels, Time] >>> tokens, embeddings = dac(audio_data)
Loading a pretrained DAC instance:
>>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest") >>> audio_data = torch.randn(1, 1, 16000) # Example shape: [Batch, Channels, Time] >>> tokens, embeddings = dac(audio_data)
The tokens and the discrete embeddings obtained above or from other sources can be decoded:
>>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest") >>> audio_data = torch.randn(1, 1, 16000) # Example shape: [Batch, Channels, Time] >>> tokens, embeddings = dac(audio_data) >>> decoded_audio = dac.decode(embeddings)
- encode(audio_data: Tensor, n_quantizers: int = None)[source]ο
Encode given audio data and return quantized latent codes
- Parameters:
audio_data (torch.Tensor[B x 1 x T]) β Audio data to encode
n_quantizers (int, optional) β Number of quantizers to use, by default None If None, all quantizers are used.
- Returns:
βzβ (torch.Tensor[B x D x T]) β Quantized continuous representation of input
βcodesβ (torch.Tensor[B x N x T]) β Codebook indices for each codebook (quantized discrete representation of input)
βlatentsβ (torch.Tensor[B x N*D x T]) β Projected latents (continuous representation of input before quantization)
βvq/commitment_lossβ (torch.Tensor[1]) β Commitment loss to train encoder to predict vectors closer to codebook entries
βvq/codebook_lossβ (torch.Tensor[1]) β Codebook loss to update the codebook
βlengthβ (int) β Number of samples in input audio
- decode(z: Tensor)[source]ο
Decode given latent codes and return audio data
- Parameters:
z (torch.Tensor) β Shape [B x D x T] Quantized continuous representation of input
- Returns:
torch.Tensor β Decoded audio data.
- Return type:
shape B x 1 x length
- forward(audio_data: Tensor, sample_rate: int = None, n_quantizers: int = None)[source]ο
Model forward pass
- Parameters:
- Returns:
βtokensβ (torch.Tensor[B x N x T]) β Codebook indices for each codebook (quantized discrete representation of input)
βembeddingsβ (torch.Tensor[B x D x T]) β Quantized continuous representation of input