speechbrain.nnet.autoencoders module
Autoencoder implementation. Can be used for Latent Diffusion or in isolation
- Authors
Artem Ploujnikov 2022
Summary
Classes:
A standard interface for autoencoders |
|
A classical (non-variational) autoencoder that does not use reparameterization but instead uses an ordinary normalization technique to constrain the latent space |
|
A Variational Autoencoder (VAE) implementation. |
|
Reference
- class speechbrain.nnet.autoencoders.Autoencoder(*args, **kwargs)[source]
Bases:
Module
A standard interface for autoencoders
Example
>>> import torch >>> from torch import nn >>> from speechbrain.nnet.linear import Linear >>> class SimpleAutoencoder(Autoencoder): ... def __init__(self): ... super().__init__() ... self.enc = Linear(n_neurons=16, input_size=128) ... self.dec = Linear(n_neurons=128, input_size=16) ... def encode(self, x, length=None): ... return self.enc(x) ... def decode(self, x, length=None): ... return self.dec(x) >>> autoencoder = SimpleAutoencoder() >>> x = torch.randn(4, 10, 128) >>> x_enc = autoencoder.encode(x) >>> x_enc.shape torch.Size([4, 10, 16]) >>> x_enc_fw = autoencoder(x) >>> x_enc_fw.shape torch.Size([4, 10, 16]) >>> x_rec = autoencoder.decode(x_enc) >>> x_rec.shape torch.Size([4, 10, 128])
- encode(x, length=None)[source]
Converts a sample from an original space (e.g. pixel or waveform) to a latent space
- Parameters:
x (torch.Tensor) – the original data representation
length (torch.Tensor) – a tensor of relative lengths
- Returns:
latent – the latent representation
- Return type:
- decode(latent)[source]
Decodes the sample from a latent repsresentation
- Parameters:
latent (torch.Tensor) – the latent representation
- Returns:
result – the decoded sample
- Return type:
- forward(x)[source]
Performs the forward pass
- Parameters:
x (torch.Tensor) – the input tensor
Results –
------- –
result (torch.Tensor) – the result
- class speechbrain.nnet.autoencoders.VariationalAutoencoder(encoder, decoder, mean, log_var, len_dim=1, latent_padding=None, mask_latent=True, mask_out=True, out_mask_value=0.0, latent_mask_value=0.0, latent_stochastic=True)[source]
Bases:
Autoencoder
A Variational Autoencoder (VAE) implementation.
Paper reference: https://arxiv.org/abs/1312.6114
- Parameters:
encoder (torch.Module) – the encoder network
decoder (torch.Module) – the decoder network
mean (torch.Module) – the module that computes the mean
log_var (torch.Module) – the module that computes the log variance
mask_value (float) – The value with which outputs and latents will be masked
len_dim (None) – the length dimension
mask_latent (bool) – where to apply the length mask to the latent representation
mask_out (bool) – whether to apply the length mask to the output
out_mask_value (float) – the mask value used for the output
latent_mask_value (float) – the mask value used for the latent representation
latent_stochastic (bool) –
if true, the “latent” parameter of VariationalAutoencoderOutput will be the latent space sample
if false, it will be the mean
Example
The example below shows a very simple implementation of VAE, not suitable for actual experiments:
>>> import torch >>> from torch import nn >>> from speechbrain.nnet.linear import Linear >>> vae_enc = Linear(n_neurons=16, input_size=128) >>> vae_dec = Linear(n_neurons=128, input_size=16) >>> vae_mean = Linear(n_neurons=16, input_size=16) >>> vae_log_var = Linear(n_neurons=16, input_size=16) >>> vae = VariationalAutoencoder( ... encoder=vae_enc, ... decoder=vae_dec, ... mean=vae_mean, ... log_var=vae_log_var, ... ) >>> x = torch.randn(4, 10, 128)
train_sample encodes a single batch and then reconstructs it
>>> vae_out = vae.train_sample(x) >>> vae_out.rec.shape torch.Size([4, 10, 128]) >>> vae_out.latent.shape torch.Size([4, 10, 16]) >>> vae_out.mean.shape torch.Size([4, 10, 16]) >>> vae_out.log_var.shape torch.Size([4, 10, 16]) >>> vae_out.latent_sample.shape torch.Size([4, 10, 16])
.encode() will return the mean corresponding to teh sample provided
>>> x_enc = vae.encode(x) >>> x_enc.shape torch.Size([4, 10, 16])
.reparameterize() performs the reparameterization trick
>>> x_enc = vae.encoder(x) >>> mean = vae.mean(x_enc) >>> log_var = vae.log_var(x_enc) >>> x_repar = vae.reparameterize(mean, log_var) >>> x_repar.shape torch.Size([4, 10, 16])
- encode(x, length=None)[source]
Converts a sample from an original space (e.g. pixel or waveform) to a latent space
- Parameters:
x (torch.Tensor) – the original data representation
- Returns:
latent – the latent representation
- Return type:
- decode(latent)[source]
Decodes the sample from a latent repsresentation
- Parameters:
latent (torch.Tensor) – the latent representation
- Returns:
result – the decoded sample
- Return type:
- reparameterize(mean, log_var)[source]
Applies the VAE reparameterization trick to get a latent space single latent space sample for decoding
- Parameters:
mean (torch.Tensor) – the latent representation mean
log_var (torch.Tensor) – the logarithm of the latent representation variance
- Returns:
sample – a latent space sample
- Return type:
- train_sample(x, length=None, out_mask_value=None, latent_mask_value=None)[source]
Provides a data sample for training the autoencoder
- Parameters:
x (torch.Tensor) – the source data (in the sample space)
length (None) – the length (optional). If provided, latents and outputs will be masked
- Returns:
result – a named tuple with the following values rec: torch.Tensor
the reconstruction
- latent: torch.Tensor
the latent space sample
- mean: torch.Tensor
the mean of the latent representation
- log_var: torch.Tensor
the logarithm of the variance of the latent representation
- Return type:
- class speechbrain.nnet.autoencoders.VariationalAutoencoderOutput(rec, latent, mean, log_var, latent_sample, latent_length)
Bases:
tuple
- latent
Alias for field number 1
- latent_length
Alias for field number 5
- latent_sample
Alias for field number 4
- log_var
Alias for field number 3
- mean
Alias for field number 2
- rec
Alias for field number 0
- class speechbrain.nnet.autoencoders.AutoencoderOutput(rec, latent, latent_length)
Bases:
tuple
- latent
Alias for field number 1
- latent_length
Alias for field number 2
- rec
Alias for field number 0
- class speechbrain.nnet.autoencoders.NormalizingAutoencoder(encoder, decoder, latent_padding=None, norm=None, len_dim=1, mask_out=True, mask_latent=True, out_mask_value=0.0, latent_mask_value=0.0)[source]
Bases:
Autoencoder
A classical (non-variational) autoencoder that does not use reparameterization but instead uses an ordinary normalization technique to constrain the latent space
- Parameters:
encoder (torch.nn.Module) – the encoder to be used
decoder (torch.nn.Module) – the decoder to be used
norm (torch.nn.Module) – the normalization module
mask_latent (bool) – where to apply the length mask to the latent representation
mask_out (bool) – whether to apply the length mask to the output
out_mask_value (float) – the mask value used for the output
Examples
>>> import torch >>> from torch import nn >>> from speechbrain.nnet.linear import Linear >>> ae_enc = Linear(n_neurons=16, input_size=128) >>> ae_dec = Linear(n_neurons=128, input_size=16) >>> ae = NormalizingAutoencoder( ... encoder=ae_enc, ... decoder=ae_dec, ... ) >>> x = torch.randn(4, 10, 128) >>> x_enc = ae.encode(x) >>> x_enc.shape torch.Size([4, 10, 16]) >>> x_dec = ae.decode(x_enc) >>> x_dec.shape torch.Size([4, 10, 128])
- encode(x, length=None)[source]
Converts a sample from an original space (e.g. pixel or waveform) to a latent space
- Parameters:
x (torch.Tensor) – the original data representation
- Returns:
latent – the latent representation
- Return type:
- decode(latent)[source]
Decodes the sample from a latent repsresentation
- Parameters:
latent (torch.Tensor) – the latent representation
- Returns:
result – the decoded sample
- Return type:
- train_sample(x, length=None, out_mask_value=None, latent_mask_value=None)[source]
Provides a data sample for training the autoencoder
- Parameters:
x (torch.Tensor) – the source data (in the sample space)
length (None) – the length (optional). If provided, latents and outputs will be masked
- Returns:
result – a named tuple with the following values rec: torch.Tensor
the reconstruction
- latent: torch.Tensor
the latent space sample
- Return type: