speechbrain.utils.streaming moduleο
Utilities to assist with designing and training streaming models.
Authors * Sylvain de Langen 2023
Summaryο
Functions:
Randomizes parts of the input sequence several times in order to detect dependencies between input frames and output frames, aka whether a given output frame depends on a given input frame. |
|
Returns a matplotlib figure of a dependency matrix generated by |
|
Split an input tensor |
|
Converts a single |
Referenceο
- speechbrain.utils.streaming.split_fixed_chunks(x, chunk_size, dim=-1)[source]ο
Split an input tensor
x
into a list of chunk tensors of sizechunk_size
alongside dimensiondim
. Useful for splitting up sequences with chunks of fixed sizes.If dimension
dim
cannot be evenly split bychunk_size
, then the last chunk will be smaller thanchunk_size
.- Parameters:
- Returns:
A chunk list of tensors, see description and example. Guarantees
.size(dim) <= chunk_size
.- Return type:
List[Tensor]
Example
>>> import torch >>> from speechbrain.utils.streaming import split_fixed_chunks >>> x = torch.zeros((16, 10000, 80)) >>> chunks = split_fixed_chunks(x, 128, dim=1) >>> len(chunks) 79 >>> chunks[0].shape torch.Size([16, 128, 80]) >>> chunks[-1].shape torch.Size([16, 16, 80])
- speechbrain.utils.streaming.split_wav_lens(chunk_lens, wav_lens)[source]ο
Converts a single
wav_lens
tensor into a list ofchunk_count
tensors, typically useful when chunking signals withsplit_fixed_chunks
.wav_lens
represents the relative length of each audio within a batch, which is typically used for masking. This function computes the relative length at chunk level.- Parameters:
chunk_lens (List[int]) β Length of the sequence of every chunk. For example, if
chunks
was returned fromsplit_fixed_chunks(x, chunk_size, dim=1)
, then this should be[chk.size(1) for chk in chunks]
.wav_lens (torch.Tensor) β Relative lengths of audio within a batch. For example, for an input signal of 100 frames and a batch of 3 elements,
(1.0, 0.5, 0.25)
would mean the batch holds audio of 100 frames, 50 frames and 25 frames respectively.
- Returns:
A list of chunked wav_lens, see description and example.
- Return type:
List[Tensor]
Example
>>> import torch >>> from speechbrain.utils.streaming import split_wav_lens, split_fixed_chunks >>> x = torch.zeros((3, 20, 80)) >>> chunks = split_fixed_chunks(x, 8, dim=1) >>> len(chunks) 3 >>> # 20 frames, 13 frames, 17 frames >>> wav_lens = torch.tensor([1.0, 0.65, 0.85]) >>> chunked_wav_lens = split_wav_lens([c.size(1) for c in chunks], wav_lens) >>> chunked_wav_lens [tensor([1., 1., 1.]), tensor([1.0000, 0.6250, 1.0000]), tensor([1.0000, 0.0000, 0.2500])] >>> # wav 1 covers 62.5% (5/8) of the second chunk's frames
- speechbrain.utils.streaming.infer_dependency_matrix(model: Callable, seq_shape: tuple, in_stride: int = 1)[source]ο
Randomizes parts of the input sequence several times in order to detect dependencies between input frames and output frames, aka whether a given output frame depends on a given input frame.
This can prove useful to check whether a model behaves correctly in a streaming context and does not contain accidental dependencies to future frames that couldnβt be known in a streaming scenario.
Note that this can get very computationally expensive for very long sequences.
Furthermore, this expects inference to be fully deterministic, else false dependencies may be found. This also means that the model must be in eval mode, to inhibit things like dropout layers.
- Parameters:
model (Callable) β Can be a model or a function (potentially emulating streaming functionality). Does not require to be a trained model, random weights should usually suffice.
seq_shape (tuple) β The function tries inferring by randomizing parts of the input sequence in order to detect unwanted dependencies. The shape is expected to look like
[batch_size, seq_len, num_feats]
, wherebatch_size
may be1
.in_stride (int) β Consider only N-th input, for when the input sequences are very long (e.g. raw audio) and the output is shorter (subsampled, filters, etc.)
- Returns:
dependencies β Matrix representing whether an output is dependent on an input; index using
[in_frame_idx, out_frame_idx]
.True
indicates a detected dependency.- Return type:
BoolTensor
- speechbrain.utils.streaming.plot_dependency_matrix(deps)[source]ο
Returns a matplotlib figure of a dependency matrix generated by
infer_dependency_matrix
.At a given point, a red square indicates that a given output frame (y-axis) was to depend on a given input frame (x-axis).
For example, a fully red image means that all output frames were dependent on all the history. This could be the case of a bidirectional RNN, or a transformer model, for example.
- Parameters:
deps (BoolTensor) β Matrix returned by
infer_dependency_matrix
or one in a compatible format.- Return type:
matplotlib figure of a dependency matrix.