"""
Neural network modules for DIFFWAVE:
A VERSATILE DIFFUSION MODEL FOR AUDIO SYNTHESIS
For more details: https://arxiv.org/pdf/2009.09761.pdf
Authors
* Yingzhi WANG 2022
"""
# This code uses a significant portion of the LMNT implementation, even though it
# has been modified and enhanced
# https://github.com/lmnt-com/diffwave/blob/master/src/diffwave/model.py
# *****************************************************************************
# Copyright 2020 LMNT, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from math import sqrt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio import transforms
from speechbrain.nnet import linear
from speechbrain.nnet.CNN import Conv1d
from speechbrain.nnet.diffusion import DenoisingDiffusion
Linear = linear.Linear
ConvTranspose2d = nn.ConvTranspose2d
@torch.jit.script
def silu(x):
"""sigmoid linear unit activation function"""
return x * torch.sigmoid(x)
[docs]
def diffwave_mel_spectogram(
sample_rate,
hop_length,
win_length,
n_fft,
n_mels,
f_min,
f_max,
power,
normalized,
norm,
mel_scale,
audio,
):
"""calculates MelSpectrogram for a raw audio signal
and preprocesses it for diffwave training
Arguments
---------
sample_rate : int
Sample rate of audio signal.
hop_length : int
Length of hop between STFT windows.
win_length : int
Window size.
n_fft : int
Size of FFT.
n_mels : int
Number of mel filterbanks.
f_min : float
Minimum frequency.
f_max : float
Maximum frequency.
power : float
Exponent for the magnitude spectrogram.
normalized : bool
Whether to normalize by magnitude after stft.
norm : str or None
If "slaney", divide the triangular mel weights by the width of the mel band
mel_scale : str
Scale to use: "htk" or "slaney".
audio : torch.tensor
input audio signal
Returns
-------
mel : torch.Tensor
"""
audio_to_mel = transforms.MelSpectrogram(
sample_rate=sample_rate,
hop_length=hop_length,
win_length=win_length,
n_fft=n_fft,
n_mels=n_mels,
f_min=f_min,
f_max=f_max,
power=power,
normalized=normalized,
norm=norm,
mel_scale=mel_scale,
).to(audio.device)
mel = audio_to_mel(torch.clamp(audio, -1.0, 1.0))
mel = 20 * torch.log10(torch.clamp(mel, min=1e-5)) - 20
mel = torch.clamp((mel + 100) / 100, 0.0, 1.0)
return mel
[docs]
class DiffusionEmbedding(nn.Module):
"""Embeds the diffusion step into an input vector of DiffWave
Arguments
---------
max_steps: int
total diffusion steps
Example
-------
>>> from speechbrain.lobes.models.DiffWave import DiffusionEmbedding
>>> diffusion_embedding = DiffusionEmbedding(max_steps=50)
>>> time_step = torch.randint(50, (1,))
>>> step_embedding = diffusion_embedding(time_step)
>>> step_embedding.shape
torch.Size([1, 512])
"""
def __init__(self, max_steps):
super().__init__()
self.register_buffer(
"embedding", self._build_embedding(max_steps), persistent=False
)
self.projection1 = Linear(input_size=128, n_neurons=512)
self.projection2 = Linear(input_size=512, n_neurons=512)
[docs]
def forward(self, diffusion_step):
"""forward function of diffusion step embedding
Arguments
---------
diffusion_step: torch.Tensor
which step of diffusion to execute
Returns
-------
diffusion step embedding: tensor [bs, 512]
"""
if diffusion_step.dtype in [torch.int32, torch.int64]:
x = self.embedding[diffusion_step]
else:
x = self._lerp_embedding(diffusion_step)
x = self.projection1(x)
x = silu(x)
x = self.projection2(x)
x = silu(x)
return x
def _lerp_embedding(self, t):
"""Deals with the cases where diffusion_step is not int
Arguments
---------
t: torch.Tensor
which step of diffusion to execute
Returns
-------
embedding : torch.Tensor
"""
low_idx = torch.floor(t).long()
high_idx = torch.ceil(t).long()
low = self.embedding[low_idx]
high = self.embedding[high_idx]
return low + (high - low) * (t - low_idx)
def _build_embedding(self, max_steps):
"""Build embeddings in a designed way
Arguments
---------
max_steps: int
total diffusion steps
Returns
-------
table: torch.Tensor
"""
steps = torch.arange(max_steps).unsqueeze(1) # [T,1]
dims = torch.arange(64).unsqueeze(0) # [1,64]
table = steps * 10.0 ** (dims * 4.0 / 63.0) # [T,64]
table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)
return table
[docs]
class SpectrogramUpsampler(nn.Module):
"""Upsampler for spectrograms with Transposed Conv
Only the upsampling is done here, the layer-specific Conv can be found
in residual block to map the mel bands into 2× residual channels
Example
-------
>>> from speechbrain.lobes.models.DiffWave import SpectrogramUpsampler
>>> spec_upsampler = SpectrogramUpsampler()
>>> mel_input = torch.rand(3, 80, 100)
>>> upsampled_mel = spec_upsampler(mel_input)
>>> upsampled_mel.shape
torch.Size([3, 80, 25600])
"""
def __init__(self):
super().__init__()
self.conv1 = ConvTranspose2d(
1, 1, [3, 32], stride=[1, 16], padding=[1, 8]
)
self.conv2 = ConvTranspose2d(
1, 1, [3, 32], stride=[1, 16], padding=[1, 8]
)
[docs]
def forward(self, x):
"""Upsamples spectrograms 256 times to match the length of audios
Hop length should be 256 when extracting mel spectrograms
Arguments
---------
x: torch.Tensor
input mel spectrogram [bs, 80, mel_len]
Returns
-------
upsampled spectrogram [bs, 80, mel_len*256]
"""
x = torch.unsqueeze(x, 1)
x = self.conv1(x)
x = F.leaky_relu(x, 0.4)
x = self.conv2(x)
x = F.leaky_relu(x, 0.4)
x = torch.squeeze(x, 1)
return x
[docs]
class ResidualBlock(nn.Module):
"""
Residual Block with dilated convolution
Arguments
---------
n_mels: int
input mel channels of conv1x1 for conditional vocoding task
residual_channels: int
channels of audio convolution
dilation: int
dilation cycles of audio convolution
uncond: bool
conditional/unconditional generation
Example
-------
>>> from speechbrain.lobes.models.DiffWave import ResidualBlock
>>> res_block = ResidualBlock(n_mels=80, residual_channels=64, dilation=3)
>>> noisy_audio = torch.randn(1, 1, 22050)
>>> timestep_embedding = torch.rand(1, 512)
>>> upsampled_mel = torch.rand(1, 80, 22050)
>>> output = res_block(noisy_audio, timestep_embedding, upsampled_mel)
>>> output[0].shape
torch.Size([1, 64, 22050])
"""
def __init__(self, n_mels, residual_channels, dilation, uncond=False):
super().__init__()
self.dilated_conv = Conv1d(
in_channels=residual_channels,
out_channels=2 * residual_channels,
kernel_size=3,
dilation=dilation,
skip_transpose=True,
padding="same",
conv_init="kaiming",
)
self.diffusion_projection = Linear(
input_size=512, n_neurons=residual_channels
)
# conditional model
if not uncond:
self.conditioner_projection = Conv1d(
in_channels=n_mels,
out_channels=2 * residual_channels,
kernel_size=1,
skip_transpose=True,
padding="same",
conv_init="kaiming",
)
# unconditional model
else:
self.conditioner_projection = None
self.output_projection = Conv1d(
in_channels=residual_channels,
out_channels=2 * residual_channels,
kernel_size=1,
skip_transpose=True,
padding="same",
conv_init="kaiming",
)
[docs]
def forward(self, x, diffusion_step, conditioner=None):
"""
forward function of Residual Block
Arguments
---------
x: torch.Tensor
input sample [bs, 1, time]
diffusion_step: torch.Tensor
the embedding of which step of diffusion to execute
conditioner: torch.Tensor
the condition used for conditional generation
Returns
-------
residual output [bs, residual_channels, time]
a skip of residual branch [bs, residual_channels, time]
"""
assert (
conditioner is None and self.conditioner_projection is None
) or (
conditioner is not None and self.conditioner_projection is not None
)
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
y = x + diffusion_step
if self.conditioner_projection is None: # using a unconditional model
y = self.dilated_conv(y)
else:
conditioner = self.conditioner_projection(conditioner)
# for inference make sure that they have the same length
# conditioner = conditioner[:, :, y.shape[-1]]
y = self.dilated_conv(y) + conditioner
gate, filter = torch.chunk(y, 2, dim=1)
y = torch.sigmoid(gate) * torch.tanh(filter)
y = self.output_projection(y)
residual, skip = torch.chunk(y, 2, dim=1)
return (x + residual) / sqrt(2.0), skip
[docs]
class DiffWave(nn.Module):
"""
DiffWave Model with dilated residual blocks
Arguments
---------
input_channels: int
input mel channels of conv1x1 for conditional vocoding task
residual_layers: int
number of residual blocks
residual_channels: int
channels of audio convolution
dilation_cycle_length: int
dilation cycles of audio convolution
total_steps: int
total steps of diffusion
unconditional: bool
conditional/unconditional generation
Example
-------
>>> from speechbrain.lobes.models.DiffWave import DiffWave
>>> diffwave = DiffWave(
... input_channels=80,
... residual_layers=30,
... residual_channels=64,
... dilation_cycle_length=10,
... total_steps=50,
... )
>>> noisy_audio = torch.randn(1, 1, 25600)
>>> timestep = torch.randint(50, (1,))
>>> input_mel = torch.rand(1, 80, 100)
>>> predicted_noise = diffwave(noisy_audio, timestep, input_mel)
>>> predicted_noise.shape
torch.Size([1, 1, 25600])
"""
def __init__(
self,
input_channels,
residual_layers,
residual_channels,
dilation_cycle_length,
total_steps,
unconditional=False,
):
super().__init__()
self.input_channels = input_channels
self.residual_layers = residual_layers
self.residual_channels = residual_channels
self.dilation_cycle_length = dilation_cycle_length
self.unconditional = unconditional
self.total_steps = total_steps
self.input_projection = Conv1d(
in_channels=1,
out_channels=self.residual_channels,
kernel_size=1,
skip_transpose=True,
padding="same",
conv_init="kaiming",
)
self.diffusion_embedding = DiffusionEmbedding(self.total_steps)
if self.unconditional: # use unconditional model
self.spectrogram_upsampler = None
else:
self.spectrogram_upsampler = SpectrogramUpsampler()
self.residual_layers = nn.ModuleList(
[
ResidualBlock(
self.input_channels,
self.residual_channels,
2 ** (i % self.dilation_cycle_length),
uncond=self.unconditional,
)
for i in range(self.residual_layers)
]
)
self.skip_projection = Conv1d(
in_channels=self.residual_channels,
out_channels=self.residual_channels,
kernel_size=1,
skip_transpose=True,
padding="same",
conv_init="kaiming",
)
self.output_projection = Conv1d(
in_channels=self.residual_channels,
out_channels=1,
kernel_size=1,
skip_transpose=True,
padding="same",
conv_init="zero",
)
[docs]
def forward(self, audio, diffusion_step, spectrogram=None, length=None):
"""
DiffWave forward function
Arguments
---------
audio: torch.Tensor
input gaussian sample [bs, 1, time]
diffusion_step: torch.Tensor
which timestep of diffusion to execute [bs, 1]
spectrogram: torch.Tensor
spectrogram data [bs, 80, mel_len]
length: torch.Tensor
sample lengths - not used - provided for compatibility only
Returns
-------
predicted noise [bs, 1, time]
"""
assert (spectrogram is None and self.spectrogram_upsampler is None) or (
spectrogram is not None and self.spectrogram_upsampler is not None
)
x = self.input_projection(audio)
x = F.relu(x)
diffusion_step = self.diffusion_embedding(diffusion_step)
if self.spectrogram_upsampler: # use conditional model
spectrogram = self.spectrogram_upsampler(spectrogram)
skip = None
for layer in self.residual_layers:
x, skip_connection = layer(x, diffusion_step, spectrogram)
skip = skip_connection if skip is None else skip_connection + skip
x = skip / sqrt(len(self.residual_layers))
x = self.skip_projection(x)
x = F.relu(x)
x = self.output_projection(x)
return x
[docs]
def diffusion_forward(
self,
x,
timesteps,
cond_emb=None,
length=None,
out_mask_value=None, # unused for diffwave
latent_mask_value=None, # unused for diffwave
):
"""Forward function suitable for wrapping by diffusion.
For this model, `out_mask_value`/`latent_mask_value` are unused
and discarded.
See :meth:`~DiffWave.forward` for details."""
return self(x, timesteps, spectrogram=cond_emb, length=length)
[docs]
class DiffWaveDiffusion(DenoisingDiffusion):
"""An enhanced diffusion implementation with DiffWave-specific inference
Arguments
---------
model: nn.Module
the underlying model
timesteps: int
the total number of timesteps
noise: str|nn.Module
the type of noise being used
"gaussian" will produce standard Gaussian noise
beta_start: float
the value of the "beta" parameter at the beginning of the process
(see DiffWave paper)
beta_end: float
the value of the "beta" parameter at the end of the process
sample_min: float
sample_max: float
Used to clip the output.
show_progress: bool
whether to show progress during inference
Example
-------
>>> from speechbrain.lobes.models.DiffWave import DiffWave
>>> diffwave = DiffWave(
... input_channels=80,
... residual_layers=30,
... residual_channels=64,
... dilation_cycle_length=10,
... total_steps=50,
... )
>>> from speechbrain.lobes.models.DiffWave import DiffWaveDiffusion
>>> from speechbrain.nnet.diffusion import GaussianNoise
>>> diffusion = DiffWaveDiffusion(
... model=diffwave,
... beta_start=0.0001,
... beta_end=0.05,
... timesteps=50,
... noise=GaussianNoise,
... )
>>> input_mel = torch.rand(1, 80, 100)
>>> output = diffusion.inference(
... unconditional=False,
... scale=256,
... condition=input_mel,
... fast_sampling=True,
... fast_sampling_noise_schedule=[0.0001, 0.001, 0.01, 0.05, 0.2, 0.5],
... )
>>> output.shape
torch.Size([1, 25600])
"""
def __init__(
self,
model,
timesteps=None,
noise=None,
beta_start=None,
beta_end=None,
sample_min=None,
sample_max=None,
show_progress=False,
):
super().__init__(
model,
timesteps,
noise,
beta_start,
beta_end,
sample_min,
sample_max,
show_progress,
)
[docs]
@torch.no_grad()
def inference(
self,
unconditional,
scale,
condition=None,
fast_sampling=False,
fast_sampling_noise_schedule=None,
device=None,
):
"""Processes the inference for diffwave
One inference function for all the locally/globally conditional
generation and unconditional generation tasks
Arguments
---------
unconditional: bool
do unconditional generation if True, else do conditional generation
scale: int
scale to get the final output wave length
for conditional generation, the output wave length is scale * condition.shape[-1]
for example, if the condition is spectrogram (bs, n_mel, time), scale should be hop length
for unconditional generation, scale should be the desired audio length
condition: torch.Tensor
input spectrogram for vocoding or other conditions for other
conditional generation, should be None for unconditional generation
fast_sampling: bool
whether to do fast sampling
fast_sampling_noise_schedule: list
the noise schedules used for fast sampling
device: str|torch.device
inference device
Returns
-------
predicted_sample: torch.Tensor
the predicted audio (bs, 1, t)
"""
if device is None:
device = torch.device("cuda")
# either condition or uncondition
if unconditional:
assert condition is None
else:
assert condition is not None
device = condition.device
# must define fast_sampling_noise_schedule during fast sampling
if fast_sampling:
assert fast_sampling_noise_schedule is not None
if fast_sampling and fast_sampling_noise_schedule is not None:
inference_noise_schedule = fast_sampling_noise_schedule
inference_alphas = 1 - torch.tensor(inference_noise_schedule)
inference_alpha_cum = inference_alphas.cumprod(dim=0)
else:
inference_noise_schedule = self.betas
inference_alphas = self.alphas
inference_alpha_cum = self.alphas_cumprod
inference_steps = []
for s in range(len(inference_noise_schedule)):
for t in range(self.timesteps - 1):
if (
self.alphas_cumprod[t + 1]
<= inference_alpha_cum[s]
<= self.alphas_cumprod[t]
):
twiddle = (
self.alphas_cumprod[t] ** 0.5
- inference_alpha_cum[s] ** 0.5
) / (
self.alphas_cumprod[t] ** 0.5
- self.alphas_cumprod[t + 1] ** 0.5
)
inference_steps.append(t + twiddle)
break
if not unconditional:
if (
len(condition.shape) == 2
): # Expand rank 2 tensors by adding a batch dimension.
condition = condition.unsqueeze(0)
audio = torch.randn(
condition.shape[0], scale * condition.shape[-1], device=device
)
else:
audio = torch.randn(1, scale, device=device)
# noise_scale = torch.from_numpy(alpha_cum**0.5).float().unsqueeze(1).to(device)
for n in range(len(inference_alphas) - 1, -1, -1):
c1 = 1 / inference_alphas[n] ** 0.5
c2 = (
inference_noise_schedule[n]
/ (1 - inference_alpha_cum[n]) ** 0.5
)
# predict noise
noise_pred = self.model(
audio,
torch.tensor([inference_steps[n]], device=device),
condition,
).squeeze(1)
# mean
audio = c1 * (audio - c2 * noise_pred)
# add variance
if n > 0:
noise = torch.randn_like(audio)
sigma = (
(1.0 - inference_alpha_cum[n - 1])
/ (1.0 - inference_alpha_cum[n])
* inference_noise_schedule[n]
) ** 0.5
audio += sigma * noise
audio = torch.clamp(audio, -1.0, 1.0)
return audio