speechbrain.nnet.diffusion moduleο
An implementation of Denoising Diffusion
https://arxiv.org/pdf/2006.11239.pdf
Certain parts adopted from / inspired by denoising-diffusion-pytorch https://github.com/lucidrains/denoising-diffusion-pytorch
- Authors
Artem Ploujnikov 2022
Summaryο
Classes:
An implementation of a classic Denoising Diffusion Probabilistic Model (DDPM) |
|
A base diffusion implementation |
|
Adds ordinary Gaussian noise |
|
A latent diffusion wrapper. |
|
Gaussian noise applied to padded samples. |
Functions:
Returns a random sample of timesteps as a 1-D tensor (one dimension only) |
Referenceο
- class speechbrain.nnet.diffusion.Diffuser(model, timesteps, noise=None)[source]ο
Bases:
Module
A base diffusion implementation
- Parameters:
model (nn.Module) β the underlying model
timesteps (int) β the number of timesteps
noise (callable|str) β
the noise function/module to use
The following predefined types of noise are provided βgaussianβ: Gaussian noise, applied to the whole sample βlength_masked_gaussianβ: Gaussian noise applied only
to the parts of the sample that is not padding
- distort(x, timesteps=None)[source]ο
Adds noise to a batch of data
- Parameters:
x (torch.Tensor) β the original data sample
timesteps (torch.Tensor) β a 1-D integer tensor of a length equal to the number of batches in x, where each entry corresponds to the timestep number for the batch. If omitted, timesteps will be randomly sampled
- train_sample(x, timesteps=None, condition=None, **kwargs)[source]ο
Creates a sample for the training loop with a corresponding target
- Parameters:
x (torch.Tensor) β the original data sample
timesteps (torch.Tensor) β a 1-D integer tensor of a length equal to the number of batches in x, where each entry corresponds to the timestep number for the batch. If omitted, timesteps will be randomly sampled
condition (torch.Tensor) β the condition used for conditional generation Should be omitted during unconditional generation
**kwargs (dict) β Arguments to forward to the underlying model.
- Returns:
pred (torch.Tensor) β the model output 0 predicted noise
noise (torch.Tensor) β the noise being applied
noisy_sample (torch.Tensor) β the sample with the noise applied
- class speechbrain.nnet.diffusion.DenoisingDiffusion(model, timesteps=None, noise=None, beta_start=None, beta_end=None, sample_min=None, sample_max=None, show_progress=False)[source]ο
Bases:
Diffuser
An implementation of a classic Denoising Diffusion Probabilistic Model (DDPM)
- Parameters:
model (nn.Module) β the underlying model
timesteps (int) β the number of timesteps
noise (str|nn.Module) β the type of noise being used βgaussianβ will produce standard Gaussian noise
beta_start (float) β the value of the βbetaβ parameter at the beginning at the end of the process (see the paper)
beta_end (float) β the value of the βbetaβ parameter at the end of the process
sample_min (float)
sample_max (float) β Used to clip the output.
show_progress (bool) β whether to show progress during inference
Example
>>> from speechbrain.nnet.unet import UNetModel >>> unet = UNetModel( ... in_channels=1, ... model_channels=16, ... norm_num_groups=4, ... out_channels=1, ... num_res_blocks=1, ... attention_resolutions=[] ... ) >>> diff = DenoisingDiffusion( ... model=unet, ... timesteps=5 ... ) >>> x = torch.randn(4, 1, 64, 64) >>> pred, noise, noisy_sample = diff.train_sample(x) >>> pred.shape torch.Size([4, 1, 64, 64]) >>> noise.shape torch.Size([4, 1, 64, 64]) >>> noisy_sample.shape torch.Size([4, 1, 64, 64]) >>> sample = diff.sample((2, 1, 64, 64)) >>> sample.shape torch.Size([2, 1, 64, 64])
- distort(x, noise=None, timesteps=None, **kwargs)[source]ο
Adds noise to the sample, in a forward diffusion process,
- Parameters:
x (torch.Tensor) β a data sample of 2 or more dimensions, with the first dimension representing the batch
noise (torch.Tensor) β the noise to add
timesteps (torch.Tensor) β a 1-D integer tensor of a length equal to the number of batches in x, where each entry corresponds to the timestep number for the batch. If omitted, timesteps will be randomly sampled
**kwargs (dict) β Arguments to forward to the underlying model.
- Returns:
result β a tensor of the same dimension as x
- Return type:
torch.Tensor
- sample(shape, **kwargs)[source]ο
Generates the number of samples indicated by the count parameter
- Parameters:
shape (enumerable) β the shape of the sample to generate
**kwargs (dict) β Arguments to forward to the underlying model.
- Returns:
result β the generated sample(s)
- Return type:
torch.Tensor
- class speechbrain.nnet.diffusion.LatentDiffusion(autoencoder, diffusion, latent_downsample_factor=None, latent_pad_dim=1)[source]ο
Bases:
Module
A latent diffusion wrapper. Latent diffusion is denoising diffusion applied to a latent space instead of the original data space
- Parameters:
autoencoder (speechbrain.nnet.autoencoders.Autoencoder) β An autoencoder converting the original space to a latent space
diffusion (speechbrain.nnet.diffusion.Diffuser) β A diffusion wrapper
latent_downsample_factor (int) β The factor that latent space dimensions need to be divisible by. This is useful if the underlying model for the diffusion wrapper is based on a UNet-like architecture where the inputs are progressively downsampled and upsampled by factors of two
latent_pad_dim (int|list[int]) β the dimension(s) along which the latent space will be padded
Example
>>> import torch >>> from torch import nn >>> from speechbrain.nnet.CNN import Conv2d >>> from speechbrain.nnet.autoencoders import NormalizingAutoencoder >>> from speechbrain.nnet.unet import UNetModel
Set up a simple autoencoder (a real autoencoder would be a deep neural network)
>>> ae_enc = Conv2d( ... kernel_size=3, ... stride=4, ... in_channels=1, ... out_channels=1, ... skip_transpose=True, ... ) >>> ae_dec = nn.ConvTranspose2d( ... kernel_size=3, ... stride=4, ... in_channels=1, ... out_channels=1, ... output_padding=1 ... ) >>> ae = NormalizingAutoencoder( ... encoder=ae_enc, ... decoder=ae_dec, ... )
Construct a diffusion model with a UNet architecture
>>> unet = UNetModel( ... in_channels=1, ... model_channels=16, ... norm_num_groups=4, ... out_channels=1, ... num_res_blocks=1, ... attention_resolutions=[] ... ) >>> diff = DenoisingDiffusion( ... model=unet, ... timesteps=5 ... ) >>> latent_diff = LatentDiffusion( ... autoencoder=ae, ... diffusion=diff, ... latent_downsample_factor=4, ... latent_pad_dim=2 ... ) >>> x = torch.randn(4, 1, 64, 64) >>> latent_sample = latent_diff.train_sample_latent(x) >>> diff_sample, ae_sample = latent_sample >>> pred, noise, noisy_sample = diff_sample >>> pred.shape torch.Size([4, 1, 16, 16]) >>> noise.shape torch.Size([4, 1, 16, 16]) >>> noisy_sample.shape torch.Size([4, 1, 16, 16]) >>> ae_sample.latent.shape torch.Size([4, 1, 16, 16])
Create a few samples (the shape given should be the shape of the latent space)
>>> sample = latent_diff.sample((2, 1, 16, 16)) >>> sample.shape torch.Size([2, 1, 64, 64])
- train_sample(x, **kwargs)[source]ο
Creates a sample for the training loop with a corresponding target
- Parameters:
x (torch.Tensor) β the original data sample
**kwargs (dict) β Arguments to forward to the underlying model.
- Returns:
pred (torch.Tensor) β the model output 0 predicted noise
noise (torch.Tensor) β the noise being applied
noisy_sample β the sample with the noise applied
- train_sample_latent(x, **kwargs)[source]ο
Returns a train sample with autoencoder output - can be used to jointly training the diffusion model and the autoencoder
- Parameters:
x (torch.Tensor) β the original data sample
**kwargs (dict) β Arguments to forward to the underlying model.
- Returns:
Training sample.
- Return type:
- speechbrain.nnet.diffusion.sample_timesteps(x, num_timesteps)[source]ο
Returns a random sample of timesteps as a 1-D tensor (one dimension only)
- Parameters:
x (torch.Tensor) β a tensor of samples of any dimension
num_timesteps (int) β the total number of timesteps
- Return type:
Random sample of timestamps.
- class speechbrain.nnet.diffusion.GaussianNoise(*args, **kwargs)[source]ο
Bases:
Module
Adds ordinary Gaussian noise
- class speechbrain.nnet.diffusion.LengthMaskedGaussianNoise(length_dim=1)[source]ο
Bases:
Module
Gaussian noise applied to padded samples. No noise is added to positions that are part of padding
- Parameters:
length_dim (int) β The time dimension for which lengths apply.
- forward(sample, length=None, **kwargs)[source]ο
Creates Gaussian noise. If a tensor of lengths is provided, no noise is added to the padding positions.
- Parameters:
sample (torch.Tensor) β a batch of data
length (torch.Tensor) β relative lengths
**kwargs (dict) β Arguments to forward to the underlying model.
- Return type:
Gaussian noise in shape of sample.