Source code for speechbrain.dataio.legacy

"""SpeechBrain Extended CSV Compatibility."""
from speechbrain.dataio.dataset import DynamicItemDataset
import collections
import csv
import pickle
import logging
import torch
import torchaudio
import re

logger = logging.getLogger(__name__)


TORCHAUDIO_FORMATS = ["wav", "flac", "aac", "ogg", "flac", "mp3"]
ITEM_POSTFIX = "_data"

CSVItem = collections.namedtuple("CSVItem", ["data", "format", "opts"])
CSVItem.__doc__ = """The Legacy Extended CSV Data item triplet"""


[docs] class ExtendedCSVDataset(DynamicItemDataset): """Extended CSV compatibility for DynamicItemDataset. Uses the SpeechBrain Extended CSV data format, where the CSV must have an 'ID' and 'duration' fields. The rest of the fields come in triplets: ``<name>, <name>_format, <name>_opts`` These add a <name>_sb_data item in the dict. Additionally, a basic DynamicItem (see DynamicItemDataset) is created, which loads the _sb_data item. Bash-like string replacements with $to_replace are supported. NOTE ---- Mapping from legacy interface: - csv_file -> csvpath - sentence_sorting -> sorting, and "random" is not supported, use e.g. ``make_dataloader(..., shuffle = (sorting=="random"))`` - avoid_if_shorter_than -> min_duration - avoid_if_longer_than -> max_duration - csv_read -> output_keys, and if you want IDs add "id" as key Arguments --------- csvpath : str, path Path to extended CSV. replacements : dict Used for Bash-like $-prefixed substitution, e.g. ``{"data_folder": "/home/speechbrain/data"}``, which would transform `$data_folder/utt1.wav` into `/home/speechbain/data/utt1.wav` sorting : {"original", "ascending", "descending"} Keep CSV order, or sort ascending or descending by duration. min_duration : float, int Minimum duration in seconds. Discards other entries. max_duration : float, int Maximum duration in seconds. Discards other entries. dynamic_items : list Configuration for extra dynamic items produced when fetching an example. List of DynamicItems or dicts with keys:: func: <callable> # To be called takes: <list> # key or list of keys of args this takes provides: key # key or list of keys that this provides NOTE: A dynamic item is automatically added for each CSV data-triplet output_keys : list, None The list of output keys to produce. You can refer to the names of the CSV data-triplets. E.G. if the CSV has: wav,wav_format,wav_opts, then the Dataset has a dynamic item output available with key ``"wav"`` NOTE: If None, read all existing. """ def __init__( self, csvpath, replacements={}, sorting="original", min_duration=0, max_duration=36000, dynamic_items=[], output_keys=[], ): if sorting not in ["original", "ascending", "descending"]: clsname = self.__class__.__name__ raise ValueError(f"{clsname} doesn't support {sorting} sorting") # Load the CSV, init class data, di_to_add, data_names = load_sb_extended_csv( csvpath, replacements ) super().__init__(data, dynamic_items, output_keys) self.pipeline.add_dynamic_items(di_to_add) # Handle filtering, sorting: reverse = False sort_key = None if sorting == "ascending" or "descending": sort_key = "duration" if sorting == "descending": reverse = True filtered_sorted_ids = self._filtered_sorted_ids( key_min_value={"duration": min_duration}, key_max_value={"duration": max_duration}, sort_key=sort_key, reverse=reverse, ) self.data_ids = filtered_sorted_ids # Handle None output_keys (differently than Base) if not output_keys: self.set_output_keys(data_names)
[docs] def load_sb_extended_csv(csv_path, replacements={}): """Loads SB Extended CSV and formats string values. Uses the SpeechBrain Extended CSV data format, where the CSV must have an 'ID' and 'duration' fields. The rest of the fields come in triplets: ``<name>, <name>_format, <name>_opts``. These add a <name>_sb_data item in the dict. Additionally, a basic DynamicItem (see DynamicItemDataset) is created, which loads the _sb_data item. Bash-like string replacements with $to_replace are supported. This format has its restriction, but they allow some tasks to have loading specified by the CSV. Arguments ---------- csv_path : str Path to the CSV file. replacements : dict Optional dict: e.g. ``{"data_folder": "/home/speechbrain/data"}`` This is used to recursively format all string values in the data. Returns ------- dict CSV data with replacements applied. list List of DynamicItems to add in DynamicItemDataset. """ with open(csv_path, newline="") as csvfile: result = {} reader = csv.DictReader(csvfile, skipinitialspace=True) variable_finder = re.compile(r"\$([\w.]+)") if not reader.fieldnames[0] == "ID": raise KeyError( "CSV has to have an 'ID' field, with unique ids" " for all data points" ) if not reader.fieldnames[1] == "duration": raise KeyError( "CSV has to have an 'duration' field, " "with the length of the data point in seconds." ) if not len(reader.fieldnames[2:]) % 3 == 0: raise ValueError( "All named fields must have 3 entries: " "<name>, <name>_format, <name>_opts" ) names = reader.fieldnames[2::3] for row in reader: # Make a triplet for each name data_point = {} # ID: data_id = row["ID"] del row["ID"] # This is used as a key in result, instead. # Duration: data_point["duration"] = float(row["duration"]) del row["duration"] # This is handled specially. if data_id in result: raise ValueError(f"Duplicate id: {data_id}") # Replacements: # Only need to run these in the actual data, # not in _opts, _format for key, value in list(row.items())[::3]: try: row[key] = variable_finder.sub( lambda match: replacements[match[1]], value ) except KeyError: raise KeyError( f"The item {value} requires replacements " "which were not supplied." ) for i, name in enumerate(names): triplet = CSVItem(*list(row.values())[i * 3 : i * 3 + 3]) data_point[name + ITEM_POSTFIX] = triplet result[data_id] = data_point # Make a DynamicItem for each CSV entry # _read_csv_item delegates reading to further dynamic_items_to_add = [] for name in names: di = { "func": _read_csv_item, "takes": name + ITEM_POSTFIX, "provides": name, } dynamic_items_to_add.append(di) return result, dynamic_items_to_add, names
def _read_csv_item(item): """Reads the different formats supported in SB Extended CSV. Delegates to the relevant functions. """ opts = _parse_csv_item_opts(item.opts) if item.format in TORCHAUDIO_FORMATS: audio, _ = torchaudio.load(item.data) return audio.squeeze(0) elif item.format == "pkl": return read_pkl(item.data, opts) elif item.format == "string": # Just implement string reading here. # NOTE: No longer supporting # lab2ind mapping like before. # Try decoding string string = item.data try: string = string.decode("utf-8") except AttributeError: pass # Splitting elements with ' ' string = string.split(" ") return string else: raise TypeError(f"Don't know how to read {item.format}") def _parse_csv_item_opts(entry): """Parse the _opts field in a SB Extended CSV item.""" # Accepting even slightly weirdly formatted entries: entry = entry.strip() if len(entry) == 0: return {} opts = {} for opt in entry.split(" "): opt_name, opt_val = opt.split(":") opts[opt_name] = opt_val return opts
[docs] def read_pkl(file, data_options={}, lab2ind=None): """This function reads tensors store in pkl format. Arguments --------- file : str The path to file to read. data_options : dict, optional A dictionary containing options for the reader. lab2ind : dict, optional Mapping from label to integer indices. Returns ------- numpy.array The array containing the read signal. """ # Trying to read data try: with open(file, "rb") as f: pkl_element = pickle.load(f) except pickle.UnpicklingError: err_msg = "cannot read the pkl file %s" % (file) raise ValueError(err_msg) type_ok = False if isinstance(pkl_element, list): if isinstance(pkl_element[0], float): tensor = torch.FloatTensor(pkl_element) type_ok = True if isinstance(pkl_element[0], int): tensor = torch.LongTensor(pkl_element) type_ok = True if isinstance(pkl_element[0], str): # convert string to integer as specified in self.label_dict if lab2ind is not None: for index, val in enumerate(pkl_element): pkl_element[index] = lab2ind[val] tensor = torch.LongTensor(pkl_element) type_ok = True if not (type_ok): err_msg = ( "The pkl file %s can only contain list of integers, " "floats, or strings. Got %s" ) % (file, type(pkl_element[0])) raise ValueError(err_msg) else: tensor = pkl_element tensor_type = tensor.dtype # Conversion to 32 bit (if needed) if tensor_type == "float64": tensor = tensor.astype("float32") if tensor_type == "int64": tensor = tensor.astype("int32") return tensor