Source code for speechbrain.nnet.autoencoders

"""Autoencoder implementation. Can be used for Latent Diffusion or in isolation

Authors
 * Artem Ploujnikov 2022
"""

import torch
from torch import nn
from collections import namedtuple
from speechbrain.dataio.dataio import clean_padding
from speechbrain.utils.data_utils import trim_as
from speechbrain.processing.features import GlobalNorm


[docs] class Autoencoder(nn.Module): """A standard interface for autoencoders Example ------- >>> import torch >>> from torch import nn >>> from speechbrain.nnet.linear import Linear >>> class SimpleAutoencoder(Autoencoder): ... def __init__(self): ... super().__init__() ... self.enc = Linear(n_neurons=16, input_size=128) ... self.dec = Linear(n_neurons=128, input_size=16) ... def encode(self, x, length=None): ... return self.enc(x) ... def decode(self, x, length=None): ... return self.dec(x) >>> autoencoder = SimpleAutoencoder() >>> x = torch.randn(4, 10, 128) >>> x_enc = autoencoder.encode(x) >>> x_enc.shape torch.Size([4, 10, 16]) >>> x_enc_fw = autoencoder(x) >>> x_enc_fw.shape torch.Size([4, 10, 16]) >>> x_rec = autoencoder.decode(x_enc) >>> x_rec.shape torch.Size([4, 10, 128]) """
[docs] def encode(self, x, length=None): """Converts a sample from an original space (e.g. pixel or waveform) to a latent space Arguments --------- x: torch.Tensor the original data representation length: torch.Tensor a tensor of relative lengths Returns ------- latent: torch.Tensor the latent representation """ raise NotImplementedError
[docs] def decode(self, latent): """Decodes the sample from a latent repsresentation Arguments --------- latent: torch.Tensor the latent representation Returns ------- result: torch.Tensor the decoded sample """ raise NotImplementedError
[docs] def forward(self, x): """Performs the forward pass Arguments --------- x: torch.Tensor the input tensor Results ------- result: torch.Tensor the result """ return self.encode(x)
[docs] class VariationalAutoencoder(Autoencoder): """A Variational Autoencoder (VAE) implementation. Paper reference: https://arxiv.org/abs/1312.6114 Arguments --------- encoder: torch.Module the encoder network decoder: torch.Module the decoder network mean: torch.Module the module that computes the mean log_var: torch.Module the module that computes the log variance mask_value: float The value with which outputs and latents will be masked len_dim: None the length dimension mask_latent: bool where to apply the length mask to the latent representation mask_out: bool whether to apply the length mask to the output out_mask_value: float the mask value used for the output latent_mask_value: float the mask value used for the latent representation latent_stochastic: bool if true, the "latent" parameter of VariationalAutoencoderOutput will be the latent space sample if false, it will be the mean Example ------- The example below shows a very simple implementation of VAE, not suitable for actual experiments: >>> import torch >>> from torch import nn >>> from speechbrain.nnet.linear import Linear >>> vae_enc = Linear(n_neurons=16, input_size=128) >>> vae_dec = Linear(n_neurons=128, input_size=16) >>> vae_mean = Linear(n_neurons=16, input_size=16) >>> vae_log_var = Linear(n_neurons=16, input_size=16) >>> vae = VariationalAutoencoder( ... encoder=vae_enc, ... decoder=vae_dec, ... mean=vae_mean, ... log_var=vae_log_var, ... ) >>> x = torch.randn(4, 10, 128) `train_sample` encodes a single batch and then reconstructs it >>> vae_out = vae.train_sample(x) >>> vae_out.rec.shape torch.Size([4, 10, 128]) >>> vae_out.latent.shape torch.Size([4, 10, 16]) >>> vae_out.mean.shape torch.Size([4, 10, 16]) >>> vae_out.log_var.shape torch.Size([4, 10, 16]) >>> vae_out.latent_sample.shape torch.Size([4, 10, 16]) .encode() will return the mean corresponding to teh sample provided >>> x_enc = vae.encode(x) >>> x_enc.shape torch.Size([4, 10, 16]) .reparameterize() performs the reparameterization trick >>> x_enc = vae.encoder(x) >>> mean = vae.mean(x_enc) >>> log_var = vae.log_var(x_enc) >>> x_repar = vae.reparameterize(mean, log_var) >>> x_repar.shape torch.Size([4, 10, 16]) """ def __init__( self, encoder, decoder, mean, log_var, len_dim=1, latent_padding=None, mask_latent=True, mask_out=True, out_mask_value=0.0, latent_mask_value=0.0, latent_stochastic=True, ): super().__init__() self.encoder = encoder self.decoder = decoder self.mean = mean self.log_var = log_var self.len_dim = len_dim self.latent_padding = latent_padding self.mask_latent = mask_latent self.mask_out = mask_out self.out_mask_value = out_mask_value self.latent_mask_value = latent_mask_value self.latent_stochastic = latent_stochastic
[docs] def encode(self, x, length=None): """Converts a sample from an original space (e.g. pixel or waveform) to a latent space Arguments --------- x: torch.Tensor the original data representation Returns ------- latent: torch.Tensor the latent representation """ encoder_out = self.encoder(x) return self.mean(encoder_out)
[docs] def decode(self, latent): """Decodes the sample from a latent repsresentation Arguments --------- latent: torch.Tensor the latent representation Returns ------- result: torch.Tensor the decoded sample """ return self.decoder(latent)
[docs] def reparameterize(self, mean, log_var): """Applies the VAE reparameterization trick to get a latent space single latent space sample for decoding Arguments --------- mean: torch.Tensor the latent representation mean log_var: torch.Tensor the logarithm of the latent representation variance Returns ------- sample: torch.Tensor a latent space sample""" epsilon = torch.randn_like(log_var) return mean + epsilon * torch.exp(0.5 * log_var)
[docs] def train_sample( self, x, length=None, out_mask_value=None, latent_mask_value=None ): """Provides a data sample for training the autoencoder Arguments --------- x: torch.Tensor the source data (in the sample space) length: None the length (optional). If provided, latents and outputs will be masked Returns ------- result: VariationalAutoencoderOutput a named tuple with the following values rec: torch.Tensor the reconstruction latent: torch.Tensor the latent space sample mean: torch.Tensor the mean of the latent representation log_var: torch.Tensor the logarithm of the variance of the latent representation """ if out_mask_value is None: out_mask_value = self.out_mask_value if latent_mask_value is None: latent_mask_value = self.latent_mask_value encoder_out = self.encoder(x) mean = self.mean(encoder_out) log_var = self.log_var(encoder_out) latent_sample = self.reparameterize(mean, log_var) if self.latent_padding is not None: latent_sample, latent_length = self.latent_padding( latent_sample, length=length ) else: latent_length = length if self.mask_latent and length is not None: latent_sample = clean_padding( latent_sample, latent_length, self.len_dim, latent_mask_value ) x_rec = self.decode(latent_sample) x_rec = trim_as(x_rec, x) if self.mask_out and length is not None: x_rec = clean_padding(x_rec, length, self.len_dim, out_mask_value) if self.latent_stochastic: latent = latent_sample else: latent, latent_length = self.latent_padding(mean, length=length) return VariationalAutoencoderOutput( x_rec, latent, mean, log_var, latent_sample, latent_length )
VariationalAutoencoderOutput = namedtuple( "VariationalAutoencoderOutput", ["rec", "latent", "mean", "log_var", "latent_sample", "latent_length"], ) AutoencoderOutput = namedtuple( "AutoencoderOutput", ["rec", "latent", "latent_length"] )
[docs] class NormalizingAutoencoder(Autoencoder): """A classical (non-variational) autoencoder that does not use reparameterization but instead uses an ordinary normalization technique to constrain the latent space Arguments --------- encoder: torch.nn.Module the encoder to be used decoder: torch.nn.Module the decoder to be used norm: torch.nn.Module the normalization module mask_latent: bool where to apply the length mask to the latent representation mask_out: bool whether to apply the length mask to the output out_mask_value: float the mask value used for the output Examples -------- >>> import torch >>> from torch import nn >>> from speechbrain.nnet.linear import Linear >>> ae_enc = Linear(n_neurons=16, input_size=128) >>> ae_dec = Linear(n_neurons=128, input_size=16) >>> ae = NormalizingAutoencoder( ... encoder=ae_enc, ... decoder=ae_dec, ... ) >>> x = torch.randn(4, 10, 128) >>> x_enc = ae.encode(x) >>> x_enc.shape torch.Size([4, 10, 16]) >>> x_dec = ae.decode(x_enc) >>> x_dec.shape torch.Size([4, 10, 128]) """ def __init__( self, encoder, decoder, latent_padding=None, norm=None, len_dim=1, mask_out=True, mask_latent=True, out_mask_value=0.0, latent_mask_value=0.0, ): super().__init__() self.encoder = encoder self.decoder = decoder self.latent_padding = latent_padding if norm is None: norm = GlobalNorm() self.norm = norm self.len_dim = len_dim self.mask_out = mask_out self.mask_latent = mask_latent self.out_mask_value = out_mask_value self.latent_mask_value = latent_mask_value
[docs] def encode(self, x, length=None): """Converts a sample from an original space (e.g. pixel or waveform) to a latent space Arguments --------- x: torch.Tensor the original data representation Returns ------- latent: torch.Tensor the latent representation """ x = self.encoder(x) x = self.norm(x, lengths=length) return x
[docs] def decode(self, latent): """Decodes the sample from a latent repsresentation Arguments --------- latent: torch.Tensor the latent representation Returns ------- result: torch.Tensor the decoded sample """ return self.decoder(latent)
[docs] def train_sample( self, x, length=None, out_mask_value=None, latent_mask_value=None ): """Provides a data sample for training the autoencoder Arguments --------- x: torch.Tensor the source data (in the sample space) length: None the length (optional). If provided, latents and outputs will be masked Returns ------- result: AutoencoderOutput a named tuple with the following values rec: torch.Tensor the reconstruction latent: torch.Tensor the latent space sample """ if out_mask_value is None: out_mask_value = self.out_mask_value if latent_mask_value is None: latent_mask_value = self.latent_mask_value latent = self.encode(x, length=length) if self.latent_padding is not None: latent, latent_length = self.latent_padding(latent, length=length) else: latent_length = length if self.mask_latent and length is not None: latent = clean_padding( latent, latent_length, self.len_dim, latent_mask_value ) x_rec = self.decode(latent) x_rec = trim_as(x_rec, x) if self.mask_out and length is not None: x_rec = clean_padding(x_rec, length, self.len_dim, out_mask_value) return AutoencoderOutput(x_rec, latent, latent_length)