Source code for speechbrain.utils.data_utils

"""This library gathers utilities for data io operation.

Authors
 * Mirco Ravanelli 2020
 * Aku Rouhe 2020
 * Samuele Cornell 2020
"""

import os
import shutil
import urllib.request
import collections.abc
import torch
import tqdm
import pathlib
import speechbrain as sb
import re


[docs]def undo_padding(batch, lengths): """Produces Python lists given a batch of sentences with their corresponding relative lengths. Arguments --------- batch : tensor Batch of sentences gathered in a batch. lengths : tensor Relative length of each sentence in the batch. Example ------- >>> batch=torch.rand([4,100]) >>> lengths=torch.tensor([0.5,0.6,0.7,1.0]) >>> snt_list=undo_padding(batch, lengths) >>> len(snt_list) 4 """ batch_max_len = batch.shape[1] as_list = [] for seq, seq_length in zip(batch, lengths): actual_size = int(torch.round(seq_length * batch_max_len)) seq_true = seq.narrow(0, 0, actual_size) as_list.append(seq_true.tolist()) return as_list
[docs]def get_all_files( dirName, match_and=None, match_or=None, exclude_and=None, exclude_or=None ): """Returns a list of files found within a folder. Different options can be used to restrict the search to some specific patterns. Arguments --------- dirName : str The directory to search. match_and : list A list that contains patterns to match. The file is returned if it matches all the entries in `match_and`. match_or : list A list that contains patterns to match. The file is returned if it matches one or more of the entries in `match_or`. exclude_and : list A list that contains patterns to match. The file is returned if it matches none of the entries in `exclude_and`. exclude_or : list A list that contains pattern to match. The file is returned if it fails to match one of the entries in `exclude_or`. Example ------- >>> get_all_files('samples/rir_samples', match_and=['3.wav']) ['samples/rir_samples/rir3.wav'] """ # Match/exclude variable initialization match_and_entry = True match_or_entry = True exclude_or_entry = False exclude_and_entry = False # Create a list of file and sub directories listOfFile = os.listdir(dirName) allFiles = list() # Iterate over all the entries for entry in listOfFile: # Create full path fullPath = os.path.join(dirName, entry) # If entry is a directory then get the list of files in this directory if os.path.isdir(fullPath): allFiles = allFiles + get_all_files( fullPath, match_and=match_and, match_or=match_or, exclude_and=exclude_and, exclude_or=exclude_or, ) else: # Check match_and case if match_and is not None: match_and_entry = False match_found = 0 for ele in match_and: if ele in fullPath: match_found = match_found + 1 if match_found == len(match_and): match_and_entry = True # Check match_or case if match_or is not None: match_or_entry = False for ele in match_or: if ele in fullPath: match_or_entry = True break # Check exclude_and case if exclude_and is not None: match_found = 0 for ele in exclude_and: if ele in fullPath: match_found = match_found + 1 if match_found == len(exclude_and): exclude_and_entry = True # Check exclude_and case if exclude_or is not None: exclude_or_entry = False for ele in exclude_or: if ele in fullPath: exclude_or_entry = True break # If needed, append the current file to the output list if ( match_and_entry and match_or_entry and not (exclude_and_entry) and not (exclude_or_entry) ): allFiles.append(fullPath) return allFiles
[docs]def split_list(seq, num): """Returns a list of splits in the sequence. Arguments --------- seq : iterable The input list, to be split. num : int The number of chunks to produce. Example ------- >>> split_list([1, 2, 3, 4, 5, 6, 7, 8, 9], 4) [[1, 2], [3, 4], [5, 6], [7, 8, 9]] """ # Average length of the chunk avg = len(seq) / float(num) out = [] last = 0.0 # Creating the chunks while last < len(seq): out.append(seq[int(last) : int(last + avg)]) last += avg return out
[docs]def recursive_items(dictionary): """Yield each (key, value) of a nested dictionary. Arguments --------- dictionary : dict The nested dictionary to list. Yields ------ `(key, value)` tuples from the dictionary. Example ------- >>> rec_dict={'lev1': {'lev2': {'lev3': 'current_val'}}} >>> [item for item in recursive_items(rec_dict)] [('lev3', 'current_val')] """ for key, value in dictionary.items(): if type(value) is dict: yield from recursive_items(value) else: yield (key, value)
[docs]def recursive_update(d, u, must_match=False): """Similar function to `dict.update`, but for a nested `dict`. From: https://stackoverflow.com/a/3233356 If you have to a nested mapping structure, for example: {"a": 1, "b": {"c": 2}} Say you want to update the above structure with: {"b": {"d": 3}} This function will produce: {"a": 1, "b": {"c": 2, "d": 3}} Instead of: {"a": 1, "b": {"d": 3}} Arguments --------- d : dict Mapping to be updated. u : dict Mapping to update with. must_match : bool Whether to throw an error if the key in `u` does not exist in `d`. Example ------- >>> d = {'a': 1, 'b': {'c': 2}} >>> recursive_update(d, {'b': {'d': 3}}) >>> d {'a': 1, 'b': {'c': 2, 'd': 3}} """ # TODO: Consider cases where u has branch off k, but d does not. # e.g. d = {"a":1}, u = {"a": {"b": 2 }} for k, v in u.items(): if isinstance(v, collections.abc.Mapping) and k in d: recursive_update(d.get(k, {}), v) elif must_match and k not in d: raise KeyError( f"Override '{k}' not found in: {[key for key in d.keys()]}" ) else: d[k] = v
[docs]def download_file( source, dest, unpack=False, dest_unpack=None, replace_existing=False ): """Downloads the file from the given source and saves it in the given destination path. Arguments --------- source : path or url Path of the source file. If the source is an URL, it downloads it from the web. dest : path Destination path. unpack : bool If True, it unpacks the data in the dest folder. replace_existing : bool If True, replaces the existing files. """ try: if sb.utils.distributed.if_main_process(): class DownloadProgressBar(tqdm.tqdm): def update_to(self, b=1, bsize=1, tsize=None): if tsize is not None: self.total = tsize self.update(b * bsize - self.n) # Create the destination directory if it doesn't exist dest_dir = pathlib.Path(dest).resolve().parent dest_dir.mkdir(parents=True, exist_ok=True) if "http" not in source: shutil.copyfile(source, dest) elif not os.path.isfile(dest) or ( os.path.isfile(dest) and replace_existing ): print(f"Downloading {source} to {dest}") with DownloadProgressBar( unit="B", unit_scale=True, miniters=1, desc=source.split("/")[-1], ) as t: urllib.request.urlretrieve( source, filename=dest, reporthook=t.update_to ) else: print(f"{dest} exists. Skipping download") # Unpack if necessary if unpack: if dest_unpack is None: dest_unpack = os.path.dirname(dest) print(f"Extracting {dest} to {dest_unpack}") shutil.unpack_archive(dest, dest_unpack) finally: sb.utils.distributed.ddp_barrier()
[docs]def pad_right_to( tensor: torch.Tensor, target_shape: (list, tuple), mode="constant", value=0, ): """ This function takes a torch tensor of arbitrary shape and pads it to target shape by appending values on the right. Parameters ---------- tensor : input torch tensor Input tensor whose dimension we need to pad. target_shape : (list, tuple) Target shape we want for the target tensor its len must be equal to tensor.ndim mode : str Pad mode, please refer to torch.nn.functional.pad documentation. value : float Pad value, please refer to torch.nn.functional.pad documentation. Returns ------- tensor : torch.Tensor Padded tensor. valid_vals : list List containing proportion for each dimension of original, non-padded values. """ assert len(target_shape) == tensor.ndim pads = [] # this contains the abs length of the padding for each dimension. valid_vals = [] # thic contains the relative lengths for each dimension. i = len(target_shape) - 1 # iterating over target_shape ndims j = 0 while i >= 0: assert ( target_shape[i] >= tensor.shape[i] ), "Target shape must be >= original shape for every dim" pads.extend([0, target_shape[i] - tensor.shape[i]]) valid_vals.append(tensor.shape[j] / target_shape[j]) i -= 1 j += 1 tensor = torch.nn.functional.pad(tensor, pads, mode=mode, value=value) return tensor, valid_vals
[docs]def batch_pad_right(tensors: list, mode="constant", value=0): """Given a list of torch tensors it batches them together by padding to the right on each dimension in order to get same length for all. Parameters ---------- tensors : list List of tensor we wish to pad together. mode : str Padding mode see torch.nn.functional.pad documentation. value : float Padding value see torch.nn.functional.pad documentation. Returns ------- tensor : torch.Tensor Padded tensor. valid_vals : list List containing proportion for each dimension of original, non-padded values. """ if not len(tensors): raise IndexError("Tensors list must not be empty") if len(tensors) == 1: # if there is only one tensor in the batch we simply unsqueeze it. return tensors[0].unsqueeze(0), torch.tensor([1.0]) if not ( any( [tensors[i].ndim == tensors[0].ndim for i in range(1, len(tensors))] ) ): raise IndexError("All tensors must have same number of dimensions") # FIXME we limit the support here: we allow padding of only the last dimension # need to remove this when feat extraction is updated to handle multichannel. max_shape = [] for dim in range(tensors[0].ndim): if dim != (tensors[0].ndim - 1): if not all( [x.shape[dim] == tensors[0].shape[dim] for x in tensors[1:]] ): raise EnvironmentError( "Tensors should have same dimensions except for last one" ) max_shape.append(max([x.shape[dim] for x in tensors])) batched = [] valid = [] for t in tensors: # for each tensor we apply pad_right_to padded, valid_percent = pad_right_to( t, max_shape, mode=mode, value=value ) batched.append(padded) valid.append(valid_percent[0]) batched = torch.stack(batched) return batched, torch.tensor(valid)
[docs]def split_by_whitespace(text): """A very basic functional version of str.split""" return text.split()
[docs]def recursive_to(data, *args, **kwargs): """Moves data to device, or other type, and handles containers. Very similar to torch.utils.data._utils.pin_memory.pin_memory, but applies .to() instead. """ if isinstance(data, torch.Tensor): return data.to(*args, **kwargs) elif isinstance(data, collections.abc.Mapping): return { k: recursive_to(sample, *args, **kwargs) for k, sample in data.items() } elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple return type(data)( *(recursive_to(sample, *args, **kwargs) for sample in data) ) elif isinstance(data, collections.abc.Sequence): return [recursive_to(sample, *args, **kwargs) for sample in data] elif hasattr(data, "to"): return data.to(*args, **kwargs) # What should be done with unknown data? # For now, just return as they are else: return data
np_str_obj_array_pattern = re.compile(r"[SaUO]")
[docs]def mod_default_collate(batch): r"""Makes a tensor from list of batch values. Note that this doesn't need to zip(*) values together as PaddedBatch connects them already (by key). Here the idea is not to error out. This is modified from: https://github.com/pytorch/pytorch/blob/c0deb231db76dbea8a9d326401417f7d1ce96ed5/torch/utils/data/_utils/collate.py#L42 """ elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None try: if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, 0, out=out) except RuntimeError: # Unequal size: return batch elif ( elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_" ): try: if ( elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap" ): # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: return batch return mod_default_collate([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) except RuntimeError: # Unequal size return batch elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) else: return batch
[docs]def split_path(path): """Splits a path to source and filename This also handles URLs and Huggingface hub paths, in addition to regular paths. Arguments --------- path : str Returns ------- str Source str Filename """ if "/" in path: return path.rsplit("/", maxsplit=1) else: # Interpret as path to file in current directory. return "./", path