"""Library for the Resource-Efficient Sepformer.
Authors
* Cem Subakan 2022
"""
import copy
import torch
import torch.nn as nn
import speechbrain.nnet.RNN as SBRNN
from speechbrain.lobes.models.dual_path import select_norm
from speechbrain.lobes.models.transformer.Transformer import (
PositionalEncoding,
TransformerEncoder,
get_lookahead_mask,
)
EPS = torch.finfo(torch.get_default_dtype()).eps
[docs]
class MemLSTM(nn.Module):
"""the Mem-LSTM of SkiM --
Note: This is taken from the SkiM implementation in ESPNet toolkit and modified for compatibility with SpeechBrain.
Arguments
---------
hidden_size: int
Dimension of the hidden state.
dropout: float
dropout ratio. Default is 0.
bidirectional: bool
Whether the LSTM layers are bidirectional.
Default is False.
mem_type: str
'hc', 'h', 'c', or 'id'
This controls whether the hidden (or cell) state of
SegLSTM will be processed by MemLSTM.
In 'id' mode, both the hidden and cell states will
be identically returned.
norm_type: str
'gln', 'cln'
This selects the type of normalization
cln is for causal implementation
Example
-------
>>> x = (torch.randn(1, 5, 64), torch.randn(1, 5, 64))
>>> block = MemLSTM(64)
>>> x = block(x, 5)
>>> x[0].shape
torch.Size([1, 5, 64])
"""
def __init__(
self,
hidden_size,
dropout=0.0,
bidirectional=False,
mem_type="hc",
norm_type="cln",
):
super().__init__()
self.hidden_size = hidden_size
self.bidirectional = bidirectional
self.input_size = (int(bidirectional) + 1) * hidden_size
self.mem_type = mem_type
assert mem_type in [
"hc",
"h",
"c",
"id",
], f"only support 'hc', 'h', 'c' and 'id', current type: {mem_type}"
if mem_type in ["hc", "h"]:
self.h_net = SBRNNBlock(
input_size=self.input_size,
hidden_channels=self.hidden_size,
num_layers=1,
outsize=self.input_size,
rnn_type="LSTM",
dropout=dropout,
bidirectional=bidirectional,
)
self.h_norm = select_norm(
norm=norm_type, dim=self.input_size, shape=3, eps=EPS
)
if mem_type in ["hc", "c"]:
self.c_net = SBRNNBlock(
input_size=self.input_size,
hidden_channels=self.hidden_size,
num_layers=1,
outsize=self.input_size,
rnn_type="LSTM",
dropout=dropout,
bidirectional=bidirectional,
)
self.c_norm = select_norm(
norm=norm_type, dim=self.input_size, shape=3, eps=EPS
)
[docs]
def forward(self, hc, S):
"""The forward function for the memory RNN
Arguments
---------
hc : tuple
(h, c), tuple of hidden and cell states from SegLSTM
shape of h and c: (d, B*S, H)
where d is the number of directions
B is the batchsize
S is the number chunks
H is the latent dimensionality
S : int
S is the number of chunks
Returns
-------
ret_val : torch.Tensor
The output of memory RNN
"""
if self.mem_type == "id":
ret_val = hc
else:
h, c = hc
d, BS, H = h.shape
B = BS // S
h = h.transpose(1, 0).contiguous().view(B, S, d * H) # B, S, dH
c = c.transpose(1, 0).contiguous().view(B, S, d * H) # B, S, dH
if self.mem_type == "hc":
h = h + self.h_norm(self.h_net(h).permute(0, 2, 1)).permute(
0, 2, 1
)
c = c + self.c_norm(self.c_net(c).permute(0, 2, 1)).permute(
0, 2, 1
)
elif self.mem_type == "h":
h = h + self.h_norm(self.h_net(h).permute(0, 2, 1)).permute(
0, 2, 1
)
c = torch.zeros_like(c)
elif self.mem_type == "c":
h = torch.zeros_like(h)
c = c + self.c_norm(self.c_net(c).permute(0, 2, 1)).permute(
0, 2, 1
)
h = h.view(B * S, d, H).transpose(1, 0).contiguous()
c = c.view(B * S, d, H).transpose(1, 0).contiguous()
ret_val = (h, c)
if not self.bidirectional:
# for causal setup
causal_ret_val = []
for x in ret_val:
x_ = torch.zeros_like(x)
x_[:, 1:, :] = x[:, :-1, :]
causal_ret_val.append(x_)
ret_val = tuple(causal_ret_val)
return ret_val
[docs]
class SegLSTM(nn.Module):
"""the Segment-LSTM of SkiM
Note: This is taken from the SkiM implementation in ESPNet toolkit and modified for compatibility with SpeechBrain.
Arguments
---------
input_size: int,
dimension of the input feature.
The input should have shape (batch, seq_len, input_size).
hidden_size: int,
dimension of the hidden state.
dropout: float,
dropout ratio. Default is 0.
bidirectional: bool,
whether the LSTM layers are bidirectional.
Default is False.
norm_type: str
One of gln, cln.
This selects the type of normalization
cln is for causal implementation.
Example
-------
>>> x = torch.randn(3, 20, 64)
>>> hc = None
>>> seglstm = SegLSTM(64, 64)
>>> y = seglstm(x, hc)
>>> y[0].shape
torch.Size([3, 20, 64])
"""
def __init__(
self,
input_size,
hidden_size,
dropout=0.0,
bidirectional=False,
norm_type="cLN",
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_direction = int(bidirectional) + 1
self.lstm = nn.LSTM(
input_size,
hidden_size,
1,
batch_first=True,
bidirectional=bidirectional,
)
self.dropout = nn.Dropout(p=dropout)
self.proj = nn.Linear(hidden_size * self.num_direction, input_size)
self.norm = select_norm(
norm=norm_type, dim=input_size, shape=3, eps=EPS
)
[docs]
def forward(self, input, hc):
"""The forward function of the Segment LSTM
Arguments
---------
input : torch.Tensor
shape [B*S, T, H]
where B is the batchsize
S is the number of chunks
T is the chunks size
H is the latent dimensionality
hc : tuple
tuple of hidden and cell states from SegLSTM
shape of h and c: (d, B*S, H)
where d is the number of directions
B is the batchsize
S is the number chunks
H is the latent dimensionality
Returns
-------
output: torch.Tensor
Output of Segment LSTM
(h, c): tuple
Same as hc input
"""
B, T, H = input.shape
if hc is None:
# In fist input SkiM block, h and c are not available
d = self.num_direction
h = torch.zeros(d, B, self.hidden_size).to(input.device)
c = torch.zeros(d, B, self.hidden_size).to(input.device)
else:
h, c = hc
output, (h, c) = self.lstm(input, (h, c))
output = self.dropout(output)
output = self.proj(output.contiguous().view(-1, output.shape[2])).view(
input.shape
)
output_norm = self.norm(output.permute(0, 2, 1)).permute(0, 2, 1)
output = input + output_norm
return output, (h, c)
[docs]
class SBRNNBlock(nn.Module):
"""RNNBlock with output layer.
Arguments
---------
input_size : int
Dimensionality of the input features.
hidden_channels : int
Dimensionality of the latent layer of the rnn.
num_layers : int
Number of the rnn layers.
outsize : int
Number of dimensions at the output of the linear layer
rnn_type : str
Type of the the rnn cell.
dropout : float
Dropout rate
bidirectional : bool
If True, bidirectional.
Example
-------
>>> x = torch.randn(10, 100, 64)
>>> rnn = SBRNNBlock(64, 100, 1, 128, bidirectional=True)
>>> x = rnn(x)
>>> x.shape
torch.Size([10, 100, 128])
"""
def __init__(
self,
input_size,
hidden_channels,
num_layers,
outsize,
rnn_type="LSTM",
dropout=0,
bidirectional=True,
):
super().__init__()
self.mdl = getattr(SBRNN, rnn_type)(
hidden_channels,
input_size=input_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
)
rnn_outsize = 2 * hidden_channels if bidirectional else hidden_channels
self.out = nn.Linear(rnn_outsize, outsize)
[docs]
def forward(self, x):
"""Returns the transformed output.
Arguments
---------
x : torch.Tensor
[B, L, N]
where, B = Batchsize,
N = number of filters
L = time points
Returns
-------
out : torch.Tensor
The transformed output.
"""
rnn_out = self.mdl(x)[0]
out = self.out(rnn_out)
return out
[docs]
class ResourceEfficientSeparationPipeline(nn.Module):
"""Resource Efficient Separation Pipeline Used for RE-SepFormer and SkiM
Note: This implementation is a generalization of the ESPNET implementation of SkiM
Arguments
---------
input_size: int
Dimension of the input feature.
Input shape should be (batch, length, input_size)
hidden_size: int
Dimension of the hidden state.
output_size: int
Dimension of the output size.
dropout: float
Dropout ratio. Default is 0.
num_blocks: int
Number of basic SkiM blocks
segment_size: int
Segmentation size for splitting long features
bidirectional: bool
Whether the RNN layers are bidirectional.
mem_type: str
'hc', 'h', 'c', 'id' or None.
This controls whether the hidden (or cell) state of SegLSTM
will be processed by MemLSTM.
In 'id' mode, both the hidden and cell states will
be identically returned.
When mem_type is None, the MemLSTM will be removed.
norm_type: str
One of gln or cln
cln is for causal implementation.
seg_model: class
The model that processes the within segment elements
mem_model: class
The memory model that ensures continuity between the segments
Example
-------
>>> x = torch.randn(10, 100, 64)
>>> seg_mdl = SBTransformerBlock_wnormandskip(1, 64, 8)
>>> mem_mdl = SBTransformerBlock_wnormandskip(1, 64, 8)
>>> resepf_pipeline = ResourceEfficientSeparationPipeline(64, 64, 128, seg_model=seg_mdl, mem_model=mem_mdl)
>>> out = resepf_pipeline.forward(x)
>>> out.shape
torch.Size([10, 100, 128])
"""
def __init__(
self,
input_size,
hidden_size,
output_size,
dropout=0.0,
num_blocks=2,
segment_size=20,
bidirectional=True,
mem_type="av",
norm_type="gln",
seg_model=None,
mem_model=None,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
self.segment_size = segment_size
self.dropout = dropout
self.num_blocks = num_blocks
self.mem_type = mem_type
self.norm_type = norm_type
assert mem_type in [
"hc",
"h",
"c",
"id",
"av",
None,
], f"only support 'hc', 'h', 'c', 'id', 'av' and None, current type: {mem_type}"
self.seg_model = nn.ModuleList([])
for i in range(num_blocks):
self.seg_model.append(copy.deepcopy(seg_model))
if self.mem_type is not None:
self.mem_model = nn.ModuleList([])
for i in range(num_blocks - 1):
self.mem_model.append(copy.deepcopy(mem_model))
self.output_fc = nn.Sequential(
nn.PReLU(), nn.Conv1d(input_size, output_size, 1)
)
[docs]
def forward(self, input):
"""The forward function of the ResourceEfficientSeparationPipeline
This takes in a tensor of size [B, (S*K), D]
Arguments
---------
input : torch.Tensor
Tensor shape [B, (S*K), D],
where, B = Batchsize,
S = Number of chunks
K = Chunksize
D = number of features
Returns
-------
output : torch.Tensor
The separated tensor.
"""
B, T, D = input.shape
input, rest = self._padfeature(input=input)
input = input.view(B, -1, self.segment_size, D) # B, S, K, D
B, S, K, D = input.shape
assert K == self.segment_size
output = input.reshape(B * S, K, D) # BS, K, D
if self.mem_type == "av":
hc = torch.zeros(
output.shape[0], 1, output.shape[-1], device=output.device
)
else:
hc = None
for i in range(self.num_blocks):
seg_model_type = type(self.seg_model[0]).__name__
if seg_model_type == "SBTransformerBlock_wnormandskip":
output = self.seg_model[i](output + hc) # BS, K, D
elif seg_model_type == "SegLSTM":
output, hc = self.seg_model[i](output, hc) # BS, K, D
else:
raise ValueError("Unsupported segment model class")
if i < (self.num_blocks - 1):
if self.mem_type == "av":
hc = output.mean(1).unsqueeze(0)
hc = self.mem_model[i](hc).permute(1, 0, 2)
else:
hc = self.mem_model[i](hc, S)
output = output.reshape(B, S * K, D)[:, :T, :] # B, T, D
output = self.output_fc(output.transpose(1, 2)).transpose(1, 2)
return output
def _padfeature(self, input):
"""
Arguments
---------
input : Tensor of size [B, T, D]
where B is Batchsize
T is the chunk length
D is the feature dimensionality
Returns
-------
input : torch.Tensor
Padded input
rest : torch.Tensor
Amount of padding
"""
B, T, D = input.shape
rest = self.segment_size - T % self.segment_size
if rest > 0:
input = torch.nn.functional.pad(input, (0, 0, 0, rest))
return input, rest
[docs]
class ResourceEfficientSeparator(nn.Module):
"""Resource Efficient Source Separator
This is the class that implements RE-SepFormer
Arguments
---------
input_dim: int
Input feature dimension
causal: bool
Whether the system is causal.
num_spk: int
Number of target speakers.
nonlinear: class
the nonlinear function for mask estimation,
select from 'relu', 'tanh', 'sigmoid'
layer: int
number of blocks. Default is 2 for RE-SepFormer.
unit: int
Dimensionality of the hidden state.
segment_size: int
Chunk size for splitting long features
dropout: float
dropout ratio. Default is 0.
mem_type: str
'hc', 'h', 'c', 'id', 'av' or None.
This controls whether a memory representation will be used to ensure continuity between segments.
In 'av' mode, the summary state is is calculated by simply averaging over the time dimension of each segment
In 'id' mode, both the hidden and cell states
will be identically returned.
When mem_type is None, the memory model will be removed.
seg_model: class
The model that processes the within segment elements
mem_model: class
The memory model that ensures continuity between the segments
Example
-------
>>> x = torch.randn(10, 64, 100)
>>> seg_mdl = SBTransformerBlock_wnormandskip(1, 64, 8)
>>> mem_mdl = SBTransformerBlock_wnormandskip(1, 64, 8)
>>> resepformer = ResourceEfficientSeparator(64, num_spk=3, mem_type='av', seg_model=seg_mdl, mem_model=mem_mdl)
>>> out = resepformer.forward(x)
>>> out.shape
torch.Size([3, 10, 64, 100])
"""
def __init__(
self,
input_dim: int,
causal: bool = True,
num_spk: int = 2,
nonlinear: str = "relu",
layer: int = 3,
unit: int = 512,
segment_size: int = 20,
dropout: float = 0.0,
mem_type: str = "hc",
seg_model=None,
mem_model=None,
):
super().__init__()
self.num_spk = num_spk
self.segment_size = segment_size
if mem_type not in ("hc", "h", "c", "id", "av", None):
raise ValueError("Not supporting mem_type={}".format(mem_type))
self.model = ResourceEfficientSeparationPipeline(
input_size=input_dim,
hidden_size=unit,
output_size=input_dim * num_spk,
dropout=dropout,
num_blocks=layer,
bidirectional=(not causal),
norm_type="cln" if causal else "gln",
segment_size=segment_size,
mem_type=mem_type,
seg_model=seg_model,
mem_model=mem_model,
)
if nonlinear not in ("sigmoid", "relu", "tanh"):
raise ValueError("Not supporting nonlinear={}".format(nonlinear))
self.nonlinear = {
"sigmoid": torch.nn.Sigmoid(),
"relu": torch.nn.ReLU(),
"tanh": torch.nn.Tanh(),
}[nonlinear]
[docs]
def forward(self, inpt: torch.Tensor):
"""Forward
Arguments
---------
inpt : torch.Tensor
Encoded feature [B, T, N]
Returns
-------
mask_tensor : torch.Tensor
"""
inpt = inpt.permute(0, 2, 1)
B, T, N = inpt.shape
processed = self.model(inpt) # B,T, N
processed = processed.reshape(B, T, N, self.num_spk)
masks = self.nonlinear(processed).unbind(dim=3)
mask_tensor = torch.stack([m.permute(0, 2, 1) for m in masks])
return mask_tensor