Source code for speechbrain.nnet.diffusion

"""An implementation of Denoising Diffusion

https://arxiv.org/pdf/2006.11239.pdf

Certain parts adopted from / inspired by denoising-diffusion-pytorch
https://github.com/lucidrains/denoising-diffusion-pytorch

Authors
 * Artem Ploujnikov 2022
"""

from collections import namedtuple
import torch
from torch import nn
from torch.nn import functional as F
from tqdm.auto import tqdm
from speechbrain.utils.data_utils import unsqueeze_as
from speechbrain.dataio.dataio import length_to_mask
from speechbrain.utils import data_utils


[docs]class Diffuser(nn.Module): """A base diffusion implementation Arguments --------- model: nn.Module the underlying model timesteps: int the number of timesteps noise: callable|str the noise function/module to use The following predefined types of noise are provided "gaussian": Gaussian noise, applied to the whole sample "length_masked_gaussian": Gaussian noise applied only to the parts of the sample that is not padding """ def __init__(self, model, timesteps, noise=None): super().__init__() self.model = model self.timesteps = timesteps if noise is None: noise = "gaussian" if isinstance(noise, str): self.noise = _NOISE_FUNCTIONS[noise]() else: self.noise = noise
[docs] def distort(self, x, timesteps=None): """Adds noise to a batch of data Arguments --------- x: torch.Tensor the original data sample timesteps: torch.Tensor a 1-D integer tensor of a length equal to the number of batches in x, where each entry corresponds to the timestep number for the batch. If omitted, timesteps will be randomly sampled Returns ------- result: torch.Tensor a tensor of the same dimension as x """ raise NotImplementedError
[docs] def train_sample(self, x, timesteps=None, condition=None, **kwargs): """Creates a sample for the training loop with a corresponding target Arguments --------- x: torch.Tensor the original data sample timesteps: torch.Tensor a 1-D integer tensor of a length equal to the number of batches in x, where each entry corresponds to the timestep number for the batch. If omitted, timesteps will be randomly sampled condition: torch.tensor the condition used for conditional generation Should be omitted during unconditional generation Returns ------- pred: torch.Tensor the model output 0 prdicted noise noise: torch.Tensor the noise being applied noisy_sample the sample with the noise applied """ if timesteps is None: timesteps = sample_timesteps(x, self.timesteps) noisy_sample, noise = self.distort(x, timesteps=timesteps, **kwargs) # in case that certain models do not have any condition as input if condition is None: pred = self.model(noisy_sample, timesteps, **kwargs) else: pred = self.model(noisy_sample, timesteps, condition, **kwargs) return pred, noise, noisy_sample
[docs] def sample(self, shape, **kwargs): """Generates the number of samples indicated by the count parameter Arguments --------- shape: enumerable the shape of the sample to generate Returns ------- result: torch.Tensor the generated sample(s) """ raise NotImplementedError
[docs] def forward(self, x, timesteps=None): """Computes the forward pass, calls distort() """ return self.distort(x, timesteps)
DDPM_DEFAULT_BETA_START = 0.0001 DDPM_DEFAULT_BETA_END = 0.02 DDPM_REF_TIMESTEPS = 1000 DESC_SAMPLING = "Diffusion Sampling"
[docs]class DenoisingDiffusion(Diffuser): """An implementation of a classic Denoising Diffusion Probabilistic Model (DDPM) Arguments --------- model: nn.Module the underlying model timesteps: int the number of timesteps noise: str|nn.Module the type of noise being used "gaussian" will produce standard Gaussian noise beta_start: float the value of the "beta" parameter at the beginning at the end of the process (see the paper) beta_end: float the value of the "beta" parameter at the end of the process show_progress: bool whether to show progress during inference Example ------- >>> from speechbrain.nnet.unet import UNetModel >>> unet = UNetModel( ... in_channels=1, ... model_channels=16, ... norm_num_groups=4, ... out_channels=1, ... num_res_blocks=1, ... attention_resolutions=[] ... ) >>> diff = DenoisingDiffusion( ... model=unet, ... timesteps=5 ... ) >>> x = torch.randn(4, 1, 64, 64) >>> pred, noise, noisy_sample = diff.train_sample(x) >>> pred.shape torch.Size([4, 1, 64, 64]) >>> noise.shape torch.Size([4, 1, 64, 64]) >>> noisy_sample.shape torch.Size([4, 1, 64, 64]) >>> sample = diff.sample((2, 1, 64, 64)) >>> sample.shape torch.Size([2, 1, 64, 64]) """ def __init__( self, model, timesteps=None, noise=None, beta_start=None, beta_end=None, sample_min=None, sample_max=None, show_progress=False, ): if timesteps is None: timesteps = DDPM_REF_TIMESTEPS super().__init__(model, timesteps=timesteps, noise=noise) if beta_start is None or beta_end is None: scale = DDPM_REF_TIMESTEPS / timesteps if beta_start is None: beta_start = scale * DDPM_DEFAULT_BETA_START if beta_end is None: beta_end = scale * DDPM_DEFAULT_BETA_END self.beta_start = beta_start self.beta_end = beta_end alphas, betas = self.compute_coefficients() self.register_buffer("alphas", alphas) self.register_buffer("betas", betas) alphas_cumprod = self.alphas.cumprod(dim=0) self.register_buffer("alphas_cumprod", alphas_cumprod) signal_coefficients = torch.sqrt(alphas_cumprod) noise_coefficients = torch.sqrt(1.0 - alphas_cumprod) self.register_buffer("signal_coefficients", signal_coefficients) self.register_buffer("noise_coefficients", noise_coefficients) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) posterior_variance = ( betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) ) self.register_buffer("posterior_variance", posterior_variance) self.register_buffer("posterior_log_variance", posterior_variance.log()) posterior_mean_weight_start = ( betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) ) posterior_mean_weight_step = ( (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) ) self.register_buffer( "posterior_mean_weight_start", posterior_mean_weight_start ) self.register_buffer( "posterior_mean_weight_step", posterior_mean_weight_step ) sample_pred_model_coefficient = (1.0 / alphas_cumprod).sqrt() self.register_buffer( "sample_pred_model_coefficient", sample_pred_model_coefficient ) sample_pred_noise_coefficient = (1.0 / alphas_cumprod - 1).sqrt() self.register_buffer( "sample_pred_noise_coefficient", sample_pred_noise_coefficient ) self.sample_min = sample_min self.sample_max = sample_max self.show_progress = show_progress
[docs] def compute_coefficients(self): """Computes diffusion coefficients (alphas and betas)""" betas = torch.linspace(self.beta_start, self.beta_end, self.timesteps) alphas = 1.0 - betas return alphas, betas
[docs] def distort(self, x, noise=None, timesteps=None, **kwargs): """Adds noise to the sample, in a forward diffusion process, Arguments --------- x: torch.Tensor a data sample of 2 or more dimensions, with the first dimension representing the batch noise: torch.Tensor the noise to add timesteps: torch.Tensor a 1-D integer tensor of a length equal to the number of batches in x, where each entry corresponds to the timestep number for the batch. If omitted, timesteps will be randomly sampled Returns ------- result: torch.Tensor a tensor of the same dimension as x """ if timesteps is None: timesteps = sample_timesteps(x, self.timesteps) if noise is None: noise = self.noise(x, **kwargs) signal_coefficients = self.signal_coefficients[timesteps] noise_coefficients = self.noise_coefficients[timesteps] noisy_sample = ( unsqueeze_as(signal_coefficients, x) * x + unsqueeze_as(noise_coefficients, noise) * noise ) return noisy_sample, noise
[docs] @torch.no_grad() def sample(self, shape, **kwargs): """Generates the number of samples indicated by the count parameter Arguments --------- shape: enumerable the shape of the sample to generate Returns ------- result: torch.Tensor the generated sample(s) """ sample = self.noise(torch.zeros(*shape, device=self.alphas.device)) steps = reversed(range(self.timesteps)) if self.show_progress: steps = tqdm(steps, desc=DESC_SAMPLING, total=self.timesteps) for timestep_number in steps: timestep = ( torch.ones( shape[0], dtype=torch.long, device=self.alphas.device ) * timestep_number ) sample = self.sample_step(sample, timestep, **kwargs) return sample
[docs] @torch.no_grad() def sample_step(self, sample, timestep, **kwargs): """Processes a single timestep for the sampling process Arguments --------- sample: torch.Tensor the sample for the following timestep timestep: int the timestep number Arguments --------- predicted_sample: torch.Tensor the predicted sample (denoised by one step`) """ model_out = self.model(sample, timestep, **kwargs) noise = self.noise(sample) sample_start = ( unsqueeze_as(self.sample_pred_model_coefficient[timestep], sample) * sample - unsqueeze_as( self.sample_pred_noise_coefficient[timestep], model_out ) * model_out ) weight_start = unsqueeze_as( self.posterior_mean_weight_start[timestep], sample_start ) weight_step = unsqueeze_as( self.posterior_mean_weight_step[timestep], sample ) mean = weight_start * sample_start + weight_step * sample log_variance = unsqueeze_as( self.posterior_log_variance[timestep], noise ) predicted_sample = mean + (0.5 * log_variance).exp() * noise if self.sample_min is not None or self.sample_max is not None: predicted_sample.clip_(min=self.sample_min, max=self.sample_max) return predicted_sample
[docs]class LatentDiffusion(nn.Module): """A latent diffusion wrapper. Latent diffusion is denoising diffusion applied to a latent space instead of the original data space Arguments --------- autoencoder: speechbrain.nnet.autoencoders.Autoencoder An autoencoder converting the original space to a latent space diffusion: speechbrian.nnet.diffusion.Diffuser A diffusion wrapper latent_downsample_factor: int The factor that latent space dimensions need to be divisible by. This is useful if the underlying model for the diffusion wrapper is based on a UNet-like architecture where the inputs are progressively downsampled and upsampled by factors of two latent_pad_dims: int|list[int] the dimension(s) along which the latent space will be padded Example ------- >>> import torch >>> from torch import nn >>> from speechbrain.nnet.CNN import Conv2d >>> from speechbrain.nnet.autoencoders import NormalizingAutoencoder >>> from speechbrain.nnet.unet import UNetModel Set up a simple autoencoder (a real autoencoder would be a deep neural network) >>> ae_enc = Conv2d( ... kernel_size=3, ... stride=4, ... in_channels=1, ... out_channels=1, ... skip_transpose=True, ... ) >>> ae_dec = nn.ConvTranspose2d( ... kernel_size=3, ... stride=4, ... in_channels=1, ... out_channels=1, ... output_padding=1 ... ) >>> ae = NormalizingAutoencoder( ... encoder=ae_enc, ... decoder=ae_dec, ... ) Construct a diffusion model with a UNet architecture >>> unet = UNetModel( ... in_channels=1, ... model_channels=16, ... norm_num_groups=4, ... out_channels=1, ... num_res_blocks=1, ... attention_resolutions=[] ... ) >>> diff = DenoisingDiffusion( ... model=unet, ... timesteps=5 ... ) >>> latent_diff = LatentDiffusion( ... autoencoder=ae, ... diffusion=diff, ... latent_downsample_factor=4, ... latent_pad_dim=2 ... ) >>> x = torch.randn(4, 1, 64, 64) >>> latent_sample = latent_diff.train_sample_latent(x) >>> diff_sample, ae_sample = latent_sample >>> pred, noise, noisy_sample = diff_sample >>> pred.shape torch.Size([4, 1, 16, 16]) >>> noise.shape torch.Size([4, 1, 16, 16]) >>> noisy_sample.shape torch.Size([4, 1, 16, 16]) >>> ae_sample.latent.shape torch.Size([4, 1, 16, 16]) Create a few samples (the shape given should be the shape of the latent space) >>> sample = latent_diff.sample((2, 1, 16, 16)) >>> sample.shape torch.Size([2, 1, 64, 64]) """ def __init__( self, autoencoder, diffusion, latent_downsample_factor=None, latent_pad_dim=1, ): super().__init__() self.autencoder = autoencoder self.diffusion = diffusion self.latent_downsample_factor = latent_downsample_factor if isinstance(latent_pad_dim, int): latent_pad_dim = [latent_pad_dim] self.latent_pad_dim = latent_pad_dim
[docs] def train_sample(self, x, **kwargs): """Creates a sample for the training loop with a corresponding target Arguments --------- x: torch.Tensor the original data sample timesteps: torch.Tensor a 1-D integer tensor of a length equal to the number of batches in x, where each entry corresponds to the timestep number for the batch. If omitted, timesteps will be randomly sampled Returns ------- pred: torch.Tensor the model output 0 prdicted noise noise: torch.Tensor the noise being applied noisy_sample the sample with the noise applied """ latent = self.autoencoder.encode(x) latent = self._pad_latent(latent) return self.diffusion.train_sample(latent, **kwargs)
def _pad_latent(self, latent): """Pads the latent space to the desired dimension Arguments --------- latent: torch.Tensor the latent representation Returns ------- result: torch.Tensor the latent representation, with padding""" # TODO: Check whether masking will need to be adjusted if ( self.latent_downsample_factor is not None and self.latent_downsample_factor > 1 ): for dim in self.latent_pad_dim: latent, _ = data_utils.pad_divisible( latent, factor=self.latent_downsample_factor, len_dim=dim ) return latent
[docs] def train_sample_latent(self, x, **kwargs): """Returns a train sample with autoencoder output - can be used to jointly training the diffusion model and the autoencoder Arguments --------- x: torch.Tensor the original data sample """ # TODO: Make this generic length = kwargs.get("length") out_mask_value = kwargs.get("out_mask_value") latent_mask_value = kwargs.get("latent_mask_value") autoencoder_out = self.autencoder.train_sample( x, length=length, out_mask_value=out_mask_value, latent_mask_value=latent_mask_value, ) latent = self._pad_latent(autoencoder_out.latent) diffusion_train_sample = self.diffusion.train_sample(latent, **kwargs) return LatentDiffusionTrainSample( diffusion=diffusion_train_sample, autoencoder=autoencoder_out )
[docs] def distort(self, x): """Adds noise to the sample, in a forward diffusion process, Arguments --------- x: torch.Tensor a data sample of 2 or more dimensions, with the first dimension representing the batch noise: torch.Tensor the noise to add timesteps: torch.Tensor a 1-D integer tensor of a length equal to the number of batches in x, where each entry corresponds to the timestep number for the batch. If omitted, timesteps will be randomly sampled Returns ------- result: torch.Tensor a tensor of the same dimension as x """ latent = self.autencoder.encode(x) return self.diffusion.distort(latent)
[docs] def sample(self, shape): """Obtains a sample out of the diffusion model Arguments --------- shape: torch.Tensor Returns ------- sample: torch.Tensor the sample of the specified shape """ # TODO: Auto-compute the latent shape latent = self.diffusion.sample(shape) latent = self._pad_latent(latent) return self.autencoder.decode(latent)
[docs]def sample_timesteps(x, num_timesteps): """Returns a random sample of timesteps as a 1-D tensor (one dimension only) Arguments --------- x: torch.Tensor a tensor of samples of any dimension num_timesteps: int the total number of timesteps""" return torch.randint(num_timesteps, (x.size(0),), device=x.device)
[docs]class GaussianNoise(nn.Module): """Adds ordinary Gaussian noise"""
[docs] def forward(self, sample, **kwargs): """Forward pass Arguments --------- sample: the original sample """ return torch.randn_like(sample)
[docs]class LengthMaskedGaussianNoise(nn.Module): """Gaussian noise applied to padded samples. No noise is added to positions that are part of padding Arguments --------- length_dim: int the """ def __init__(self, length_dim=1): super().__init__() self.length_dim = length_dim
[docs] def forward(self, sample, length=None, **kwargs): """Creates Gaussian noise. If a tensor of lengths is provided, no noise is added to the padding positions. sample: torch.Tensor a batch of data length: torch.Tensor relative lengths """ noise = torch.randn_like(sample) if length is not None: max_len = sample.size(self.length_dim) mask = length_to_mask(length * max_len, max_len).bool() mask_shape = self._compute_mask_shape(noise, max_len) mask = mask.view(mask_shape) noise.masked_fill_(~mask, 0.0) return noise
def _compute_mask_shape(self, noise, max_len): return ( (noise.shape[0],) + ((1,) * (self.length_dim - 1)) # Between the batch and len_dim + (max_len,) + ((1,) * (noise.dim() - 3)) # Unsqueeze at the end )
_NOISE_FUNCTIONS = { "gaussian": GaussianNoise, "length_masked_gaussian": LengthMaskedGaussianNoise, } DiffusionTrainSample = namedtuple( "DiffusionTrainSample", ["pred", "noise", "noisy_sample"] ) LatentDiffusionTrainSample = namedtuple( "LatentDiffusionTrainSample", ["diffusion", "autoencoder"] )