Source code for speechbrain.inference.diarization

""" Specifies the inference interfaces for diarization modules.

Authors:
 * Aku Rouhe 2021
 * Peter Plantinga 2021
 * Loren Lugosch 2020
 * Mirco Ravanelli 2020
 * Titouan Parcollet 2021
 * Abdel Heba 2021
 * Andreas Nautsch 2022, 2023
 * Pooneh Mousavi 2023
 * Sylvain de Langen 2023
 * Adel Moumen 2023
 * Pradnya Kandarkar 2023
"""
import torch
from speechbrain.inference.interfaces import Pretrained


[docs] class Speech_Emotion_Diarization(Pretrained): """A ready-to-use SED interface (audio -> emotions and their durations) Arguments --------- hparams Hyperparameters (from HyperPyYAML) Example ------- >>> from speechbrain.inference.diarization import Speech_Emotion_Diarization >>> tmpdir = getfixture("tmpdir") >>> sed_model = Speech_Emotion_Diarization.from_hparams(source="speechbrain/emotion-diarization-wavlm-large", savedir=tmpdir,) # doctest: +SKIP >>> sed_model.diarize_file("speechbrain/emotion-diarization-wavlm-large/example.wav") # doctest: +SKIP """ MODULES_NEEDED = ["input_norm", "wav2vec", "output_mlp"] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def diarize_file(self, path): """Get emotion diarization of a spoken utterance. Arguments --------- path : str Path to audio file which to diarize. Returns ------- list of dictionary: List[Dict[List]] The emotions and their temporal boundaries. """ waveform = self.load_audio(path) # Fake a batch: batch = waveform.unsqueeze(0) rel_length = torch.tensor([1.0]) frame_class = self.diarize_batch(batch, rel_length, [path]) return frame_class
[docs] def encode_batch(self, wavs, wav_lens): """Encodes audios into fine-grained emotional embeddings Arguments --------- wavs : torch.tensor Batch of waveforms [batch, time, channels]. wav_lens : torch.tensor Lengths of the waveforms relative to the longest one in the batch, tensor of shape [batch]. The longest one should have relative length 1.0 and others len(waveform) / max_length. Used for ignoring padding. Returns ------- torch.tensor The encoded batch """ if len(wavs.shape) == 1: wavs = wavs.unsqueeze(0) # Assign full length if wav_lens is not assigned if wav_lens is None: wav_lens = torch.ones(wavs.shape[0], device=self.device) wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) wavs = self.mods.input_norm(wavs, wav_lens) outputs = self.mods.wav2vec2(wavs) return outputs
[docs] def diarize_batch(self, wavs, wav_lens, batch_id): """Get emotion diarization of a batch of waveforms. The waveforms should already be in the model's desired format. You can call: ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` to get a correctly converted signal in most cases. Arguments --------- wavs : torch.tensor Batch of waveforms [batch, time, channels]. wav_lens : torch.tensor Lengths of the waveforms relative to the longest one in the batch, tensor of shape [batch]. The longest one should have relative length 1.0 and others len(waveform) / max_length. Used for ignoring padding. batch_id : torch.tensor id of each batch (file names etc.) Returns ------- list of dictionary: List[Dict[List]] The emotions and their temporal boundaries. """ outputs = self.encode_batch(wavs, wav_lens) averaged_out = self.hparams.avg_pool(outputs) outputs = self.mods.output_mlp(averaged_out) outputs = self.hparams.log_softmax(outputs) score, index = torch.max(outputs, dim=-1) preds = self.hparams.label_encoder.decode_torch(index) results = self.preds_to_diarization(preds, batch_id) return results
[docs] def preds_to_diarization(self, prediction, batch_id): """Convert frame-wise predictions into a dictionary of diarization results. Returns ------- dictionary A dictionary with the start/end of each emotion """ results = {} for i in range(len(prediction)): pred = prediction[i] lol = [] for j in range(len(pred)): start = round(self.hparams.stride * 0.02 * j, 2) end = round(start + self.hparams.window_length * 0.02, 2) lol.append([batch_id[i], start, end, pred[j]]) lol = self.merge_ssegs_same_emotion_adjacent(lol) results[batch_id[i]] = [ {"start": k[1], "end": k[2], "emotion": k[3]} for k in lol ] return results
[docs] def forward(self, wavs, wav_lens, batch_id): """Get emotion diarization for a batch of waveforms.""" return self.diarize_batch(wavs, wav_lens, batch_id)
[docs] def is_overlapped(self, end1, start2): """Returns True if segments are overlapping. Arguments --------- end1 : float End time of the first segment. start2 : float Start time of the second segment. Returns ------- overlapped : bool True of segments overlapped else False. Example ------- >>> from speechbrain.processing import diarization as diar >>> diar.is_overlapped(5.5, 3.4) True >>> diar.is_overlapped(5.5, 6.4) False """ if start2 > end1: return False else: return True
[docs] def merge_ssegs_same_emotion_adjacent(self, lol): """Merge adjacent sub-segs if they are the same emotion. Arguments --------- lol : list of list Each list contains [utt_id, sseg_start, sseg_end, emo_label]. Returns ------- new_lol : list of list new_lol contains adjacent segments merged from the same emotion ID. Example ------- >>> from speechbrain.utils.EDER import merge_ssegs_same_emotion_adjacent >>> lol=[['u1', 0.0, 7.0, 'a'], ... ['u1', 7.0, 9.0, 'a'], ... ['u1', 9.0, 11.0, 'n'], ... ['u1', 11.0, 13.0, 'n'], ... ['u1', 13.0, 15.0, 'n'], ... ['u1', 15.0, 16.0, 'a']] >>> merge_ssegs_same_emotion_adjacent(lol) [['u1', 0.0, 9.0, 'a'], ['u1', 9.0, 15.0, 'n'], ['u1', 15.0, 16.0, 'a']] """ new_lol = [] # Start from the first sub-seg sseg = lol[0] flag = False for i in range(1, len(lol)): next_sseg = lol[i] # IF sub-segments overlap AND has same emotion THEN merge if ( self.is_overlapped(sseg[2], next_sseg[1]) and sseg[3] == next_sseg[3] ): sseg[2] = next_sseg[2] # just update the end time # This is important. For the last sseg, if it is the same emotion then merge # Make sure we don't append the last segment once more. Hence, set FLAG=True if i == len(lol) - 1: flag = True new_lol.append(sseg) else: new_lol.append(sseg) sseg = next_sseg # Add last segment only when it was skipped earlier. if flag is False: new_lol.append(lol[-1]) return new_lol