"""Wide ResNet for Speech Enhancement.
Author
* Peter Plantinga 2022
"""
import torch
import speechbrain as sb
from speechbrain.processing.features import ISTFT, STFT, spectral_magnitude
[docs]
class EnhanceResnet(torch.nn.Module):
"""Model for enhancement based on Wide ResNet.
Full model description at: https://arxiv.org/pdf/2112.06068.pdf
Arguments
---------
n_fft : int
Number of points in the fourier transform, see ``speechbrain.processing.features.STFT``
win_length : int
Length of stft window in ms, see ``speechbrain.processing.features.STFT``
hop_length : int
Time between windows in ms, see ``speechbrain.processing.features.STFT``
sample_rate : int
Number of samples per second of input audio.
channel_counts : list of ints
Number of output channels in each CNN block. Determines number of blocks.
dense_count : int
Number of dense layers.
dense_nodes : int
Number of nodes in the dense layers.
activation : function
Function to apply before convolution layers.
normalization : class
Name of class to use for constructing norm layers.
dropout : float
Portion of layer outputs to drop during training (between 0 and 1).
mask_weight : float
Amount of weight to give mask. 0 - no masking, 1 - full masking.
Example
-------
>>> inputs = torch.rand([10, 16000])
>>> model = EnhanceResnet()
>>> outputs, feats = model(inputs)
>>> outputs.shape
torch.Size([10, 15872])
>>> feats.shape
torch.Size([10, 63, 257])
"""
def __init__(
self,
n_fft=512,
win_length=32,
hop_length=16,
sample_rate=16000,
channel_counts=[128, 128, 256, 256, 512, 512],
dense_count=2,
dense_nodes=1024,
activation=torch.nn.GELU(),
normalization=sb.nnet.normalization.BatchNorm2d,
dropout=0.1,
mask_weight=0.99,
):
super().__init__()
self.mask_weight = mask_weight
# First, convert time-domain to log spectral magnitude inputs
self.stft = STFT(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
sample_rate=sample_rate,
)
# CNN takes log spectral mag inputs
self.CNN = sb.nnet.containers.Sequential(
input_shape=[None, None, n_fft // 2 + 1]
)
for channel_count in channel_counts:
self.CNN.append(
ConvBlock,
channels=channel_count,
activation=activation,
normalization=normalization,
dropout=dropout,
)
# Fully connected layers
self.DNN = sb.nnet.containers.Sequential(
input_shape=self.CNN.get_output_shape()
)
for _ in range(dense_count):
self.DNN.append(
sb.nnet.linear.Linear,
n_neurons=dense_nodes,
combine_dims=True,
)
self.DNN.append(activation)
self.DNN.append(sb.nnet.normalization.LayerNorm)
self.DNN.append(torch.nn.Dropout(p=dropout))
# Output layer produces real mask that is applied to complex inputs
self.DNN.append(sb.nnet.linear.Linear, n_neurons=n_fft // 2 + 1)
# Convert back to time domain
self.istft = ISTFT(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
sample_rate=sample_rate,
)
[docs]
def forward(self, x):
"""Processes the input tensor and outputs the enhanced speech."""
# Generate features
noisy_spec = self.stft(x)
log_mag = self.extract_feats(noisy_spec)
# Generate mask
mask = self.DNN(self.CNN(log_mag))
mask = mask.clamp(min=0, max=1).unsqueeze(-1)
# Apply mask
masked_spec = self.mask_weight * mask * noisy_spec
masked_spec += (1 - self.mask_weight) * noisy_spec
# Extract feats for loss computation
enhanced_features = self.extract_feats(masked_spec)
# Return resynthesized waveform
return self.istft(masked_spec), enhanced_features
[docs]
class ConvBlock(torch.nn.Module):
"""Convolution block, including squeeze-and-excitation.
Arguments
---------
input_shape : tuple of ints
The expected size of the inputs.
channels : int
Number of output channels.
activation : function
Function applied before each block.
normalization : class
Name of a class to use for constructing norm layers.
dropout : float
Portion of block outputs to drop during training.
Example
-------
>>> inputs = torch.rand([10, 20, 30, 128])
>>> block = ConvBlock(input_shape=inputs.shape, channels=256)
>>> outputs = block(inputs)
>>> outputs.shape
torch.Size([10, 20, 15, 256])
"""
def __init__(
self,
input_shape,
channels,
activation=torch.nn.GELU(),
normalization=sb.nnet.normalization.LayerNorm,
dropout=0.1,
):
super().__init__()
self.activation = activation
self.downsample = sb.nnet.CNN.Conv2d(
input_shape=input_shape,
out_channels=channels,
kernel_size=3,
stride=(2, 1),
)
self.conv1 = sb.nnet.CNN.Conv2d(
in_channels=channels, out_channels=channels, kernel_size=3
)
self.norm1 = normalization(input_size=channels)
self.conv2 = sb.nnet.CNN.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=3,
)
self.norm2 = normalization(input_size=channels)
self.dropout = sb.nnet.dropout.Dropout2d(drop_rate=dropout)
self.se_block = SEblock(input_size=channels)
[docs]
def forward(self, x):
"""Processes the input tensor with a convolutional block."""
x = self.downsample(x)
residual = self.activation(x)
residual = self.norm1(residual)
residual = self.dropout(residual)
residual = self.conv1(residual)
residual = self.activation(residual)
residual = self.norm2(residual)
residual = self.dropout(residual)
residual = self.conv2(residual)
residual *= self.se_block(residual)
return x + residual
[docs]
class SEblock(torch.nn.Module):
"""Squeeze-and-excitation block.
Defined: https://arxiv.org/abs/1709.01507
Arguments
---------
input_size : tuple of ints
Expected size of the input tensor
Example
-------
>>> inputs = torch.rand([10, 20, 30, 256])
>>> se_block = SEblock(input_size=inputs.shape[-1])
>>> outputs = se_block(inputs)
>>> outputs.shape
torch.Size([10, 1, 1, 256])
"""
def __init__(self, input_size):
super().__init__()
self.linear1 = sb.nnet.linear.Linear(
input_size=input_size, n_neurons=input_size
)
self.linear2 = sb.nnet.linear.Linear(
input_size=input_size, n_neurons=input_size
)
[docs]
def forward(self, x):
"""Processes the input tensor with a squeeze-and-excite block."""
# torch.mean causes weird inplace error
# x = torch.mean(x, dim=(1, 2), keepdim=True)
count = x.size(1) * x.size(2)
x = torch.sum(x, dim=(1, 2), keepdim=True) / count
x = self.linear1(x)
x = torch.nn.functional.relu(x)
x = self.linear2(x)
return torch.sigmoid(x)