""" Implementation of a popular speech separation model.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import speechbrain as sb
from speechbrain.processing.signal_processing import overlap_and_add
EPS = 1e-8
[docs]
class Encoder(nn.Module):
"""This class learns the adaptive frontend for the ConvTasnet model.
Arguments
---------
L : int
The filter kernel size. Needs to be an odd number.
N : int
Number of dimensions at the output of the adaptive front end.
Example
-------
>>> inp = torch.rand(10, 100)
>>> encoder = Encoder(11, 20)
>>> h = encoder(inp)
>>> h.shape
torch.Size([10, 20, 20])
"""
def __init__(self, L, N):
super().__init__()
# 50% overlap
self.conv1d_U = sb.nnet.CNN.Conv1d(
in_channels=1,
out_channels=N,
kernel_size=L,
stride=L // 2,
bias=False,
)
[docs]
def forward(self, mixture):
"""
Arguments
---------
mixture : torch.Tensor
Tensor shape is [M, T]. M is batch size. T is #samples
Returns
-------
mixture_w : torch.Tensor
Tensor shape is [M, K, N], where K = (T-L)/(L/2)+1 = 2T/L-1
"""
mixture = torch.unsqueeze(mixture, -1) # [M, T, 1]
conv_out = self.conv1d_U(mixture)
mixture_w = F.relu(conv_out) # [M, K, N]
return mixture_w
[docs]
class Decoder(nn.Module):
"""This class implements the decoder for the ConvTasnet.
The separated source embeddings are fed to the decoder to reconstruct
the estimated sources in the time domain.
Arguments
---------
L : int
Number of bases to use when reconstructing.
N : int
Input size
Example
-------
>>> L, C, N = 8, 2, 8
>>> mixture_w = torch.randn(10, 100, N)
>>> est_mask = torch.randn(10, 100, C, N)
>>> Decoder = Decoder(L, N)
>>> mixture_hat = Decoder(mixture_w, est_mask)
>>> mixture_hat.shape
torch.Size([10, 404, 2])
"""
def __init__(self, L, N):
super().__init__()
# Hyper-parameter
self.L = L
# Components
self.basis_signals = sb.nnet.linear.Linear(
input_size=N, n_neurons=L, bias=False
)
[docs]
def forward(self, mixture_w, est_mask):
"""
Arguments
---------
mixture_w : torch.Tensor
Tensor shape is [M, K, N].
est_mask : torch.Tensor
Tensor shape is [M, K, C, N].
Returns
-------
est_source : torch.Tensor
Tensor shape is [M, T, C].
"""
# D = W * M
source_w = (
torch.unsqueeze(mixture_w, 2).repeat(1, 1, est_mask.size(2), 1)
* est_mask
) # [M, K, C, N]
source_w = source_w.permute(0, 2, 1, 3) # [M, C, K, N]
# S = DV
est_source = self.basis_signals(source_w) # [M, C, K, L]
est_source = overlap_and_add(est_source, self.L // 2) # M x C x T
return est_source.permute(0, 2, 1) # M x T x C
[docs]
class TemporalBlocksSequential(sb.nnet.containers.Sequential):
"""
A wrapper for the temporal-block layer to replicate it
Arguments
---------
input_shape : tuple
Expected shape of the input.
H : int
The number of intermediate channels.
P : int
The kernel size in the convolutions.
R : int
The number of times to replicate the multilayer Temporal Blocks.
X : int
The number of layers of Temporal Blocks with different dilations.
norm_type : str
The type of normalization, in ['gLN', 'cLN'].
causal : bool
To use causal or non-causal convolutions, in [True, False].
Example
-------
>>> x = torch.randn(14, 100, 10)
>>> H, P, R, X = 10, 5, 2, 3
>>> TemporalBlocks = TemporalBlocksSequential(
... x.shape, H, P, R, X, 'gLN', False
... )
>>> y = TemporalBlocks(x)
>>> y.shape
torch.Size([14, 100, 10])
"""
def __init__(self, input_shape, H, P, R, X, norm_type, causal):
super().__init__(input_shape=input_shape)
for r in range(R):
for x in range(X):
dilation = 2**x
self.append(
TemporalBlock,
out_channels=H,
kernel_size=P,
stride=1,
padding="same",
dilation=dilation,
norm_type=norm_type,
causal=causal,
layer_name=f"temporalblock_{r}_{x}",
)
[docs]
class MaskNet(nn.Module):
"""
Arguments
---------
N : int
Number of filters in autoencoder.
B : int
Number of channels in bottleneck 1 × 1-conv block.
H : int
Number of channels in convolutional blocks.
P : int
Kernel size in convolutional blocks.
X : int
Number of convolutional blocks in each repeat.
R : int
Number of repeats.
C : int
Number of speakers.
norm_type : str
One of BN, gLN, cLN.
causal : bool
Causal or non-causal.
mask_nonlinear : str
Use which non-linear function to generate mask, in ['softmax', 'relu'].
Example
-------
>>> N, B, H, P, X, R, C = 11, 12, 2, 5, 3, 1, 2
>>> MaskNet = MaskNet(N, B, H, P, X, R, C)
>>> mixture_w = torch.randn(10, 11, 100)
>>> est_mask = MaskNet(mixture_w)
>>> est_mask.shape
torch.Size([2, 10, 11, 100])
"""
def __init__(
self,
N,
B,
H,
P,
X,
R,
C,
norm_type="gLN",
causal=False,
mask_nonlinear="relu",
):
super().__init__()
# Hyper-parameter
self.C = C
self.mask_nonlinear = mask_nonlinear
# Components
# [M, K, N] -> [M, K, N]
self.layer_norm = ChannelwiseLayerNorm(N)
# [M, K, N] -> [M, K, B]
self.bottleneck_conv1x1 = sb.nnet.CNN.Conv1d(
in_channels=N,
out_channels=B,
kernel_size=1,
bias=False,
)
# [M, K, B] -> [M, K, B]
in_shape = (None, None, B)
self.temporal_conv_net = TemporalBlocksSequential(
in_shape, H, P, R, X, norm_type, causal
)
# [M, K, B] -> [M, K, C*N]
self.mask_conv1x1 = sb.nnet.CNN.Conv1d(
in_channels=B, out_channels=C * N, kernel_size=1, bias=False
)
[docs]
def forward(self, mixture_w):
"""Keep this API same with TasNet.
Arguments
---------
mixture_w : torch.Tensor
Tensor shape is [M, K, N], M is batch size.
Returns
-------
est_mask : torch.Tensor
Tensor shape is [M, K, C, N].
"""
mixture_w = mixture_w.permute(0, 2, 1)
M, K, N = mixture_w.size()
y = self.layer_norm(mixture_w)
y = self.bottleneck_conv1x1(y)
y = self.temporal_conv_net(y)
score = self.mask_conv1x1(y)
# score = self.network(mixture_w) # [M, K, N] -> [M, K, C*N]
score = score.contiguous().reshape(
M, K, self.C, N
) # [M, K, C*N] -> [M, K, C, N]
# [M, K, C, N] -> [C, M, N, K]
score = score.permute(2, 0, 3, 1)
if self.mask_nonlinear == "softmax":
est_mask = F.softmax(score, dim=2)
elif self.mask_nonlinear == "relu":
est_mask = F.relu(score)
else:
raise ValueError("Unsupported mask non-linear function")
return est_mask
[docs]
class TemporalBlock(torch.nn.Module):
"""The conv1d compound layers used in Masknet.
Arguments
---------
input_shape : tuple
The expected shape of the input.
out_channels : int
The number of intermediate channels.
kernel_size : int
The kernel size in the convolutions.
stride : int
Convolution stride in convolutional layers.
padding : str
The type of padding in the convolutional layers,
(same, valid, causal). If "valid", no padding is performed.
dilation : int
Amount of dilation in convolutional layers.
norm_type : str
The type of normalization, in ['gLN', 'cLN'].
causal : bool
To use causal or non-causal convolutions, in [True, False].
Example
-------
>>> x = torch.randn(14, 100, 10)
>>> TemporalBlock = TemporalBlock(x.shape, 10, 11, 1, 'same', 1)
>>> y = TemporalBlock(x)
>>> y.shape
torch.Size([14, 100, 10])
"""
def __init__(
self,
input_shape,
out_channels,
kernel_size,
stride,
padding,
dilation,
norm_type="gLN",
causal=False,
):
super().__init__()
M, K, B = input_shape
self.layers = sb.nnet.containers.Sequential(input_shape=input_shape)
# [M, K, B] -> [M, K, H]
self.layers.append(
sb.nnet.CNN.Conv1d,
out_channels=out_channels,
kernel_size=1,
bias=False,
layer_name="conv",
)
self.layers.append(nn.PReLU(), layer_name="act")
self.layers.append(
choose_norm(norm_type, out_channels), layer_name="norm"
)
# [M, K, H] -> [M, K, B]
self.layers.append(
DepthwiseSeparableConv,
out_channels=B,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
norm_type=norm_type,
causal=causal,
layer_name="DSconv",
)
[docs]
def forward(self, x):
"""
Arguments
---------
x : torch.Tensor
Tensor shape is [M, K, B].
Returns
-------
x : torch.Tensor
Tensor shape is [M, K, B].
"""
residual = x
x = self.layers(x)
return x + residual
[docs]
class DepthwiseSeparableConv(sb.nnet.containers.Sequential):
"""Building block for the Temporal Blocks of Masknet in ConvTasNet.
Arguments
---------
input_shape : tuple
Expected shape of the input.
out_channels : int
Number of output channels.
kernel_size : int
The kernel size in the convolutions.
stride : int
Convolution stride in convolutional layers.
padding : str
The type of padding in the convolutional layers,
(same, valid, causal). If "valid", no padding is performed.
dilation : int
Amount of dilation in convolutional layers.
norm_type : str
The type of normalization, in ['gLN', 'cLN'].
causal : bool
To use causal or non-causal convolutions, in [True, False].
Example
-------
>>> x = torch.randn(14, 100, 10)
>>> DSconv = DepthwiseSeparableConv(x.shape, 10, 11, 1, 'same', 1)
>>> y = DSconv(x)
>>> y.shape
torch.Size([14, 100, 10])
"""
def __init__(
self,
input_shape,
out_channels,
kernel_size,
stride,
padding,
dilation,
norm_type="gLN",
causal=False,
):
super().__init__(input_shape=input_shape)
batchsize, time, in_channels = input_shape
# [M, K, H] -> [M, K, H]
if causal:
paddingval = dilation * (kernel_size - 1)
padding = "causal"
default_padding = "same"
else:
default_padding = 0
self.append(
sb.nnet.CNN.Conv1d,
out_channels=in_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels,
bias=False,
layer_name="conv_0",
default_padding=default_padding,
)
if causal:
self.append(Chomp1d(paddingval), layer_name="chomp")
self.append(nn.PReLU(), layer_name="act")
self.append(choose_norm(norm_type, in_channels), layer_name="act")
# [M, K, H] -> [M, K, B]
self.append(
sb.nnet.CNN.Conv1d,
out_channels=out_channels,
kernel_size=1,
bias=False,
layer_name="conv_1",
)
[docs]
class Chomp1d(nn.Module):
"""This class cuts out a portion of the signal from the end.
It is written as a class to be able to incorporate it inside a sequential
wrapper.
Arguments
---------
chomp_size : int
The size of the portion to discard (in samples).
Example
-------
>>> x = torch.randn(10, 110, 5)
>>> chomp = Chomp1d(10)
>>> x_chomped = chomp(x)
>>> x_chomped.shape
torch.Size([10, 100, 5])
"""
def __init__(self, chomp_size):
super().__init__()
self.chomp_size = chomp_size
[docs]
def forward(self, x):
"""
Arguments
---------
x : torch.Tensor
Tensor shape is [M, Kpad, H].
Returns
-------
x : torch.Tensor
Tensor shape is [M, K, H].
"""
return x[:, : -self.chomp_size, :].contiguous()
[docs]
def choose_norm(norm_type, channel_size):
"""This function returns the chosen normalization type.
Arguments
---------
norm_type : str
One of ['gLN', 'cLN', 'batchnorm'].
channel_size : int
Number of channels.
Returns
-------
Constructed layer of the chosen type
Example
-------
>>> choose_norm('gLN', 10)
GlobalLayerNorm()
"""
if norm_type == "gLN":
return GlobalLayerNorm(channel_size)
elif norm_type == "cLN":
return ChannelwiseLayerNorm(channel_size)
else:
return nn.BatchNorm1d(channel_size)
[docs]
class ChannelwiseLayerNorm(nn.Module):
"""Channel-wise Layer Normalization (cLN).
Arguments
---------
channel_size : int
Number of channels in the normalization dimension (the third dimension).
Example
-------
>>> x = torch.randn(2, 3, 3)
>>> norm_func = ChannelwiseLayerNorm(3)
>>> x_normalized = norm_func(x)
>>> x.shape
torch.Size([2, 3, 3])
"""
def __init__(self, channel_size):
super().__init__()
self.gamma = nn.Parameter(torch.Tensor(1, 1, channel_size)) # [1, 1, N]
self.beta = nn.Parameter(torch.Tensor(1, 1, channel_size)) # [1, 1, N]
self.reset_parameters()
[docs]
def reset_parameters(self):
"""Resets the parameters."""
self.gamma.data.fill_(1)
self.beta.data.zero_()
[docs]
def forward(self, y):
"""
Args:
y: [M, K, N], M is batch size, N is channel size, K is length
Returns:
cLN_y: [M, K, N]
"""
mean = torch.mean(y, dim=2, keepdim=True) # [M, K, 1]
var = torch.var(y, dim=2, keepdim=True, unbiased=False) # [M, K, 1]
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
return cLN_y
[docs]
class GlobalLayerNorm(nn.Module):
"""Global Layer Normalization (gLN).
Arguments
---------
channel_size : int
Number of channels in the third dimension.
Example
-------
>>> x = torch.randn(2, 3, 3)
>>> norm_func = GlobalLayerNorm(3)
>>> x_normalized = norm_func(x)
>>> x.shape
torch.Size([2, 3, 3])
"""
def __init__(self, channel_size):
super().__init__()
self.gamma = nn.Parameter(torch.Tensor(1, 1, channel_size)) # [1, 1, N]
self.beta = nn.Parameter(torch.Tensor(1, 1, channel_size)) # [1, 1, N]
self.reset_parameters()
[docs]
def reset_parameters(self):
"""Resets the parameters."""
self.gamma.data.fill_(1)
self.beta.data.zero_()
[docs]
def forward(self, y):
"""
Arguments
---------
y : torch.Tensor
Tensor shape [M, K, N]. M is batch size, N is channel size, and K is length.
Returns
-------
gLN_y : torch.Tensor
Tensor shape [M, K. N]
"""
mean = y.mean(dim=1, keepdim=True).mean(
dim=2, keepdim=True
) # [M, 1, 1]
var = (
(torch.pow(y - mean, 2))
.mean(dim=1, keepdim=True)
.mean(dim=2, keepdim=True)
)
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
return gLN_y