speechbrain.utils.streaming module

Utilities to assist with designing and training streaming models.

Authors * Sylvain de Langen 2023

Summary

Functions:

infer_dependency_matrix

Randomizes parts of the input sequence several times in order to detect dependencies between input frames and output frames, aka whether a given output frame depends on a given input frame.

plot_dependency_matrix

Returns a matplotlib figure of a dependency matrix generated by infer_dependency_matrix.

split_fixed_chunks

Split an input tensor x into a list of chunk tensors of size chunk_size alongside dimension dim.

split_wav_lens

Converts a single wav_lens tensor into a list of chunk_count tensors, typically useful when chunking signals with split_fixed_chunks.

Reference

speechbrain.utils.streaming.split_fixed_chunks(x: Tensor, chunk_size: int, dim: int = -1) List[Tensor][source]

Split an input tensor x into a list of chunk tensors of size chunk_size alongside dimension dim. Useful for splitting up sequences with chunks of fixed sizes.

If dimension dim cannot be evenly split by chunk_size, then the last chunk will be smaller than chunk_size.

Parameters:
  • x (torch.Tensor) – The tensor to split into chunks, typically a sequence or audio signal.

  • chunk_size (int) – The size of each chunk, i.e. the max size of each chunk on dimension dim.

  • dim (int) – Dimension to split alongside of, typically the time dimension.

Returns:

A chunk list of tensors, see description and example. Guarantees .size(dim) <= chunk_size.

Return type:

List[torch.Tensor]

Example

>>> import torch
>>> from speechbrain.utils.streaming import split_fixed_chunks
>>> x = torch.zeros((16, 10000, 80))
>>> chunks = split_fixed_chunks(x, 128, dim=1)
>>> len(chunks)
79
>>> chunks[0].shape
torch.Size([16, 128, 80])
>>> chunks[-1].shape
torch.Size([16, 16, 80])
speechbrain.utils.streaming.split_wav_lens(chunk_lens: List[int], wav_lens: Tensor) List[Tensor][source]

Converts a single wav_lens tensor into a list of chunk_count tensors, typically useful when chunking signals with split_fixed_chunks.

wav_lens represents the relative length of each audio within a batch, which is typically used for masking. This function computes the relative length at chunk level.

Parameters:
  • chunk_lens (List[int]) – Length of the sequence of every chunk. For example, if chunks was returned from split_fixed_chunks(x, chunk_size, dim=1), then this should be [chk.size(1) for chk in chunks].

  • wav_lens (torch.Tensor) – Relative lengths of audio within a batch. For example, for an input signal of 100 frames and a batch of 3 elements, (1.0, 0.5, 0.25) would mean the batch holds audio of 100 frames, 50 frames and 25 frames respectively.

Returns:

A list of chunked wav_lens, see description and example.

Return type:

List[torch.Tensor]

Example

>>> import torch
>>> from speechbrain.utils.streaming import split_wav_lens, split_fixed_chunks
>>> x = torch.zeros((3, 20, 80))
>>> chunks = split_fixed_chunks(x, 8, dim=1)
>>> len(chunks)
3
>>> # 20 frames, 13 frames, 17 frames
>>> wav_lens = torch.tensor([1.0, 0.65, 0.85])
>>> chunked_wav_lens = split_wav_lens([c.size(1) for c in chunks], wav_lens)
>>> chunked_wav_lens
[tensor([1., 1., 1.]), tensor([1.0000, 0.6250, 1.0000]), tensor([1.0000, 0.0000, 0.2500])]
>>> # wav 1 covers 62.5% (5/8) of the second chunk's frames
speechbrain.utils.streaming.infer_dependency_matrix(model: Callable, seq_shape: tuple, in_stride: int = 1)[source]

Randomizes parts of the input sequence several times in order to detect dependencies between input frames and output frames, aka whether a given output frame depends on a given input frame.

This can prove useful to check whether a model behaves correctly in a streaming context and does not contain accidental dependencies to future frames that couldn’t be known in a streaming scenario.

Note that this can get very computationally expensive for very long sequences.

Furthermore, this expects inference to be fully deterministic, else false dependencies may be found. This also means that the model must be in eval mode, to inhibit things like dropout layers.

Parameters:
  • model (Callable) – Can be a model or a function (potentially emulating streaming functionality). Does not require to be a trained model, random weights should usually suffice.

  • seq_shape (tuple) – The function tries inferring by randomizing parts of the input sequence in order to detect unwanted dependencies. The shape is expected to look like [batch_size, seq_len, num_feats], where batch_size may be 1.

  • in_stride (int) – Consider only N-th input, for when the input sequences are very long (e.g. raw audio) and the output is shorter (subsampled, filters, etc.)

Returns:

dependencies – Matrix representing whether an output is dependent on an input; index using [in_frame_idx, out_frame_idx]. True indicates a detected dependency.

Return type:

torch.BoolTensor

speechbrain.utils.streaming.plot_dependency_matrix(deps)[source]

Returns a matplotlib figure of a dependency matrix generated by infer_dependency_matrix.

At a given point, a red square indicates that a given output frame (y-axis) was to depend on a given input frame (x-axis).

For example, a fully red image means that all output frames were dependent on all the history. This could be the case of a bidirectional RNN, or a transformer model, for example.

Parameters:

deps (torch.BoolTensor) – Matrix returned by infer_dependency_matrix or one in a compatible format.