"""Library for Downloading and Preparing Datasets for Data Augmentation,
This library provides functions for downloading datasets from the web and
preparing the necessary CSV data manifest files for use by data augmenters.
Authors:
* Mirco Ravanelli 2023
"""
import os
import torchaudio
from speechbrain.utils.data_utils import download_file, get_all_files
from speechbrain.utils.distributed import main_process_only
from speechbrain.utils.logger import get_logger
# Logger init
logger = get_logger(__name__)
[docs]
@main_process_only
def prepare_dataset_from_URL(URL, dest_folder, ext, csv_file, max_length=None):
"""Downloads a dataset containing recordings (e.g., noise sequences)
from the provided URL and prepares the necessary CSV files for use by the noise augmenter.
Arguments
---------
URL : str
The URL of the dataset to download.
dest_folder : str
The local folder where the noisy dataset will be downloaded.
ext : str
File extensions to search for within the downloaded dataset.
csv_file : str
The path to store the prepared noise CSV file.
max_length : float
The maximum length in seconds.
Recordings longer than this will be automatically cut into pieces.
"""
# Download and unpack if necessary
data_file = os.path.join(dest_folder, "data.zip")
if not os.path.isdir(dest_folder):
download_file(URL, data_file, unpack=True)
else:
download_file(URL, data_file)
# Prepare noise csv if necessary
if not os.path.isfile(csv_file):
filelist = get_all_files(dest_folder, match_and=["." + ext])
prepare_csv(filelist, csv_file, max_length)
[docs]
@main_process_only
def prepare_csv(filelist, csv_file, max_length=None):
"""Iterate a set of wavs and write the corresponding csv file.
Arguments
---------
filelist : str
A list containing the paths of files of interest.
csv_file : str
The path to store the prepared noise CSV file.
max_length : float
The maximum length in seconds.
Recordings longer than this will be automatically cut into pieces.
"""
try:
write_csv(filelist, csv_file, max_length)
except Exception as e:
# Handle the exception or log the error message
logger.error("Exception:", exc_info=(e))
# Delete the file if something fails
if os.path.exists(csv_file):
os.remove(csv_file)
[docs]
@main_process_only
def write_csv(filelist, csv_file, max_length=None):
"""
Iterate through a list of audio files and write the corresponding CSV file.
Arguments
---------
filelist : list of str
A list containing the paths of audio files of interest.
csv_file : str
The path where to store the prepared noise CSV file.
max_length : float (optional)
The maximum recording length in seconds.
Recordings longer than this will be automatically cut into pieces.
"""
with open(csv_file, "w", encoding="utf-8") as w:
w.write("ID,duration,wav,wav_format,wav_opts\n")
for i, filename in enumerate(filelist):
_write_csv_row(w, filename, i, max_length)
def _write_csv_row(w, filename, index, max_length):
"""
Write a single row to the CSV file based on the audio file information.
Arguments
---------
w : file
The open CSV file for writing.
filename : str
The path to the audio file.
index : int
The index of the audio file in the list.
max_length : float (optional)
The maximum recording length in seconds.
"""
signal, rate = torchaudio.load(filename)
signal = _ensure_single_channel(signal, filename, rate)
ID, ext = os.path.basename(filename).split(".")
duration = signal.shape[1] / rate
if max_length is not None and duration > max_length:
_handle_long_waveform(
w, filename, ID, ext, signal, rate, duration, max_length, index
)
else:
_write_short_waveform_csv(w, ID, ext, duration, filename, index)
def _ensure_single_channel(signal, filename, rate):
"""
Ensure that the audio signal has only one channel.
Arguments
---------
signal : torch.Tensor
The audio signal.
filename : str
The path to the audio file.
rate : int
The sampling frequency of the signal.
Returns
-------
signal : Torch.Tensor
The audio signal with a single channel.
"""
if signal.shape[0] > 1:
signal = signal[0].unsqueeze(0)
torchaudio.save(filename, signal, rate)
return signal
def _handle_long_waveform(
w, filename, ID, ext, signal, rate, duration, max_length, index
):
"""
Handle long audio waveforms by cutting them into pieces and writing to the CSV.
Arguments
---------
w : file
The open CSV file for writing.
filename : str
The path to the audio file.
ID : str
The unique identifier for the audio.
ext : str
The audio file extension.
signal : torch.Tensor
The audio signal.
rate : int
The audio sample rate.
duration : float
The duration of the audio in seconds.
max_length : float
The maximum recording length in seconds.
index : int
The index of the audio file in the list.
"""
os.remove(filename)
for j in range(int(duration / max_length)):
start = int(max_length * j * rate)
stop = int(min(max_length * (j + 1), duration) * rate)
ext = filename.split(".")[1]
new_filename = filename.replace("." + ext, "_" + str(j) + "." + ext)
torchaudio.save(new_filename, signal[:, start:stop], rate)
csv_row = (
f"{ID}_{index}_{j}",
str((stop - start) / rate),
new_filename,
ext,
"\n",
)
w.write(",".join(csv_row))
def _write_short_waveform_csv(w, ID, ext, duration, filename, index):
"""
Write a CSV row for a short audio waveform.
Arguments
---------
w : file
The open CSV file for writing.
ID : str
The unique identifier for the audio.
ext : str
The audio file extension.
duration : float
The duration of the audio in seconds.
filename : str
The path to the audio file.
index : int
The index of the audio file in the list.
"""
w.write(",".join((f"{ID}_{index}", str(duration), filename, ext, "\n")))