Source code for speechbrain.dataio.preprocess

"""Preprocessors for audio"""

import torch

from speechbrain.augment.time_domain import Resample


[docs] class AudioNormalizer: """Normalizes audio into a standard format Arguments --------- sample_rate : int The sampling rate to which the incoming signals should be converted. mix : {"avg-to-mono", "keep"} "avg-to-mono" - add all channels together and normalize by number of channels. This also removes the channel dimension, resulting in [time] format tensor. "keep" - don't normalize channel information Example ------- >>> import torchaudio >>> example_file = 'tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac' >>> signal, sr = torchaudio.load(example_file, channels_first = False) >>> normalizer = AudioNormalizer(sample_rate=8000) >>> normalized = normalizer(signal, sr) >>> signal.shape torch.Size([160000, 4]) >>> normalized.shape torch.Size([80000]) NOTE ---- This will also upsample audio. However, upsampling cannot produce meaningful information in the bandwidth which it adds. Generally models will not work well for upsampled data if they have not specifically been trained to do so. """ def __init__(self, sample_rate=16000, mix="avg-to-mono"): self.sample_rate = sample_rate if mix not in ["avg-to-mono", "keep"]: raise ValueError(f"Unexpected mixing configuration {mix}") self.mix = mix self._cached_resamplers = {}
[docs] def __call__(self, audio, sample_rate): """Perform normalization Arguments --------- audio : torch.Tensor The input waveform torch tensor. Assuming [time, channels], or [time]. sample_rate : int Rate the audio was sampled at. Returns ------- audio : torch.Tensor Channel- and sample-rate-normalized audio. """ if sample_rate not in self._cached_resamplers: # Create a Resample instance from this newly seen SR to internal SR self._cached_resamplers[sample_rate] = Resample( sample_rate, self.sample_rate ) resampler = self._cached_resamplers[sample_rate] resampled = resampler(audio.unsqueeze(0)).squeeze(0) return self._mix(resampled)
def _mix(self, audio): """Handle channel mixing""" flat_input = audio.dim() == 1 if self.mix == "avg-to-mono": if flat_input: return audio return torch.mean(audio, 1) if self.mix == "keep": return audio