"""Configuration and utility classes for classes for Dynamic Chunk Training, as
often used for the training of streaming-capable models in speech recognition.
The definition of Dynamic Chunk Training is based on that of the following
paper, though a lot of the literature refers to the same definition:
https://arxiv.org/abs/2012.05481
Authors
* Sylvain de Langen 2023
"""
import speechbrain as sb
from dataclasses import dataclass
from typing import Optional
import torch
# NOTE: this configuration object is intended to be relatively specific to
# Dynamic Chunk Training; if you want to implement a different similar type of
# chunking different from that, you should consider using a different object.
[docs]
@dataclass
class DynChunkTrainConfig:
"""Dynamic Chunk Training configuration object for use with transformers,
often in ASR for streaming.
This object may be used both to configure masking at training time and for
run-time configuration of DynChunkTrain-ready models."""
chunk_size: int
"""Size in frames of a single chunk, always `>0`.
If chunkwise streaming should be disabled at some point, pass an optional
streaming config parameter."""
left_context_size: Optional[int] = None
"""Number of *chunks* (not frames) visible to the left, always `>=0`.
If zero, then chunks can never attend to any past chunk.
If `None`, the left context is infinite (but use
`.is_fininite_left_context` for such a check)."""
[docs]
def is_infinite_left_context(self) -> bool:
"""Returns true if the left context is infinite (i.e. any chunk can
attend to any past frame)."""
return self.left_context_size is None
[docs]
def left_context_size_frames(self) -> Optional[int]:
"""Returns the number of left context *frames* (not chunks).
If ``None``, the left context is infinite.
See also the ``left_context_size`` field."""
if self.left_context_size is None:
return None
return self.chunk_size * self.left_context_size
[docs]
@dataclass
class DynChunkTrainConfigRandomSampler:
"""Helper class to generate a DynChunkTrainConfig at runtime depending on the current
stage.
Example
-------
>>> from speechbrain.core import Stage
>>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
>>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfigRandomSampler
>>> # for the purpose of this example, we test a scenario with a 100%
>>> # chance of the (24, None) scenario to occur
>>> sampler = DynChunkTrainConfigRandomSampler(
... chunkwise_prob=1.0,
... chunk_size_min=24,
... chunk_size_max=24,
... limited_left_context_prob=0.0,
... left_context_chunks_min=16,
... left_context_chunks_max=16,
... test_config=DynChunkTrainConfig(32, 16),
... valid_config=None
... )
>>> one_train_config = sampler(Stage.TRAIN)
>>> one_train_config
DynChunkTrainConfig(chunk_size=24, left_context_size=None)
>>> one_train_config.is_infinite_left_context()
True
>>> sampler(Stage.TEST)
DynChunkTrainConfig(chunk_size=32, left_context_size=16)"""
chunkwise_prob: float
"""When sampling (during `Stage.TRAIN`), the probability that a finite chunk
size will be used.
In the other case, any chunk can attend to the full past and future
context."""
chunk_size_min: int
"""When sampling a random chunk size, the minimum chunk size that can be
picked."""
chunk_size_max: int
"""When sampling a random chunk size, the maximum chunk size that can be
picked."""
limited_left_context_prob: float
"""When sampling a random chunk size, the probability that the left context
will be limited.
In the other case, any chunk can attend to the full past context."""
left_context_chunks_min: int
"""When sampling a random left context size, the minimum number of left
context chunks that can be picked."""
left_context_chunks_max: int
"""When sampling a random left context size, the maximum number of left
context chunks that can be picked."""
test_config: Optional[DynChunkTrainConfig] = None
"""The configuration that should be used for `Stage.TEST`.
When `None`, evaluation is done with full context (i.e. non-streaming)."""
valid_config: Optional[DynChunkTrainConfig] = None
"""The configuration that should be used for `Stage.VALID`.
When `None`, evaluation is done with full context (i.e. non-streaming)."""
def _sample_bool(self, prob: float) -> bool:
"""Samples a random boolean with a probability, in a way that depends on
PyTorch's RNG seed.
Arguments
---------
prob : float
Probability (0..1) to return True (False otherwise)."""
return torch.rand((1,)).item() < prob
[docs]
def __call__(self, stage: "sb.core.Stage") -> DynChunkTrainConfig:
"""In training stage, samples a random DynChunkTrain configuration.
During validation or testing, returns the relevant configuration.
Arguments
---------
stage : speechbrain.core.Stage
Current stage of training or evaluation.
In training mode, a random DynChunkTrainConfig will be sampled
according to the specified probabilities and ranges.
During evaluation, the relevant DynChunkTrainConfig attribute will
be picked.
"""
if stage == sb.core.Stage.TRAIN:
# When training for streaming, for each batch, we have a
# `dynamic_chunk_prob` probability of sampling a chunk size
# between `dynamic_chunk_min` and `_max`, otherwise output
# frames can see anywhere in the future.
if self._sample_bool(self.chunkwise_prob):
chunk_size = torch.randint(
self.chunk_size_min, self.chunk_size_max + 1, (1,),
).item()
if self._sample_bool(self.limited_left_context_prob):
left_context_chunks = torch.randint(
self.left_context_chunks_min,
self.left_context_chunks_max + 1,
(1,),
).item()
else:
left_context_chunks = None
return DynChunkTrainConfig(chunk_size, left_context_chunks)
return None
elif stage == sb.core.Stage.TEST:
return self.test_config
elif stage == sb.core.Stage.VALID:
return self.valid_config
else:
raise AttributeError(f"Unsupported stage found {stage}")