speechbrain.lobes.models.ECAPA_TDNN module

A popular speaker recognition and diarization model.

Authors
  • Hwidong Na 2020

Summary

Classes:

AttentiveStatisticsPooling

This class implements an attentive statistic pooling layer for each channel.

BatchNorm1d

Classifier

This class implements the cosine similarity on the top of features.

Conv1d

ECAPA_TDNN

An implementation of the speaker embedding model in a paper.

Res2NetBlock

An implementation of Res2NetBlock w/ dilation.

SEBlock

An implementation of squeeuze-and-excitation block.

SERes2NetBlock

An implementation of building block in ECAPA-TDNN, i.e., TDNN-Res2Net-TDNN-SEBlock.

TDNNBlock

An implementation of TDNN.

Reference

class speechbrain.lobes.models.ECAPA_TDNN.Conv1d(*args, **kwargs)[source]

Bases: speechbrain.nnet.CNN.Conv1d

training: bool
class speechbrain.lobes.models.ECAPA_TDNN.BatchNorm1d(*args, **kwargs)[source]

Bases: speechbrain.nnet.normalization.BatchNorm1d

training: bool
class speechbrain.lobes.models.ECAPA_TDNN.TDNNBlock(in_channels, out_channels, kernel_size, dilation, activation=<class 'torch.nn.modules.activation.ReLU'>)[source]

Bases: torch.nn.modules.module.Module

An implementation of TDNN.

in_channelsint

Number of input channels.

out_channelsint

The number of output channels.

kernel_sizeint

The kernel size of the TDNN blocks.

dilationint

The dilation of the Res2Net block.

activationtorch class

A class for constructing the activation layers.

Example

>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
>>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
>>> out_tensor.shape
torch.Size([8, 120, 64])
forward(x)[source]
training: bool
class speechbrain.lobes.models.ECAPA_TDNN.Res2NetBlock(in_channels, out_channels, scale=8, dilation=1)[source]

Bases: torch.nn.modules.module.Module

An implementation of Res2NetBlock w/ dilation.

Parameters
  • in_channels (int) – The number of channels expected in the input.

  • out_channels (int) – The number of output channels.

  • scale (int) – The scale of the Res2Net block.

  • dilation (int) – The dilation of the Res2Net block.

Example

>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
>>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
>>> out_tensor.shape
torch.Size([8, 120, 64])
forward(x)[source]
training: bool
class speechbrain.lobes.models.ECAPA_TDNN.SEBlock(in_channels, se_channels, out_channels)[source]

Bases: torch.nn.modules.module.Module

An implementation of squeeuze-and-excitation block.

Parameters
  • in_channels (int) – The number of input channels.

  • se_channels (int) – The number of output channels after squeeze.

  • out_channels (int) – The number of output channels.

Example

>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
>>> se_layer = SEBlock(64, 16, 64)
>>> lengths = torch.rand((8,))
>>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
>>> out_tensor.shape
torch.Size([8, 120, 64])
forward(x, lengths=None)[source]
training: bool
class speechbrain.lobes.models.ECAPA_TDNN.AttentiveStatisticsPooling(channels, attention_channels=128, global_context=True)[source]

Bases: torch.nn.modules.module.Module

This class implements an attentive statistic pooling layer for each channel. It returns the concatenated mean and std of the input tensor.

Parameters
  • channels (int) – The number of input channels.

  • attention_channels (int) – The number of attention channels.

Example

>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
>>> asp_layer = AttentiveStatisticsPooling(64)
>>> lengths = torch.rand((8,))
>>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
>>> out_tensor.shape
torch.Size([8, 1, 128])
forward(x, lengths=None)[source]

Calculates mean and std for a batch (input tensor).

Parameters

x (torch.Tensor) – Tensor of shape [N, C, L].

training: bool
class speechbrain.lobes.models.ECAPA_TDNN.SERes2NetBlock(in_channels, out_channels, res2net_scale=8, se_channels=128, kernel_size=1, dilation=1, activation=<class 'torch.nn.modules.activation.ReLU'>)[source]

Bases: torch.nn.modules.module.Module

An implementation of building block in ECAPA-TDNN, i.e., TDNN-Res2Net-TDNN-SEBlock.

Parameters
  • out_channels (int) – The number of output channels.

  • res2net_scale (int) – The scale of the Res2Net block.

  • kernel_size (int) – The kernel size of the TDNN blocks.

  • dilation (int) – The dilation of the Res2Net block.

  • activation (torch class) – A class for constructing the activation layers.

Example

>>> x = torch.rand(8, 120, 64).transpose(1, 2)
>>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
>>> out = conv(x).transpose(1, 2)
>>> out.shape
torch.Size([8, 120, 64])
forward(x, lengths=None)[source]
training: bool
class speechbrain.lobes.models.ECAPA_TDNN.ECAPA_TDNN(input_size, device='cpu', lin_neurons=192, activation=<class 'torch.nn.modules.activation.ReLU'>, channels=[512, 512, 512, 512, 1536], kernel_sizes=[5, 3, 3, 3, 1], dilations=[1, 2, 3, 4, 1], attention_channels=128, res2net_scale=8, se_channels=128, global_context=True)[source]

Bases: torch.nn.modules.module.Module

An implementation of the speaker embedding model in a paper. “ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification” (https://arxiv.org/abs/2005.07143).

Parameters
  • device (str) – Device used, e.g., “cpu” or “cuda”.

  • activation (torch class) – A class for constructing the activation layers.

  • channels (list of ints) – Output channels for TDNN/SERes2Net layer.

  • kernel_sizes (list of ints) – List of kernel sizes for each layer.

  • dilations (list of ints) – List of dilations for kernels in each layer.

  • lin_neurons (int) – Number of neurons in linear layers.

Example

>>> input_feats = torch.rand([5, 120, 80])
>>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
>>> outputs = compute_embedding(input_feats)
>>> outputs.shape
torch.Size([5, 1, 192])
forward(x, lengths=None)[source]

Returns the embedding vector.

Parameters

x (torch.Tensor) – Tensor of shape (batch, time, channel).

training: bool
class speechbrain.lobes.models.ECAPA_TDNN.Classifier(input_size, device='cpu', lin_blocks=0, lin_neurons=192, out_neurons=1211)[source]

Bases: torch.nn.modules.module.Module

This class implements the cosine similarity on the top of features.

Parameters
  • device (str) – Device used, e.g., “cpu” or “cuda”.

  • lin_blocks (int) – Number of linear layers.

  • lin_neurons (int) – Number of neurons in linear layers.

  • out_neurons (int) – Number of classes.

Example

>>> classify = Classifier(input_size=2, lin_neurons=2, out_neurons=2)
>>> outputs = torch.tensor([ [1., -1.], [-9., 1.], [0.9, 0.1], [0.1, 0.9] ])
>>> outupts = outputs.unsqueeze(1)
>>> cos = classify(outputs)
>>> (cos < -1.0).long().sum()
tensor(0)
>>> (cos > 1.0).long().sum()
tensor(0)
training: bool
forward(x)[source]

Returns the output probabilities over speakers.

Parameters

x (torch.Tensor) – Torch tensor.