"""Multi-microphone components.
This library contains functions for multi-microphone signal processing.
Example
-------
>>> import torch
>>>
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, SrpPhat, Music
>>> from speechbrain.processing.multi_mic import DelaySum, Mvdr, Gev
>>>
>>> xs_speech = read_audio(
... 'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise_diff = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> xs_noise_diff = xs_noise_diff.unsqueeze(0)
>>> xs_noise_loc = read_audio('tests/samples/multi-mic/noise_0.70225_-0.70225_0.11704.flac')
>>> xs_noise_loc = xs_noise_loc.unsqueeze(0)
>>> fs = 16000 # sampling rate
>>> ss = xs_speech
>>> nn_diff = 0.05 * xs_noise_diff
>>> nn_loc = 0.05 * xs_noise_loc
>>> xs_diffused_noise = ss + nn_diff
>>> xs_localized_noise = ss + nn_loc
>>> # Delay-and-Sum Beamforming with GCC-PHAT localization
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>> delaysum = DelaySum()
>>> istft = ISTFT(sample_rate=fs)
>>> Xs = stft(xs_diffused_noise)
>>> Ns = stft(nn_diff)
>>> XXs = cov(Xs)
>>> NNs = cov(Ns)
>>> tdoas = gccphat(XXs)
>>> Ys_ds = delaysum(Xs, tdoas)
>>> ys_ds = istft(Ys_ds)
>>> # Mvdr Beamforming with SRP-PHAT localization
>>> mvdr = Mvdr()
>>> mics = torch.zeros((4,3), dtype=torch.float)
>>> mics[0,:] = torch.FloatTensor([-0.05, -0.05, +0.00])
>>> mics[1,:] = torch.FloatTensor([-0.05, +0.05, +0.00])
>>> mics[2,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> mics[3,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> srpphat = SrpPhat(mics=mics)
>>> doas = srpphat(XXs)
>>> Ys_mvdr = mvdr(Xs, NNs, doas, doa_mode=True, mics=mics, fs=fs)
>>> ys_mvdr = istft(Ys_mvdr)
>>> # Mvdr Beamforming with MUSIC localization
>>> music = Music(mics=mics)
>>> doas = music(XXs)
>>> Ys_mvdr2 = mvdr(Xs, NNs, doas, doa_mode=True, mics=mics, fs=fs)
>>> ys_mvdr2 = istft(Ys_mvdr2)
>>> # GeV Beamforming
>>> gev = Gev()
>>> Xs = stft(xs_localized_noise)
>>> Ss = stft(ss)
>>> Ns = stft(nn_loc)
>>> SSs = cov(Ss)
>>> NNs = cov(Ns)
>>> Ys_gev = gev(Xs, SSs, NNs)
>>> ys_gev = istft(Ys_gev)
Authors:
* William Aris
* Francois Grondin
"""
import torch
from packaging import version
import speechbrain.processing.decomposition as eig
[docs]
class Covariance(torch.nn.Module):
"""Computes the covariance matrices of the signals.
Arguments:
----------
average : bool
Informs the module if it should return an average
(computed on the time dimension) of the covariance
matrices. The Default value is True.
Example
-------
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT
>>> from speechbrain.processing.multi_mic import Covariance
>>>
>>> xs_speech = read_audio(
... 'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> xs_noise = xs_noise.unsqueeze(0)
>>> xs = xs_speech + 0.05 * xs_noise
>>> fs = 16000
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>>
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> XXs.shape
torch.Size([1, 1001, 201, 2, 10])
"""
def __init__(self, average=True):
super().__init__()
self.average = average
[docs]
def forward(self, Xs):
""" This method uses the utility function _cov to compute covariance
matrices. Therefore, the result has the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics + n_pairs).
The order on the last dimension corresponds to the triu_indices for a
square matrix. For instance, if we have 4 channels, we get the following
order: (0, 0), (0, 1), (0, 2), (0, 3), (1, 1), (1, 2), (1, 3), (2, 2), (2, 3)
and (3, 3). Therefore, XXs[..., 0] corresponds to channels (0, 0) and XXs[..., 1]
corresponds to channels (0, 1).
Arguments:
----------
Xs : tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics)
"""
XXs = Covariance._cov(Xs=Xs, average=self.average)
return XXs
@staticmethod
def _cov(Xs, average=True):
""" Computes the covariance matrices (XXs) of the signals. The result will
have the following format: (batch, time_step, n_fft/2 + 1, 2, n_mics + n_pairs).
Arguments:
----------
Xs : tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics)
average : boolean
Informs the function if it should return an average
(computed on the time dimension) of the covariance
matrices. Default value is True.
"""
# Get useful dimensions
n_mics = Xs.shape[4]
# Formatting the real and imaginary parts
Xs_re = Xs[..., 0, :].unsqueeze(4)
Xs_im = Xs[..., 1, :].unsqueeze(4)
# Computing the covariance
Rxx_re = torch.matmul(Xs_re, Xs_re.transpose(3, 4)) + torch.matmul(
Xs_im, Xs_im.transpose(3, 4)
)
Rxx_im = torch.matmul(Xs_re, Xs_im.transpose(3, 4)) - torch.matmul(
Xs_im, Xs_re.transpose(3, 4)
)
# Selecting the upper triangular part of the covariance matrices
idx = torch.triu_indices(n_mics, n_mics)
XXs_re = Rxx_re[..., idx[0], idx[1]]
XXs_im = Rxx_im[..., idx[0], idx[1]]
XXs = torch.stack((XXs_re, XXs_im), 3)
# Computing the average if desired
if average is True:
n_time_frames = XXs.shape[1]
XXs = torch.mean(XXs, 1, keepdim=True)
XXs = XXs.repeat(1, n_time_frames, 1, 1, 1)
return XXs
[docs]
class DelaySum(torch.nn.Module):
"""Performs delay and sum beamforming by using the TDOAs and
the first channel as a reference.
Example
-------
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, DelaySum
>>>
>>> xs_speech = read_audio(
... 'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_speech = xs_speech. unsqueeze(0) # [batch, time, channel]
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> xs_noise = xs_noise.unsqueeze(0) #[batch, time, channels]
>>> fs = 16000
>>> xs = xs_speech + 0.05 * xs_noise
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>> delaysum = DelaySum()
>>> istft = ISTFT(sample_rate=fs)
>>>
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> tdoas = gccphat(XXs)
>>> Ys = delaysum(Xs, tdoas)
>>> ys = istft(Ys)
"""
def __init__(self):
super().__init__()
[docs]
def forward(
self,
Xs,
localization_tensor,
doa_mode=False,
mics=None,
fs=None,
c=343.0,
):
"""This method computes a steering vector by using the TDOAs/DOAs and
then calls the utility function _delaysum to perform beamforming.
The result has the following format: (batch, time_step, n_fft, 2, 1).
Arguments
---------
Xs : tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics)
localization_tensor : tensor
A tensor containing either time differences of arrival (TDOAs)
(in samples) for each timestamp or directions of arrival (DOAs)
(xyz coordinates in meters). If localization_tensor represents
TDOAs, then its format is (batch, time_steps, n_mics + n_pairs).
If localization_tensor represents DOAs, then its format is
(batch, time_steps, 3)
doa_mode : bool
The user needs to set this parameter to True if localization_tensor
represents DOAs instead of TDOAs. Its default value is set to False.
mics : tensor
The cartesian position (xyz coordinates in meters) of each microphone.
The tensor must have the following format (n_mics, 3). This
parameter is only mandatory when localization_tensor represents
DOAs.
fs : int
The sample rate in Hertz of the signals. This parameter is only
mandatory when localization_tensor represents DOAs.
c : float
The speed of sound in the medium. The speed is expressed in meters
per second and the default value of this parameter is 343 m/s. This
parameter is only used when localization_tensor represents DOAs.
"""
# Get useful dimensions
n_fft = Xs.shape[2]
localization_tensor = localization_tensor.to(Xs.device)
# Convert the tdoas to taus
if doa_mode:
taus = doas2taus(doas=localization_tensor, mics=mics, fs=fs, c=c)
else:
taus = tdoas2taus(tdoas=localization_tensor)
# Generate the steering vector
As = steering(taus=taus, n_fft=n_fft)
# Apply delay and sum
Ys = DelaySum._delaysum(Xs=Xs, As=As)
return Ys
@staticmethod
def _delaysum(Xs, As):
"""Perform delay and sum beamforming. The result has
the following format: (batch, time_step, n_fft, 2, 1).
Arguments
---------
Xs : tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics)
As : tensor
The steering vector to point in the direction of
the target source. The tensor must have the format
(batch, time_step, n_fft/2 + 1, 2, n_mics)
"""
# Get useful dimensions
n_mics = Xs.shape[4]
# Generate unmixing coefficients
Ws_re = As[..., 0, :] / n_mics
Ws_im = -1 * As[..., 1, :] / n_mics
# Get input signal
Xs_re = Xs[..., 0, :]
Xs_im = Xs[..., 1, :]
# Applying delay and sum
Ys_re = torch.sum((Ws_re * Xs_re - Ws_im * Xs_im), dim=3, keepdim=True)
Ys_im = torch.sum((Ws_re * Xs_im + Ws_im * Xs_re), dim=3, keepdim=True)
# Assembling the result
Ys = torch.stack((Ys_re, Ys_im), 3)
return Ys
[docs]
class Mvdr(torch.nn.Module):
"""Perform minimum variance distortionless response (MVDR) beamforming
by using an input signal in the frequency domain, its covariance matrices
and tdoas (to compute a steering vector).
Example
-------
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, DelaySum
>>>
>>> xs_speech = read_audio(
... 'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channel]
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> xs_noise = xs_noise.unsqueeze(0) #[batch, time, channels]
>>> fs = 16000
>>> xs = xs_speech + 0.05 * xs_noise
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>> mvdr = Mvdr()
>>> istft = ISTFT(sample_rate=fs)
>>>
>>> Xs = stft(xs)
>>> Ns = stft(xs_noise)
>>> XXs = cov(Xs)
>>> NNs = cov(Ns)
>>> tdoas = gccphat(XXs)
>>> Ys = mvdr(Xs, NNs, tdoas)
>>> ys = istft(Ys)
"""
def __init__(self, eps=1e-20):
super().__init__()
self.eps = eps
[docs]
def forward(
self,
Xs,
NNs,
localization_tensor,
doa_mode=False,
mics=None,
fs=None,
c=343.0,
):
"""This method computes a steering vector before using the
utility function _mvdr to perform beamforming. The result has
the following format: (batch, time_step, n_fft, 2, 1).
Arguments
---------
Xs : tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics)
NNs : tensor
The covariance matrices of the noise signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs)
localization_tensor : tensor
A tensor containing either time differences of arrival (TDOAs)
(in samples) for each timestamp or directions of arrival (DOAs)
(xyz coordinates in meters). If localization_tensor represents
TDOAs, then its format is (batch, time_steps, n_mics + n_pairs).
If localization_tensor represents DOAs, then its format is
(batch, time_steps, 3)
doa_mode : bool
The user needs to set this parameter to True if localization_tensor
represents DOAs instead of TDOAs. Its default value is set to False.
mics : tensor
The cartesian position (xyz coordinates in meters) of each microphone.
The tensor must have the following format (n_mics, 3). This
parameter is only mandatory when localization_tensor represents
DOAs.
fs : int
The sample rate in Hertz of the signals. This parameter is only
mandatory when localization_tensor represents DOAs.
c : float
The speed of sound in the medium. The speed is expressed in meters
per second and the default value of this parameter is 343 m/s. This
parameter is only used when localization_tensor represents DOAs.
"""
# Get useful dimensions
n_fft = Xs.shape[2]
localization_tensor = localization_tensor.to(Xs.device)
NNs = NNs.to(Xs.device)
if mics is not None:
mics = mics.to(Xs.device)
# Convert the tdoas to taus
if doa_mode:
taus = doas2taus(doas=localization_tensor, mics=mics, fs=fs, c=c)
else:
taus = tdoas2taus(tdoas=localization_tensor)
# Generate the steering vector
As = steering(taus=taus, n_fft=n_fft)
# Perform mvdr
Ys = Mvdr._mvdr(Xs=Xs, NNs=NNs, As=As)
return Ys
@staticmethod
def _mvdr(Xs, NNs, As, eps=1e-20):
"""Perform minimum variance distortionless response beamforming.
Arguments
---------
Xs : tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics).
NNs : tensor
The covariance matrices of the noise signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
As : tensor
The steering vector to point in the direction of
the target source. The tensor must have the format
(batch, time_step, n_fft/2 + 1, 2, n_mics).
"""
# Get unique covariance values to reduce the number of computations
NNs_val, NNs_idx = torch.unique(NNs, return_inverse=True, dim=1)
# Inverse covariance matrices
NNs_inv = eig.inv(NNs_val)
# Capture real and imaginary parts, and restore time steps
NNs_inv_re = NNs_inv[..., 0][:, NNs_idx]
NNs_inv_im = NNs_inv[..., 1][:, NNs_idx]
# Decompose steering vector
AsC_re = As[..., 0, :].unsqueeze(4)
AsC_im = 1.0 * As[..., 1, :].unsqueeze(4)
AsT_re = AsC_re.transpose(3, 4)
AsT_im = -1.0 * AsC_im.transpose(3, 4)
# Project
NNs_inv_AsC_re = torch.matmul(NNs_inv_re, AsC_re) - torch.matmul(
NNs_inv_im, AsC_im
)
NNs_inv_AsC_im = torch.matmul(NNs_inv_re, AsC_im) + torch.matmul(
NNs_inv_im, AsC_re
)
# Compute the gain
alpha = 1.0 / (
torch.matmul(AsT_re, NNs_inv_AsC_re)
- torch.matmul(AsT_im, NNs_inv_AsC_im)
)
# Get the unmixing coefficients
Ws_re = torch.matmul(NNs_inv_AsC_re, alpha).squeeze(4)
Ws_im = -torch.matmul(NNs_inv_AsC_im, alpha).squeeze(4)
# Applying MVDR
Xs_re = Xs[..., 0, :]
Xs_im = Xs[..., 1, :]
Ys_re = torch.sum((Ws_re * Xs_re - Ws_im * Xs_im), dim=3, keepdim=True)
Ys_im = torch.sum((Ws_re * Xs_im + Ws_im * Xs_re), dim=3, keepdim=True)
Ys = torch.stack((Ys_re, Ys_im), -2)
return Ys
[docs]
class Gev(torch.nn.Module):
"""Generalized EigenValue decomposition (GEV) Beamforming.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> import torch
>>>
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import Gev
>>>
>>> xs_speech = read_audio(
... 'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_0.70225_-0.70225_0.11704.flac')
>>> xs_noise = xs_noise.unsqueeze(0)
>>> fs = 16000
>>> ss = xs_speech
>>> nn = 0.05 * xs_noise
>>> xs = ss + nn
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gev = Gev()
>>> istft = ISTFT(sample_rate=fs)
>>>
>>> Ss = stft(ss)
>>> Nn = stft(nn)
>>> Xs = stft(xs)
>>>
>>> SSs = cov(Ss)
>>> NNs = cov(Nn)
>>>
>>> Ys = gev(Xs, SSs, NNs)
>>> ys = istft(Ys)
"""
def __init__(self):
super().__init__()
[docs]
def forward(self, Xs, SSs, NNs):
""" This method uses the utility function _gev to perform generalized
eigenvalue decomposition beamforming. Therefore, the result has
the following format: (batch, time_step, n_fft, 2, 1).
Arguments
---------
Xs : tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics).
SSs : tensor
The covariance matrices of the target signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
NNs : tensor
The covariance matrices of the noise signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
"""
Ys = Gev._gev(Xs=Xs, SSs=SSs, NNs=NNs)
return Ys
@staticmethod
def _gev(Xs, SSs, NNs):
""" Perform generalized eigenvalue decomposition beamforming. The result
has the following format: (batch, time_step, n_fft, 2, 1).
Arguments
---------
Xs : tensor
A batch of audio signals in the frequency domain.
The tensor must have the following format:
(batch, time_step, n_fft/2 + 1, 2, n_mics).
SSs : tensor
The covariance matrices of the target signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
NNs : tensor
The covariance matrices of the noise signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
"""
# Putting on the right device
SSs = SSs.to(Xs.device)
NNs = NNs.to(Xs.device)
# Get useful dimensions
n_mics = Xs.shape[4]
n_mics_pairs = SSs.shape[4]
# Computing the eigenvectors
SSs_NNs = torch.cat((SSs, NNs), dim=4)
SSs_NNs_val, SSs_NNs_idx = torch.unique(
SSs_NNs, return_inverse=True, dim=1
)
SSs = SSs_NNs_val[..., range(0, n_mics_pairs)]
NNs = SSs_NNs_val[..., range(n_mics_pairs, 2 * n_mics_pairs)]
NNs = eig.pos_def(NNs)
Vs, Ds = eig.gevd(SSs, NNs)
# Beamforming
F_re = Vs[..., (n_mics - 1), 0]
F_im = Vs[..., (n_mics - 1), 1]
# Normalize
F_norm = 1.0 / (
torch.sum(F_re ** 2 + F_im ** 2, dim=3, keepdim=True) ** 0.5
).repeat(1, 1, 1, n_mics)
F_re *= F_norm
F_im *= F_norm
Ws_re = F_re[:, SSs_NNs_idx]
Ws_im = F_im[:, SSs_NNs_idx]
Xs_re = Xs[..., 0, :]
Xs_im = Xs[..., 1, :]
Ys_re = torch.sum((Ws_re * Xs_re - Ws_im * Xs_im), dim=3, keepdim=True)
Ys_im = torch.sum((Ws_re * Xs_im + Ws_im * Xs_re), dim=3, keepdim=True)
# Assembling the output
Ys = torch.stack((Ys_re, Ys_im), 3)
return Ys
[docs]
class GccPhat(torch.nn.Module):
"""Generalized Cross-Correlation with Phase Transform localization.
Arguments
---------
tdoa_max : int
Specifies a range to search for delays. For example, if
tdoa_max = 10, the method will restrict its search for delays
between -10 and 10 samples. This parameter is optional and its
default value is None. When tdoa_max is None, the method will
search for delays between -n_fft/2 and n_fft/2 (full range).
eps : float
A small value to avoid divisions by 0 with the phase transformation.
The default value is 1e-20.
Example
-------
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT, ISTFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, DelaySum
>>>
>>> xs_speech = read_audio(
... 'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channel]
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> xs_noise = xs_noise.unsqueeze(0) #[batch, time, channels]
>>> fs = 16000
>>> xs = xs_speech + 0.05 * xs_noise
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> tdoas = gccphat(XXs)
"""
def __init__(self, tdoa_max=None, eps=1e-20):
super().__init__()
self.tdoa_max = tdoa_max
self.eps = eps
[docs]
def forward(self, XXs):
""" Perform generalized cross-correlation with phase transform localization
by using the utility function _gcc_phat and by extracting the delays (in samples)
before performing a quadratic interpolation to improve the accuracy.
The result has the format: (batch, time_steps, n_mics + n_pairs).
The order on the last dimension corresponds to the triu_indices for a
square matrix. For instance, if we have 4 channels, we get the following
order: (0, 0), (0, 1), (0, 2), (0, 3), (1, 1), (1, 2), (1, 3), (2, 2), (2, 3)
and (3, 3). Therefore, delays[..., 0] corresponds to channels (0, 0) and delays[..., 1]
corresponds to channels (0, 1).
Arguments:
----------
XXs : tensor
The covariance matrices of the input signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
"""
xxs = GccPhat._gcc_phat(XXs=XXs, eps=self.eps)
delays = GccPhat._extract_delays(xxs=xxs, tdoa_max=self.tdoa_max)
tdoas = GccPhat._interpolate(xxs=xxs, delays=delays)
return tdoas
@staticmethod
def _gcc_phat(XXs, eps=1e-20):
""" Evaluate GCC-PHAT for each timestamp. It returns the result in the time
domain. The result has the format: (batch, time_steps, n_fft, n_mics + n_pairs).
Arguments
---------
XXs : tensor
The covariance matrices of the input signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
eps : float
A small value to avoid divisions by 0 with the phase transform. The
default value is 1e-20.
"""
# Get useful dimensions
n_samples = (XXs.shape[2] - 1) * 2
# Extracting the tensors needed
XXs_val, XXs_idx = torch.unique(XXs, return_inverse=True, dim=4)
XXs_re = XXs_val[..., 0, :]
XXs_im = XXs_val[..., 1, :]
# Applying the phase transform
XXs_abs = torch.sqrt(XXs_re ** 2 + XXs_im ** 2) + eps
XXs_re_phat = XXs_re / XXs_abs
XXs_im_phat = XXs_im / XXs_abs
XXs_phat = torch.stack((XXs_re_phat, XXs_im_phat), 4)
# Returning in the temporal domain
XXs_phat = XXs_phat.transpose(2, 3)
if version.parse(torch.__version__) >= version.parse("1.8.0"):
XXs_phat = torch.complex(XXs_phat[..., 0], XXs_phat[..., 1])
xxs = torch.fft.irfft(XXs_phat, n=n_samples)
else:
xxs = torch.irfft(XXs_phat, signal_ndim=1, signal_sizes=[n_samples])
xxs = xxs[..., XXs_idx, :]
# Formatting the output
xxs = xxs.transpose(2, 3)
return xxs
@staticmethod
def _extract_delays(xxs, tdoa_max=None):
""" Extract the rounded delays from the cross-correlation for each timestamp.
The result has the format: (batch, time_steps, n_mics + n_pairs).
Arguments
---------
xxs : tensor
The correlation signals obtained after a gcc-phat operation. The tensor
must have the format (batch, time_steps, n_fft, n_mics + n_pairs).
tdoa_max : int
Specifies a range to search for delays. For example, if
tdoa_max = 10, the method will restrict its search for delays
between -10 and 10 samples. This parameter is optional and its
default value is None. When tdoa_max is None, the method will
search for delays between -n_fft/2 and +n_fft/2 (full range).
"""
# Get useful dimensions
n_fft = xxs.shape[2]
# If no tdoa specified, cover the whole frame
if tdoa_max is None:
tdoa_max = torch.div(n_fft, 2, rounding_mode="floor")
# Splitting the GCC-PHAT values to search in the range
slice_1 = xxs[..., 0:tdoa_max, :]
slice_2 = xxs[..., -tdoa_max:, :]
xxs_sliced = torch.cat((slice_1, slice_2), 2)
# Extracting the delays in the range
_, delays = torch.max(xxs_sliced, 2)
# Adjusting the delays that were affected by the slicing
offset = n_fft - xxs_sliced.shape[2]
idx = delays >= slice_1.shape[2]
delays[idx] += offset
# Centering the delays around 0
delays[idx] -= n_fft
return delays
@staticmethod
def _interpolate(xxs, delays):
"""Perform quadratic interpolation on the cross-correlation to
improve the tdoa accuracy. The result has the format:
(batch, time_steps, n_mics + n_pairs)
Arguments
---------
xxs : tensor
The correlation signals obtained after a gcc-phat operation. The tensor
must have the format (batch, time_steps, n_fft, n_mics + n_pairs).
delays : tensor
The rounded tdoas obtained by selecting the sample with the highest
amplitude. The tensor must have the format
(batch, time_steps, n_mics + n_pairs).
"""
# Get useful dimensions
n_fft = xxs.shape[2]
# Get the max amplitude and its neighbours
tp = torch.fmod((delays - 1) + n_fft, n_fft).unsqueeze(2)
y1 = torch.gather(xxs, 2, tp).squeeze(2)
tp = torch.fmod(delays + n_fft, n_fft).unsqueeze(2)
y2 = torch.gather(xxs, 2, tp).squeeze(2)
tp = torch.fmod((delays + 1) + n_fft, n_fft).unsqueeze(2)
y3 = torch.gather(xxs, 2, tp).squeeze(2)
# Add a fractional part to the initially rounded delay
delays_frac = delays + (y1 - y3) / (2 * y1 - 4 * y2 + 2 * y3)
return delays_frac
[docs]
class SrpPhat(torch.nn.Module):
"""Steered-Response Power with Phase Transform Localization.
Arguments
---------
mics : tensor
The cartesian coordinates (xyz) in meters of each microphone.
The tensor must have the following format (n_mics, 3).
space : string
If this parameter is set to 'sphere', the localization will
be done in 3D by searching in a sphere of possible doas. If
it set to 'circle', the search will be done in 2D by searching
in a circle. By default, this parameter is set to 'sphere'.
Note: The 'circle' option isn't implemented yet.
sample_rate : int
The sample rate in Hertz of the signals to perform SRP-PHAT on.
By default, this parameter is set to 16000 Hz.
speed_sound : float
The speed of sound in the medium. The speed is expressed in meters
per second and the default value of this parameter is 343 m/s.
eps : float
A small value to avoid errors like division by 0. The default value
of this parameter is 1e-20.
Example
-------
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import SrpPhat
>>> xs_speech = read_audio('tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac')
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> fs = 16000
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise = xs_noise.unsqueeze(0)
>>> ss1 = xs_speech
>>> ns1 = 0.05 * xs_noise
>>> xs1 = ss1 + ns1
>>> ss2 = xs_speech
>>> ns2 = 0.20 * xs_noise
>>> xs2 = ss2 + ns2
>>> ss = torch.cat((ss1,ss2), dim=0)
>>> ns = torch.cat((ns1,ns2), dim=0)
>>> xs = torch.cat((xs1,xs2), dim=0)
>>> mics = torch.zeros((4,3), dtype=torch.float)
>>> mics[0,:] = torch.FloatTensor([-0.05, -0.05, +0.00])
>>> mics[1,:] = torch.FloatTensor([-0.05, +0.05, +0.00])
>>> mics[2,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> mics[3,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> srpphat = SrpPhat(mics=mics)
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> doas = srpphat(XXs)
"""
def __init__(
self,
mics,
space="sphere",
sample_rate=16000,
speed_sound=343.0,
eps=1e-20,
):
super().__init__()
# Generate the doas
if space == "sphere":
self.doas = sphere()
if space == "circle":
pass
# Generate associated taus with the doas
self.taus = doas2taus(
self.doas, mics=mics, fs=sample_rate, c=speed_sound
)
# Save epsilon
self.eps = eps
[docs]
def forward(self, XXs):
""" Perform SRP-PHAT localization on a signal by computing a steering
vector and then by using the utility function _srp_phat to extract the doas.
The result is a tensor containing the directions of arrival (xyz coordinates
(in meters) in the direction of the sound source). The output tensor
has the format (batch, time_steps, 3).
This localization method uses Global Coherence Field (GCF):
https://www.researchgate.net/publication/221491705_Speaker_localization_based_on_oriented_global_coherence_field
Arguments
---------
XXs : tensor
The covariance matrices of the input signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
"""
# Get useful dimensions
n_fft = XXs.shape[2]
# Generate the steering vector
As = steering(self.taus.to(XXs.device), n_fft)
# Perform srp-phat
doas = SrpPhat._srp_phat(XXs=XXs, As=As, doas=self.doas, eps=self.eps)
return doas
@staticmethod
def _srp_phat(XXs, As, doas, eps=1e-20):
"""Perform srp-phat to find the direction of arrival
of the sound source. The result is a tensor containing the directions
of arrival (xyz coordinates (in meters) in the direction of the sound source).
The output tensor has the format: (batch, time_steps, 3).
Arguments
---------
XXs : tensor
The covariance matrices of the input signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
As : tensor
The steering vector that cover the all the potential directions
of arrival. The tensor must have the format
(n_doas, n_fft/2 + 1, 2, n_mics).
doas : tensor
All the possible directions of arrival that will be scanned. The
tensor must have the format (n_doas, 3).
"""
# Putting on the right device
As = As.to(XXs.device)
doas = doas.to(XXs.device)
# Get useful dimensions
n_mics = As.shape[3]
# Get the indices for the pairs of microphones
idx = torch.triu_indices(n_mics, n_mics)
# Generate the demixing vector from the steering vector
As_1_re = As[:, :, 0, idx[0, :]]
As_1_im = As[:, :, 1, idx[0, :]]
As_2_re = As[:, :, 0, idx[1, :]]
As_2_im = As[:, :, 1, idx[1, :]]
Ws_re = As_1_re * As_2_re + As_1_im * As_2_im
Ws_im = As_1_re * As_2_im - As_1_im * As_2_re
Ws_re = Ws_re.reshape(Ws_re.shape[0], -1)
Ws_im = Ws_im.reshape(Ws_im.shape[0], -1)
# Get unique covariance values to reduce the number of computations
XXs_val, XXs_idx = torch.unique(XXs, return_inverse=True, dim=1)
# Perform the phase transform
XXs_re = XXs_val[:, :, :, 0, :]
XXs_im = XXs_val[:, :, :, 1, :]
XXs_re = XXs_re.reshape((XXs_re.shape[0], XXs_re.shape[1], -1))
XXs_im = XXs_im.reshape((XXs_im.shape[0], XXs_im.shape[1], -1))
XXs_abs = torch.sqrt(XXs_re ** 2 + XXs_im ** 2) + eps
XXs_re_norm = XXs_re / XXs_abs
XXs_im_norm = XXs_im / XXs_abs
# Project on the demixing vectors, and keep only real part
Ys_A = torch.matmul(XXs_re_norm, Ws_re.transpose(0, 1))
Ys_B = torch.matmul(XXs_im_norm, Ws_im.transpose(0, 1))
Ys = Ys_A - Ys_B
# Get maximum points
_, doas_idx = torch.max(Ys, dim=2)
# Repeat for each frame
doas = (doas[doas_idx, :])[:, XXs_idx, :]
return doas
[docs]
class Music(torch.nn.Module):
"""Multiple Signal Classification (MUSIC) localization.
Arguments
---------
mics : tensor
The cartesian coordinates (xyz) in meters of each microphone.
The tensor must have the following format (n_mics, 3).
space : string
If this parameter is set to 'sphere', the localization will
be done in 3D by searching in a sphere of possible doas. If
it set to 'circle', the search will be done in 2D by searching
in a circle. By default, this parameter is set to 'sphere'.
Note: The 'circle' option isn't implemented yet.
sample_rate : int
The sample rate in Hertz of the signals to perform SRP-PHAT on.
By default, this parameter is set to 16000 Hz.
speed_sound : float
The speed of sound in the medium. The speed is expressed in meters
per second and the default value of this parameter is 343 m/s.
eps : float
A small value to avoid errors like division by 0. The default value
of this parameter is 1e-20.
n_sig : int
An estimation of the number of sound sources. The default value is set
to one source.
Example
-------
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import SrpPhat
>>> xs_speech = read_audio('tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac')
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> fs = 16000
>>> xs_speech = xs_speech.unsqueeze(0) # [batch, time, channels]
>>> xs_noise = xs_noise.unsqueeze(0)
>>> ss1 = xs_speech
>>> ns1 = 0.05 * xs_noise
>>> xs1 = ss1 + ns1
>>> ss2 = xs_speech
>>> ns2 = 0.20 * xs_noise
>>> xs2 = ss2 + ns2
>>> ss = torch.cat((ss1,ss2), dim=0)
>>> ns = torch.cat((ns1,ns2), dim=0)
>>> xs = torch.cat((xs1,xs2), dim=0)
>>> mics = torch.zeros((4,3), dtype=torch.float)
>>> mics[0,:] = torch.FloatTensor([-0.05, -0.05, +0.00])
>>> mics[1,:] = torch.FloatTensor([-0.05, +0.05, +0.00])
>>> mics[2,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> mics[3,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> music = Music(mics=mics)
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> doas = music(XXs)
"""
def __init__(
self,
mics,
space="sphere",
sample_rate=16000,
speed_sound=343.0,
eps=1e-20,
n_sig=1,
):
super().__init__()
# Generate the doas
if space == "sphere":
self.doas = sphere()
if space == "circle":
pass
# Generate associated taus with the doas
self.taus = doas2taus(
self.doas, mics=mics, fs=sample_rate, c=speed_sound
)
# Save epsilon
self.eps = eps
# Save number of signals
self.n_sig = n_sig
[docs]
def forward(self, XXs):
"""Perform MUSIC localization on a signal by computing a steering
vector and then by using the utility function _music to extract the doas.
The result is a tensor containing the directions of arrival (xyz coordinates
(in meters) in the direction of the sound source). The output tensor
has the format (batch, time_steps, 3).
Arguments
---------
XXs : tensor
The covariance matrices of the input signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
"""
# Get useful dimensions
n_fft = XXs.shape[2]
# Generate the steering vector
As = steering(self.taus.to(XXs.device), n_fft)
# Perform music
doas = Music._music(
XXs=XXs, As=As, doas=self.doas, n_sig=self.n_sig, eps=self.eps
)
return doas
@staticmethod
def _music(XXs, As, doas, n_sig, eps=1e-20):
"""Perform multiple signal classification to find the
direction of arrival of the sound source. The result
has the format: (batch, time_steps, 3).
Arguments
---------
XXs : tensor
The covariance matrices of the input signal. The tensor must
have the format (batch, time_steps, n_fft/2 + 1, 2, n_mics + n_pairs).
As : tensor
The steering vector that covers the all the potential directions
of arrival. The tensor must have the format.
(n_doas, n_fft/2 + 1, 2, n_mics).
doas : tensor
All the possible directions of arrival that will be scanned. The
tensor must have the format (n_doas, 3).
n_sig : int
The number of signals in the signal + noise subspace (default is 1).
"""
# Putting on the right device
As = As.to(XXs.device)
doas = doas.to(XXs.device)
# Collecting data
n_mics = As.shape[3]
n_doas = As.shape[0]
n_bins = As.shape[2]
svd_range = n_mics - n_sig
# Get unique values to reduce computations
XXs_val, XXs_idx = torch.unique(XXs, return_inverse=True, dim=1)
# Singular value decomposition
Us, _ = eig.svdl(XXs_val)
# Format for the projection
Us = Us.unsqueeze(2).repeat(1, 1, n_doas, 1, 1, 1, 1)
Us_re = Us[..., range(0, svd_range), 0]
Us_im = Us[..., range(0, svd_range), 1]
# Fixing the format of the steering vector
As = (
As.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(6)
.permute(0, 1, 2, 3, 6, 5, 4)
)
As = As.repeat(Us.shape[0], Us.shape[1], 1, 1, 1, 1, 1)
As_re = As[..., 0]
As_im = As[..., 1]
# Applying MUSIC's formula
As_mm_Us_re = torch.matmul(As_re, Us_re) + torch.matmul(As_im, Us_im)
As_mm_Us_im = torch.matmul(As_re, Us_im) - torch.matmul(As_im, Us_re)
As_mm_Us_abs = torch.sqrt(As_mm_Us_re ** 2 + As_mm_Us_im ** 2)
As_mm_Us_sum = torch.sum(As_mm_Us_abs, dim=5)
As_As_abs = torch.sum(As_re ** 2, dim=5) + torch.sum(As_im ** 2, dim=5)
Ps = (As_As_abs / (As_mm_Us_sum + eps)).squeeze(4)
Ys = torch.sum(Ps, dim=3) / n_bins
# Get maximum points
_, doas_idx = torch.max(Ys, dim=2)
doas = (doas[doas_idx, :])[:, XXs_idx, :]
return doas
[docs]
def doas2taus(doas, mics, fs, c=343.0):
"""This function converts directions of arrival (xyz coordinates
expressed in meters) in time differences of arrival (expressed in
samples). The result has the following format: (batch, time_steps, n_mics).
Arguments
---------
doas : tensor
The directions of arrival expressed with cartesian coordinates (xyz)
in meters. The tensor must have the following format: (batch, time_steps, 3).
mics : tensor
The cartesian position (xyz) in meters of each microphone.
The tensor must have the following format (n_mics, 3).
fs : int
The sample rate in Hertz of the signals.
c : float
The speed of sound in the medium. The speed is expressed in meters
per second and the default value of this parameter is 343 m/s.
Example
-------
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.multi_mic import sphere, doas2taus
>>> xs = read_audio('tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac')
>>> xs = xs.unsqueeze(0) # [batch, time, channels]
>>> fs = 16000
>>> mics = torch.zeros((4,3), dtype=torch.float)
>>> mics[0,:] = torch.FloatTensor([-0.05, -0.05, +0.00])
>>> mics[1,:] = torch.FloatTensor([-0.05, +0.05, +0.00])
>>> mics[2,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> mics[3,:] = torch.FloatTensor([+0.05, +0.05, +0.00])
>>> doas = sphere()
>>> taus = doas2taus(doas, mics, fs)
"""
taus = (fs / c) * torch.matmul(doas.to(mics.device), mics.transpose(0, 1))
return taus
[docs]
def tdoas2taus(tdoas):
""" This function selects the tdoas of each channel and put them
in a tensor. The result has the following format:
(batch, time_steps, n_mics).
Arguments:
----------
tdoas : tensor
The time difference of arrival (TDOA) (in samples) for
each timestamp. The tensor has the format
(batch, time_steps, n_mics + n_pairs).
Example
-------
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, tdoas2taus
>>>
>>> xs_speech = read_audio(
... 'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> xs = xs_speech + 0.05 * xs_noise
>>> xs = xs.unsqueeze(0)
>>> fs = 16000
>>>
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>>
>>> Xs = stft(xs)
>>> XXs = cov(Xs)
>>> tdoas = gccphat(XXs)
>>> taus = tdoas2taus(tdoas)
"""
n_pairs = tdoas.shape[len(tdoas.shape) - 1]
n_channels = int(((1 + 8 * n_pairs) ** 0.5 - 1) / 2)
taus = tdoas[..., range(0, n_channels)]
return taus
[docs]
def steering(taus, n_fft):
""" This function computes a steering vector by using the time differences
of arrival for each channel (in samples) and the number of bins (n_fft).
The result has the following format: (batch, time_step, n_fft/2 + 1, 2, n_mics).
Arguments:
----------
taus : tensor
The time differences of arrival for each channel. The tensor must have
the following format: (batch, time_steps, n_mics).
n_fft : int
The number of bins resulting of the STFT. It is assumed that the
argument "onesided" was set to True for the STFT.
Example:
--------f
>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> from speechbrain.processing.features import STFT
>>> from speechbrain.processing.multi_mic import Covariance
>>> from speechbrain.processing.multi_mic import GccPhat, tdoas2taus, steering
>>>
>>> xs_speech = read_audio(
... 'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac'
... )
>>> xs_noise = read_audio('tests/samples/multi-mic/noise_diffuse.flac')
>>> xs = xs_speech + 0.05 * xs_noise
>>> xs = xs.unsqueeze(0) # [batch, time, channels]
>>> fs = 16000
>>> stft = STFT(sample_rate=fs)
>>> cov = Covariance()
>>> gccphat = GccPhat()
>>>
>>> Xs = stft(xs)
>>> n_fft = Xs.shape[2]
>>> XXs = cov(Xs)
>>> tdoas = gccphat(XXs)
>>> taus = tdoas2taus(tdoas)
>>> As = steering(taus, n_fft)
"""
# Collecting useful numbers
pi = 3.141592653589793
frame_size = int((n_fft - 1) * 2)
# Computing the different parts of the steering vector
omegas = 2 * pi * torch.arange(0, n_fft, device=taus.device) / frame_size
omegas = omegas.repeat(taus.shape + (1,))
taus = taus.unsqueeze(len(taus.shape)).repeat(
(1,) * len(taus.shape) + (n_fft,)
)
# Assembling the steering vector
a_re = torch.cos(-omegas * taus)
a_im = torch.sin(-omegas * taus)
a = torch.stack((a_re, a_im), len(a_re.shape))
a = a.transpose(len(a.shape) - 3, len(a.shape) - 1).transpose(
len(a.shape) - 3, len(a.shape) - 2
)
return a
[docs]
def sphere(levels_count=4):
""" This function generates cartesian coordinates (xyz) for a set
of points forming a 3D sphere. The coordinates are expressed in
meters and can be used as doas. The result has the format:
(n_points, 3).
Arguments
---------
levels_count : int
A number proportional to the number of points that the user
wants to generate.
- If levels_count = 1, then the sphere will have 42 points
- If levels_count = 2, then the sphere will have 162 points
- If levels_count = 3, then the sphere will have 642 points
- If levels_count = 4, then the sphere will have 2562 points
- If levels_count = 5, then the sphere will have 10242 points
- ...
By default, levels_count is set to 4.
Example
-------
>>> import torch
>>> from speechbrain.processing.multi_mic import sphere
>>> doas = sphere()
"""
# Generate points at level 0
h = (5.0 ** 0.5) / 5.0
r = (2.0 / 5.0) * (5.0 ** 0.5)
pi = 3.141592654
pts = torch.zeros((12, 3), dtype=torch.float)
pts[0, :] = torch.FloatTensor([0, 0, 1])
pts[11, :] = torch.FloatTensor([0, 0, -1])
pts[range(1, 6), 0] = r * torch.sin(2.0 * pi * torch.arange(0, 5) / 5.0)
pts[range(1, 6), 1] = r * torch.cos(2.0 * pi * torch.arange(0, 5) / 5.0)
pts[range(1, 6), 2] = h
pts[range(6, 11), 0] = (
-1.0 * r * torch.sin(2.0 * pi * torch.arange(0, 5) / 5.0)
)
pts[range(6, 11), 1] = (
-1.0 * r * torch.cos(2.0 * pi * torch.arange(0, 5) / 5.0)
)
pts[range(6, 11), 2] = -1.0 * h
# Generate triangles at level 0
trs = torch.zeros((20, 3), dtype=torch.long)
trs[0, :] = torch.LongTensor([0, 2, 1])
trs[1, :] = torch.LongTensor([0, 3, 2])
trs[2, :] = torch.LongTensor([0, 4, 3])
trs[3, :] = torch.LongTensor([0, 5, 4])
trs[4, :] = torch.LongTensor([0, 1, 5])
trs[5, :] = torch.LongTensor([9, 1, 2])
trs[6, :] = torch.LongTensor([10, 2, 3])
trs[7, :] = torch.LongTensor([6, 3, 4])
trs[8, :] = torch.LongTensor([7, 4, 5])
trs[9, :] = torch.LongTensor([8, 5, 1])
trs[10, :] = torch.LongTensor([4, 7, 6])
trs[11, :] = torch.LongTensor([5, 8, 7])
trs[12, :] = torch.LongTensor([1, 9, 8])
trs[13, :] = torch.LongTensor([2, 10, 9])
trs[14, :] = torch.LongTensor([3, 6, 10])
trs[15, :] = torch.LongTensor([11, 6, 7])
trs[16, :] = torch.LongTensor([11, 7, 8])
trs[17, :] = torch.LongTensor([11, 8, 9])
trs[18, :] = torch.LongTensor([11, 9, 10])
trs[19, :] = torch.LongTensor([11, 10, 6])
# Generate next levels
for levels_index in range(0, levels_count):
# 0
# / \
# A---B
# / \ / \
# 1---C---2
trs_count = trs.shape[0]
subtrs_count = trs_count * 4
subtrs = torch.zeros((subtrs_count, 6), dtype=torch.long)
subtrs[0 * trs_count + torch.arange(0, trs_count), 0] = trs[:, 0]
subtrs[0 * trs_count + torch.arange(0, trs_count), 1] = trs[:, 0]
subtrs[0 * trs_count + torch.arange(0, trs_count), 2] = trs[:, 0]
subtrs[0 * trs_count + torch.arange(0, trs_count), 3] = trs[:, 1]
subtrs[0 * trs_count + torch.arange(0, trs_count), 4] = trs[:, 2]
subtrs[0 * trs_count + torch.arange(0, trs_count), 5] = trs[:, 0]
subtrs[1 * trs_count + torch.arange(0, trs_count), 0] = trs[:, 0]
subtrs[1 * trs_count + torch.arange(0, trs_count), 1] = trs[:, 1]
subtrs[1 * trs_count + torch.arange(0, trs_count), 2] = trs[:, 1]
subtrs[1 * trs_count + torch.arange(0, trs_count), 3] = trs[:, 1]
subtrs[1 * trs_count + torch.arange(0, trs_count), 4] = trs[:, 1]
subtrs[1 * trs_count + torch.arange(0, trs_count), 5] = trs[:, 2]
subtrs[2 * trs_count + torch.arange(0, trs_count), 0] = trs[:, 2]
subtrs[2 * trs_count + torch.arange(0, trs_count), 1] = trs[:, 0]
subtrs[2 * trs_count + torch.arange(0, trs_count), 2] = trs[:, 1]
subtrs[2 * trs_count + torch.arange(0, trs_count), 3] = trs[:, 2]
subtrs[2 * trs_count + torch.arange(0, trs_count), 4] = trs[:, 2]
subtrs[2 * trs_count + torch.arange(0, trs_count), 5] = trs[:, 2]
subtrs[3 * trs_count + torch.arange(0, trs_count), 0] = trs[:, 0]
subtrs[3 * trs_count + torch.arange(0, trs_count), 1] = trs[:, 1]
subtrs[3 * trs_count + torch.arange(0, trs_count), 2] = trs[:, 1]
subtrs[3 * trs_count + torch.arange(0, trs_count), 3] = trs[:, 2]
subtrs[3 * trs_count + torch.arange(0, trs_count), 4] = trs[:, 2]
subtrs[3 * trs_count + torch.arange(0, trs_count), 5] = trs[:, 0]
subtrs_flatten = torch.cat(
(subtrs[:, [0, 1]], subtrs[:, [2, 3]], subtrs[:, [4, 5]]), axis=0
)
subtrs_sorted, _ = torch.sort(subtrs_flatten, axis=1)
index_max = torch.max(subtrs_sorted)
subtrs_scalar = (
subtrs_sorted[:, 0] * (index_max + 1) + subtrs_sorted[:, 1]
)
unique_scalar, unique_indices = torch.unique(
subtrs_scalar, return_inverse=True
)
unique_values = torch.zeros(
(unique_scalar.shape[0], 2), dtype=unique_scalar.dtype
)
unique_values[:, 0] = torch.div(
unique_scalar, index_max + 1, rounding_mode="floor"
)
unique_values[:, 1] = unique_scalar - unique_values[:, 0] * (
index_max + 1
)
trs = torch.transpose(torch.reshape(unique_indices, (3, -1)), 0, 1)
pts = pts[unique_values[:, 0], :] + pts[unique_values[:, 1], :]
pts /= torch.repeat_interleave(
torch.unsqueeze(torch.sum(pts ** 2, axis=1) ** 0.5, 1), 3, 1
)
return pts