speechbrain.dataio.dataset module

Dataset examples for loading individual data points

Authors
  • Aku Rouhe 2020

  • Samuele Cornell 2020

Summary

Classes:

DynamicItemDataset

Dataset that reads, wrangles, and produces dicts.

FilteredSortedDynamicItemDataset

Possibly filtered, possibly sorted DynamicItemDataset.

Functions:

add_dynamic_item

Helper for adding the same item to multiple datasets.

set_output_keys

Helper for setting the same item to multiple datasets.

Reference

class speechbrain.dataio.dataset.DynamicItemDataset(data, dynamic_items=[], output_keys=[])[source]

Bases: Generic[torch.utils.data.dataset.T_co]

Dataset that reads, wrangles, and produces dicts.

Each data point dict provides some items (by key), for example, a path to a wavefile with the key “wav_file”. When a data point is fetched from this Dataset, more items are produced dynamically, based on pre-existing items and other dynamic created items. For example, a dynamic item could take the wavfile path and load the audio from the disk.

The dynamic items can depend on other dynamic items: a suitable evaluation order is used automatically, as long as there are no circular dependencies.

A specified list of keys is collected in the output dict. These can be items in the original data or dynamic items. If some dynamic items are not requested, nor depended on by other requested items, they won’t be computed. So for example if a user simply wants to iterate over the text, the time-consuming audio loading can be skipped.

About the format: Takes a dict of dicts as the collection of data points to read/wrangle. The top level keys are data point IDs. Each data point (example) dict should have the same keys, corresponding to different items in that data point.

Altogether the data collection could look like this:

>>> data = {
...  "spk1utt1": {
...      "wav_file": "/path/to/spk1utt1.wav",
...      "text": "hello world",
...      "speaker": "spk1",
...      },
...  "spk1utt2": {
...      "wav_file": "/path/to/spk1utt2.wav",
...      "text": "how are you world",
...      "speaker": "spk1",
...      }
... }

Note

The top-level key, the data point id, is implicitly added as an item in the data point, with the key “id”

Each dynamic item is configured by three things: a key, a func, and a list of argkeys. The key should be unique among all the items (dynamic or not) in each data point. The func is any callable, and it returns the dynamic item’s value. The callable is called with the values of other items as specified by the argkeys list (as positional args, passed in the order specified by argkeys).

The dynamic_items configuration could look like this:

>>> import torch
>>> dynamic_items = [
...     {"func": lambda l: torch.Tensor(l),
...     "takes": ["wav_loaded"],
...     "provides": "wav"},
...     {"func": lambda path: [ord(c)/100 for c in path],  # Fake "loading"
...     "takes": ["wav_file"],
...     "provides": "wav_loaded"},
...     {"func": lambda t: t.split(),
...     "takes": ["text"],
...     "provides": "words"}]

With these, different views of the data can be loaded:

>>> from speechbrain.dataio.dataloader import SaveableDataLoader
>>> from speechbrain.dataio.batch import PaddedBatch
>>> dataset = DynamicItemDataset(data, dynamic_items)
>>> dataloader = SaveableDataLoader(dataset, collate_fn=PaddedBatch,
...     batch_size=2)
>>> # First, create encoding for words:
>>> dataset.set_output_keys(["words"])
>>> encoding = {}
>>> next_id = 1
>>> for batch in dataloader:
...     for sent in batch.words:
...         for word in sent:
...             if word not in encoding:
...                 encoding[word] = next_id
...                 next_id += 1
>>> # Next, add an encoded words_tensor dynamic item:
>>> dataset.add_dynamic_item(
...     func = lambda ws: torch.tensor([encoding[w] for w in ws],
...             dtype=torch.long),
...     takes = ["words"],
...     provides = "words_encoded")
>>> # Now we can get word and audio tensors:
>>> dataset.set_output_keys(["id", "wav", "words_encoded"])
>>> batch = next(iter(dataloader))
>>> batch.id
['spk1utt1', 'spk1utt2']
>>> batch.wav  # +ELLIPSIS
PaddedData(data=tensor([[0.4700, 1.1200, ...
>>> batch.words_encoded
PaddedData(data=tensor([[1, 2, 0, 0],
        [3, 4, 5, 2]]), lengths=tensor([0.5000, 1.0000]))

Output keys can also be a map:

>>> dataset.set_output_keys({"id":"id", "signal": "wav", "words": "words_encoded"})
>>> batch = next(iter(dataloader))
>>> batch.words
PaddedData(data=tensor([[1, 2, 0, 0],
        [3, 4, 5, 2]]), lengths=tensor([0.5000, 1.0000]))
Parameters
  • data (dict) – Dictionary containing single data points (e.g. utterances).

  • dynamic_items (list, optional) –

    Configuration for the dynamic items produced when fetching an example. List of DynamicItems or dicts with the format:

    func: <callable> # To be called
    takes: <list> # key or list of keys of args this takes
    provides: key # key or list of keys that this provides
    

  • output_keys (dict, list, optional) –

    List of keys (either directly available in data or dynamic items) to include in the output dict when data points are fetched.

    If a dict is given; it is used to map internal keys to output keys. From the output_keys dict key:value pairs the key appears outside, and value is the internal key.

add_dynamic_item(func, takes=None, provides=None)[source]

Makes a new dynamic item available on the dataset.

Two calling conventions. For DynamicItem objects, just use: add_dynamic_item(dynamic_item). But otherwise, should use: add_dynamic_item(func, takes, provides).

See speechbrain.utils.data_pipeline.

Parameters
  • func (callable, DynamicItem) – If a DynamicItem is given, adds that directly. Otherwise a DynamicItem is created, and this specifies the callable to use. If a generator function is given, then create a GeneratorDynamicItem. Otherwise creates a normal DynamicItem.

  • takes (list, str) – List of keys. When func is called, each key is resolved to either an entry in the data or the output of another dynamic_item. The func is then called with these as positional arguments, in the same order as specified here. A single arg can be given directly.

  • provides (str) – Unique key or keys that this provides.

set_output_keys(keys)[source]

Use this to change the output keys.

These are the keys that are actually evaluated when a data point is fetched from the dataset.

Parameters

keys (dict, list) –

List of keys (str) to produce in output.

If a dict is given; it is used to map internal keys to output keys. From the output_keys dict key:value pairs the key appears outside, and value is the internal key.

output_keys_as(keys)[source]

Context manager to temporarily set output keys.

Example

>>> dataset = DynamicItemDataset({"a":{"x":1,"y":2},"b":{"x":3,"y":4}},
...     output_keys = ["x"])
>>> with dataset.output_keys_as(["y"]):
...     print(dataset[0])
{'y': 2}
>>> print(dataset[0])
{'x': 1}

Note

Not thread-safe. While in this context manager, the output keys are affected for any call.

filtered_sorted(key_min_value={}, key_max_value={}, key_test={}, sort_key=None, reverse=False, select_n=None)[source]

Get a filtered and/or sorted version of this, shares static data.

The reason to implement these operations in the same method is that computing some dynamic items may be expensive, and this way the filtering and sorting steps don’t need to compute the dynamic items twice.

Parameters
  • key_min_value (dict) – Map from key (in data or in dynamic items) to limit, will only keep data_point if data_point[key] >= limit

  • key_max_value (dict) – Map from key (in data or in dynamic items) to limit, will only keep data_point if data_point[key] <= limit

  • key_test (dict) – Map from key (in data or in dynamic items) to func, will only keep data_point if bool(func(data_point[key])) == True

  • sort_key (None, str) – If not None, sort by data_point[sort_key]. Default is ascending order.

  • reverse (bool) – If True, sort in descending order.

  • select_n (None, int) – If not None, only keep (at most) the first n filtered data_points. The possible sorting is applied, but only on the first n data points found. Meant for debugging.

Returns

Shares the static data, but has its own output keys and dynamic items (initially deep copied from this, so they have the same dynamic items available)

Return type

FilteredSortedDynamicItemDataset

Note

Temporarily changes the output keys!

classmethod from_json(json_path, replacements={}, dynamic_items=[], output_keys=[])[source]

Load a data prep JSON file and create a Dataset based on it.

classmethod from_csv(csv_path, replacements={}, dynamic_items=[], output_keys=[])[source]

Load a data prep CSV file and create a Dataset based on it.

classmethod from_arrow_dataset(dataset, replacements={}, dynamic_items=[], output_keys=[])[source]

Loading a prepared huggingface dataset

class speechbrain.dataio.dataset.FilteredSortedDynamicItemDataset(from_dataset, data_ids)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

Possibly filtered, possibly sorted DynamicItemDataset.

Shares the static data (reference). Has its own dynamic_items and output_keys (deepcopy).

classmethod from_json(json_path, replacements={}, dynamic_items=None, output_keys=None)[source]
classmethod from_csv(csv_path, replacements={}, dynamic_items=None, output_keys=None)[source]
speechbrain.dataio.dataset.add_dynamic_item(datasets, func, takes=None, provides=None)[source]

Helper for adding the same item to multiple datasets.

speechbrain.dataio.dataset.set_output_keys(datasets, output_keys)[source]

Helper for setting the same item to multiple datasets.