speechbrain.processing.multi_mic module

Multi-microphone components.

This library contains functions for multi-microphone signal processing.

Example

>>> import torch
>>>
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, SrpPhat, Music
>>> from speechbrain.processing.multi_mic import DelaySum, Mvdr, Gev
>>>
>>> xs_speech = read_audio(
...    'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise_diff = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> xs_noise_diff = xs_noise_diff.unsqueeze(0)
>>> xs_noise_loc = read_audio('tests/samples/multi-mic/noise_0.70225_-0.70225_0.11704.flac')
>>> xs_noise_loc =  xs_noise_loc.unsqueeze(0)
>>> fs = 16000 # sampling rate
>>> ss = xs_speech
>>> nn_diff = 0.05 * xs_noise_diff
>>> nn_loc = 0.05 * xs_noise_loc
>>> xs_diffused_noise = ss + nn_diff
>>> xs_localized_noise = ss + nn_loc
>>> # Delay-and-Sum Beamforming with GCC-PHAT localization
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>> delaysum = DelaySum()
>>> istft = ISTFT(sample_rate=fs)
>>> Xs = stft(xs_diffused_noise)
>>> Ns = stft(nn_diff)
>>> XXs = cov(Xs)
>>> NNs = cov(Ns)
>>> tdoas = gccphat(XXs)
>>> Ys_ds = delaysum(Xs, tdoas)
>>> ys_ds = istft(Ys_ds)
>>> # Mvdr Beamforming with SRP-PHAT localization
>>> mvdr = Mvdr()
>>> mics = torch.zeros((4,3), dtype=torch.float)
>>> mics[0,:] = torch.FloatTensor([-0.05, -0.05, +0.00])
>>> mics[1,:] = torch.FloatTensor([-0.05, +0.05, +0.00])
>>> mics[2,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> mics[3,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> srpphat = SrpPhat(mics=mics)
>>> doas = srpphat(XXs)
>>> Ys_mvdr = mvdr(Xs, NNs, doas, doa_mode=True, mics=mics, fs=fs)
>>> ys_mvdr = istft(Ys_mvdr)
>>> # Mvdr Beamforming with MUSIC localization
>>> music = Music(mics=mics)
>>> doas = music(XXs)
>>> Ys_mvdr2 = mvdr(Xs, NNs, doas, doa_mode=True, mics=mics, fs=fs)
>>> ys_mvdr2 = istft(Ys_mvdr2)
>>> # GeV Beamforming
>>> gev = Gev()
>>> Xs = stft(xs_localized_noise)
>>> Ss = stft(ss)
>>> Ns = stft(nn_loc)
>>> SSs = cov(Ss)
>>> NNs = cov(Ns)
>>> Ys_gev = gev(Xs, SSs, NNs)
>>> ys_gev = istft(Ys_gev)
Authors:
  • William Aris

  • Francois Grondin

Summary

Classes:

Covariance

Computes the covariance matrices of the signals.

DelaySum

Performs delay and sum beamforming by using the TDOAs and the first channel as a reference.

GccPhat

Generalized Cross-Correlation with Phase Transform localization.

Gev

Generalized EigenValue decomposition (GEV) Beamforming.

Music

Multiple Signal Classification (MUSIC) localization.

Mvdr

Perform minimum variance distortionless response (MVDR) beamforming by using an input signal in the frequency domain, its covariance matrices and tdoas (to compute a steering vector).

SrpPhat

Steered-Response Power with Phase Transform Localization.

Functions:

doas2taus

This function converts directions of arrival (xyz coordinates expressed in meters) in time differences of arrival (expressed in samples).

sphere

This function generates cartesian coordinates (xyz) for a set of points forming a 3D sphere.

steering

This function computes a steering vector by using the time differences of arrival for each channel (in samples) and the number of bins (n_fft).

tdoas2taus

This function selects the tdoas of each channel and put them in a tensor.

Reference

class speechbrain.processing.multi_mic.Covariance(average=True)[source]

Bases: Module

Computes the covariance matrices of the signals.

Arguments:

averagebool

Informs the module if it should return an average (computed on the time dimension) of the covariance matrices. The Default value is True.

Example

>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT
>>> from speechbrain.processing.multi_mic import Covariance
>>>
>>> xs_speech = read_audio(
...    'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> xs_noise = xs_noise.unsqueeze(0)
>>> xs = xs_speech + 0.05 * xs_noise
>>> fs = 16000
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>>
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> XXs.shape
torch.Size([1, 1001, 201, 2, 10])
forward(Xs)[source]

This method uses the utility function _cov to compute covariance matrices. Therefore, the result has the following format: (batch, time_step, n_fft/2 + 1, 2, n_mics + n_pairs).

The order on the last dimension corresponds to the triu_indices for a square matrix. For instance, if we have 4 channels, we get the following order: (0, 0), (0, 1), (0, 2), (0, 3), (1, 1), (1, 2), (1, 3), (2, 2), (2, 3) and (3, 3). Therefore, XXs[…, 0] corresponds to channels (0, 0) and XXs[…, 1] corresponds to channels (0, 1).

Arguments:

Xstensor

A batch of audio signals in the frequency domain. The tensor must have the following format: (batch, time_step, n_fft/2 + 1, 2, n_mics)

training: bool
class speechbrain.processing.multi_mic.DelaySum[source]

Bases: Module

Performs delay and sum beamforming by using the TDOAs and the first channel as a reference.

Example

>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, DelaySum
>>>
>>> xs_speech = read_audio(
...    'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_speech = xs_speech. unsqueeze(0) # [batch, time, channel]
>>> xs_noise  = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> xs_noise = xs_noise.unsqueeze(0) #[batch, time, channels]
>>> fs = 16000
>>> xs = xs_speech + 0.05 * xs_noise
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>> delaysum = DelaySum()
>>> istft = ISTFT(sample_rate=fs)
>>>
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> tdoas = gccphat(XXs)
>>> Ys = delaysum(Xs, tdoas)
>>> ys = istft(Ys)
forward(Xs, localization_tensor, doa_mode=False, mics=None, fs=None, c=343.0)[source]

This method computes a steering vector by using the TDOAs/DOAs and then calls the utility function _delaysum to perform beamforming. The result has the following format: (batch, time_step, n_fft, 2, 1).

Parameters:
  • Xs (tensor) – A batch of audio signals in the frequency domain. The tensor must have the following format: (batch, time_step, n_fft/2 + 1, 2, n_mics)

  • localization_tensor (tensor) – A tensor containing either time differences of arrival (TDOAs) (in samples) for each timestamp or directions of arrival (DOAs) (xyz coordinates in meters). If localization_tensor represents TDOAs, then its format is (batch, time_steps, n_mics + n_pairs). If localization_tensor represents DOAs, then its format is (batch, time_steps, 3)

  • doa_mode (bool) – The user needs to set this parameter to True if localization_tensor represents DOAs instead of TDOAs. Its default value is set to False.

  • mics (tensor) – The cartesian position (xyz coordinates in meters) of each microphone. The tensor must have the following format (n_mics, 3). This parameter is only mandatory when localization_tensor represents DOAs.

  • fs (int) – The sample rate in Hertz of the signals. This parameter is only mandatory when localization_tensor represents DOAs.

  • c (float) – The speed of sound in the medium. The speed is expressed in meters per second and the default value of this parameter is 343 m/s. This parameter is only used when localization_tensor represents DOAs.

training: bool
class speechbrain.processing.multi_mic.Mvdr(eps=1e-20)[source]

Bases: Module

Perform minimum variance distortionless response (MVDR) beamforming by using an input signal in the frequency domain, its covariance matrices and tdoas (to compute a steering vector).

>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, DelaySum
>>>
>>> xs_speech = read_audio(
...    'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channel]
>>> xs_noise  = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> xs_noise = xs_noise.unsqueeze(0) #[batch, time, channels]
>>> fs = 16000
>>> xs = xs_speech + 0.05 * xs_noise
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>> mvdr = Mvdr()
>>> istft = ISTFT(sample_rate=fs)
>>>
>>> Xs = stft(xs)
>>> Ns = stft(xs_noise)
>>> XXs = cov(Xs)
>>> NNs = cov(Ns)
>>> tdoas = gccphat(XXs)
>>> Ys = mvdr(Xs, NNs, tdoas)
>>> ys = istft(Ys)
forward(Xs, NNs, localization_tensor, doa_mode=False, mics=None, fs=None, c=343.0)[source]

This method computes a steering vector before using the utility function _mvdr to perform beamforming. The result has the following format: (batch, time_step, n_fft, 2, 1).

Parameters:
  • Xs (tensor) – A batch of audio signals in the frequency domain. The tensor must have the following format: (batch, time_step, n_fft/2 + 1, 2, n_mics)

  • NNs (tensor) – The covariance matrices of the noise signal. The tensor must have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs)

  • localization_tensor (tensor) – A tensor containing either time differences of arrival (TDOAs) (in samples) for each timestamp or directions of arrival (DOAs) (xyz coordinates in meters). If localization_tensor represents TDOAs, then its format is (batch, time_steps, n_mics + n_pairs). If localization_tensor represents DOAs, then its format is (batch, time_steps, 3)

  • doa_mode (bool) – The user needs to set this parameter to True if localization_tensor represents DOAs instead of TDOAs. Its default value is set to False.

  • mics (tensor) – The cartesian position (xyz coordinates in meters) of each microphone. The tensor must have the following format (n_mics, 3). This parameter is only mandatory when localization_tensor represents DOAs.

  • fs (int) – The sample rate in Hertz of the signals. This parameter is only mandatory when localization_tensor represents DOAs.

  • c (float) – The speed of sound in the medium. The speed is expressed in meters per second and the default value of this parameter is 343 m/s. This parameter is only used when localization_tensor represents DOAs.

training: bool
class speechbrain.processing.multi_mic.Gev[source]

Bases: Module

Generalized EigenValue decomposition (GEV) Beamforming.

Example

>>> from speechbrain.dataio.dataio import read_audio
>>> import torch
>>>
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import Gev
>>>
>>> xs_speech = read_audio(
...    'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_speech  = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_0.70225_-0.70225_0.11704.flac')
>>> xs_noise = xs_noise.unsqueeze(0)
>>> fs = 16000
>>> ss = xs_speech
>>> nn = 0.05 * xs_noise
>>> xs = ss + nn
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gev = Gev()
>>> istft = ISTFT(sample_rate=fs)
>>>
>>> Ss = stft(ss)
>>> Nn = stft(nn)
>>> Xs = stft(xs)
>>>
>>> SSs = cov(Ss)
>>> NNs = cov(Nn)
>>>
>>> Ys = gev(Xs, SSs, NNs)
>>> ys = istft(Ys)
forward(Xs, SSs, NNs)[source]

This method uses the utility function _gev to perform generalized eigenvalue decomposition beamforming. Therefore, the result has the following format: (batch, time_step, n_fft, 2, 1).

Parameters:
  • Xs (tensor) – A batch of audio signals in the frequency domain. The tensor must have the following format: (batch, time_step, n_fft/2 + 1, 2, n_mics).

  • SSs (tensor) – The covariance matrices of the target signal. The tensor must have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).

  • NNs (tensor) – The covariance matrices of the noise signal. The tensor must have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).

training: bool
class speechbrain.processing.multi_mic.GccPhat(tdoa_max=None, eps=1e-20)[source]

Bases: Module

Generalized Cross-Correlation with Phase Transform localization.

Parameters:
  • tdoa_max (int) – Specifies a range to search for delays. For example, if tdoa_max = 10, the method will restrict its search for delays between -10 and 10 samples. This parameter is optional and its default value is None. When tdoa_max is None, the method will search for delays between -n_fft/2 and n_fft/2 (full range).

  • eps (float) – A small value to avoid divisions by 0 with the phase transformation. The default value is 1e-20.

Example

>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, DelaySum
>>>
>>> xs_speech = read_audio(
...    'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channel]
>>> xs_noise  = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> xs_noise = xs_noise.unsqueeze(0) #[batch, time, channels]
>>> fs = 16000
>>> xs = xs_speech + 0.05 * xs_noise
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> tdoas = gccphat(XXs)
forward(XXs)[source]

Perform generalized cross-correlation with phase transform localization by using the utility function _gcc_phat and by extracting the delays (in samples) before performing a quadratic interpolation to improve the accuracy. The result has the format: (batch, time_steps, n_mics + n_pairs).

The order on the last dimension corresponds to the triu_indices for a square matrix. For instance, if we have 4 channels, we get the following order: (0, 0), (0, 1), (0, 2), (0, 3), (1, 1), (1, 2), (1, 3), (2, 2), (2, 3) and (3, 3). Therefore, delays[…, 0] corresponds to channels (0, 0) and delays[…, 1] corresponds to channels (0, 1).

Arguments:

XXstensor

The covariance matrices of the input signal. The tensor must have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).

training: bool
class speechbrain.processing.multi_mic.SrpPhat(mics, space='sphere', sample_rate=16000, speed_sound=343.0, eps=1e-20)[source]

Bases: Module

Steered-Response Power with Phase Transform Localization.

Parameters:
  • mics (tensor) – The cartesian coordinates (xyz) in meters of each microphone. The tensor must have the following format (n_mics, 3).

  • space (string) – If this parameter is set to ‘sphere’, the localization will be done in 3D by searching in a sphere of possible doas. If it set to ‘circle’, the search will be done in 2D by searching in a circle. By default, this parameter is set to ‘sphere’. Note: The ‘circle’ option isn’t implemented yet.

  • sample_rate (int) – The sample rate in Hertz of the signals to perform SRP-PHAT on. By default, this parameter is set to 16000 Hz.

  • speed_sound (float) – The speed of sound in the medium. The speed is expressed in meters per second and the default value of this parameter is 343 m/s.

  • eps (float) – A small value to avoid errors like division by 0. The default value of this parameter is 1e-20.

Example

>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import SrpPhat
>>> xs_speech = read_audio('tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac')
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> fs = 16000
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise = xs_noise.unsqueeze(0)
>>> ss1 = xs_speech
>>> ns1 = 0.05 * xs_noise
>>> xs1 = ss1 + ns1
>>> ss2 = xs_speech
>>> ns2 = 0.20 * xs_noise
>>> xs2 = ss2 + ns2
>>> ss = torch.cat((ss1,ss2), dim=0)
>>> ns = torch.cat((ns1,ns2), dim=0)
>>> xs = torch.cat((xs1,xs2), dim=0)
>>> mics = torch.zeros((4,3), dtype=torch.float)
>>> mics[0,:] = torch.FloatTensor([-0.05, -0.05, +0.00])
>>> mics[1,:] = torch.FloatTensor([-0.05, +0.05, +0.00])
>>> mics[2,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> mics[3,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> srpphat = SrpPhat(mics=mics)
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> doas = srpphat(XXs)
forward(XXs)[source]

Perform SRP-PHAT localization on a signal by computing a steering vector and then by using the utility function _srp_phat to extract the doas. The result is a tensor containing the directions of arrival (xyz coordinates (in meters) in the direction of the sound source). The output tensor has the format (batch, time_steps, 3).

This localization method uses Global Coherence Field (GCF): https://www.researchgate.net/publication/221491705_Speaker_localization_based_on_oriented_global_coherence_field

Parameters:

XXs (tensor) – The covariance matrices of the input signal. The tensor must have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).

training: bool
class speechbrain.processing.multi_mic.Music(mics, space='sphere', sample_rate=16000, speed_sound=343.0, eps=1e-20, n_sig=1)[source]

Bases: Module

Multiple Signal Classification (MUSIC) localization.

Parameters:
  • mics (tensor) – The cartesian coordinates (xyz) in meters of each microphone. The tensor must have the following format (n_mics, 3).

  • space (string) – If this parameter is set to ‘sphere’, the localization will be done in 3D by searching in a sphere of possible doas. If it set to ‘circle’, the search will be done in 2D by searching in a circle. By default, this parameter is set to ‘sphere’. Note: The ‘circle’ option isn’t implemented yet.

  • sample_rate (int) – The sample rate in Hertz of the signals to perform SRP-PHAT on. By default, this parameter is set to 16000 Hz.

  • speed_sound (float) – The speed of sound in the medium. The speed is expressed in meters per second and the default value of this parameter is 343 m/s.

  • eps (float) – A small value to avoid errors like division by 0. The default value of this parameter is 1e-20.

  • n_sig (int) – An estimation of the number of sound sources. The default value is set to one source.

Example

>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import SrpPhat
>>> xs_speech = read_audio('tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac')
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> fs = 16000
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise = xs_noise.unsqueeze(0)
>>> ss1 = xs_speech
>>> ns1 = 0.05 * xs_noise
>>> xs1 = ss1 + ns1
>>> ss2 = xs_speech
>>> ns2 = 0.20 * xs_noise
>>> xs2 = ss2 + ns2
>>> ss = torch.cat((ss1,ss2), dim=0)
>>> ns = torch.cat((ns1,ns2), dim=0)
>>> xs = torch.cat((xs1,xs2), dim=0)
>>> mics = torch.zeros((4,3), dtype=torch.float)
>>> mics[0,:] = torch.FloatTensor([-0.05, -0.05, +0.00])
>>> mics[1,:] = torch.FloatTensor([-0.05, +0.05, +0.00])
>>> mics[2,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> mics[3,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> music = Music(mics=mics)
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> doas = music(XXs)
forward(XXs)[source]

Perform MUSIC localization on a signal by computing a steering vector and then by using the utility function _music to extract the doas. The result is a tensor containing the directions of arrival (xyz coordinates (in meters) in the direction of the sound source). The output tensor has the format (batch, time_steps, 3).

Parameters:

XXs (tensor) – The covariance matrices of the input signal. The tensor must have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).

training: bool
speechbrain.processing.multi_mic.doas2taus(doas, mics, fs, c=343.0)[source]

This function converts directions of arrival (xyz coordinates expressed in meters) in time differences of arrival (expressed in samples). The result has the following format: (batch, time_steps, n_mics).

Parameters:
  • doas (tensor) – The directions of arrival expressed with cartesian coordinates (xyz) in meters. The tensor must have the following format: (batch, time_steps, 3).

  • mics (tensor) – The cartesian position (xyz) in meters of each microphone. The tensor must have the following format (n_mics, 3).

  • fs (int) – The sample rate in Hertz of the signals.

  • c (float) – The speed of sound in the medium. The speed is expressed in meters per second and the default value of this parameter is 343 m/s.

Example

>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.multi_mic import sphere, doas2taus
>>> xs = read_audio('tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac')
>>> xs = xs.unsqueeze(0) # [batch, time, channels]
>>> fs = 16000
>>> mics = torch.zeros((4,3), dtype=torch.float)
>>> mics[0,:] = torch.FloatTensor([-0.05, -0.05, +0.00])
>>> mics[1,:] = torch.FloatTensor([-0.05, +0.05, +0.00])
>>> mics[2,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> mics[3,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> doas = sphere()
>>> taus = doas2taus(doas, mics, fs)
speechbrain.processing.multi_mic.tdoas2taus(tdoas)[source]

This function selects the tdoas of each channel and put them in a tensor. The result has the following format: (batch, time_steps, n_mics).

Arguments:

tdoastensor

The time difference of arrival (TDOA) (in samples) for each timestamp. The tensor has the format (batch, time_steps, n_mics + n_pairs).

Example

>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, tdoas2taus
>>>
>>> xs_speech = read_audio(
...    'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> xs = xs_speech + 0.05 * xs_noise
>>> xs = xs.unsqueeze(0)
>>> fs = 16000
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>>
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> tdoas = gccphat(XXs)
>>> taus = tdoas2taus(tdoas)
speechbrain.processing.multi_mic.steering(taus, n_fft)[source]

This function computes a steering vector by using the time differences of arrival for each channel (in samples) and the number of bins (n_fft). The result has the following format: (batch, time_step, n_fft/2 + 1, 2, n_mics).

Arguments:

taustensor

The time differences of arrival for each channel. The tensor must have the following format: (batch, time_steps, n_mics).

n_fftint

The number of bins resulting of the STFT. It is assumed that the argument “onesided” was set to True for the STFT.

Example: ——–f >>> import torch >>> from speechbrain.dataio.dataio import read_audio >>> from speechbrain.processing.features import STFT >>> from speechbrain.processing.multi_mic import Covariance >>> from speechbrain.processing.multi_mic import GccPhat, tdoas2taus, steering >>> >>> xs_speech = read_audio( … ‘tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac’ … ) >>> xs_noise = read_audio(‘tests/samples/multi-mic/noise_diffuse.flac’) >>> xs = xs_speech + 0.05 * xs_noise >>> xs = xs.unsqueeze(0) # [batch, time, channels] >>> fs = 16000

>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>>
>>> Xs = stft(xs)
>>> n_fft = Xs.shape[2]
>>> XXs = cov(Xs)
>>> tdoas = gccphat(XXs)
>>> taus = tdoas2taus(tdoas)
>>> As = steering(taus, n_fft)
speechbrain.processing.multi_mic.sphere(levels_count=4)[source]

This function generates cartesian coordinates (xyz) for a set of points forming a 3D sphere. The coordinates are expressed in meters and can be used as doas. The result has the format: (n_points, 3).

Parameters:

levels_count (int) –

A number proportional to the number of points that the user wants to generate.

  • If levels_count = 1, then the sphere will have 42 points

  • If levels_count = 2, then the sphere will have 162 points

  • If levels_count = 3, then the sphere will have 642 points

  • If levels_count = 4, then the sphere will have 2562 points

  • If levels_count = 5, then the sphere will have 10242 points

By default, levels_count is set to 4.

Example

>>> import torch
>>> from speechbrain.processing.multi_mic import sphere
>>> doas = sphere()