speechbrain.utils.autocast moduleο
This module implements utilities and abstractions for use with
torch.autocast
, i.e. Automatic Mixed Precision.
- Authors
Sylvain de Langen 2023
Adel Moumen 2025
Summaryο
Classes:
Configuration for automatic mixed precision (AMP). |
|
A context manager that conditionally enables |
Functions:
Decorator for forward methods which, by default, disables autocast and casts any floating-point tensor parameters into the specified dtype (much like |
Referenceο
- class speechbrain.utils.autocast.AMPConfig(dtype: dtype)[source]ο
Bases:
object
Configuration for automatic mixed precision (AMP).
- Parameters:
dtype (torch.dtype) β The dtype to use for AMP.
- dtype: dtypeο
- class speechbrain.utils.autocast.TorchAutocast(*args, **kwargs)[source]ο
Bases:
object
A context manager that conditionally enables
torch.autocast
for GPU operations.This manager wraps around
torch.autocast
to automatically enable autocasting when running on a GPU and a data type other than float32 is specified. If the desired data type is float32, autocasting is bypassed and the context manager behaves as a no-op.- Parameters:
*args (tuple) β Positional arguments forwarded to
torch.autocast
. See the PyTorch documentation: https://pytorch.org/docs/stable/amp.html#torch.autocast**kwargs (dict) β Keyword arguments forwarded to
torch.autocast
. Typically includes thedtype
argument to specify the desired precision. See the PyTorch documentation for more details.
- __enter__()[source]ο
Enter the autocast context.
- Returns:
The result of entering the underlying autocast context manager.
- Return type:
context
- Raises:
RuntimeError β If an error occurs while entering the autocast context and the context provides βdeviceβ and βfast_dtypeβ attributes, a RuntimeError is raised with additional diagnostic information.
- speechbrain.utils.autocast.fwd_default_precision(fwd: Callable | None = None, cast_inputs: dtype | None = torch.float32)[source]ο
Decorator for forward methods which, by default, disables autocast and casts any floating-point tensor parameters into the specified dtype (much like
torch.cuda.amp.custom_fwd
).The wrapped forward will gain an additional
force_allow_autocast
keyword parameter. When set toTrue
, the function will ignorecast_inputs
and will not disable autocast, as if this decorator was not specified. (Thus, modules can specify a default recommended precision, and users can override that behavior when desired.)Note that as of PyTorch 2.1.1, this will only affect CUDA AMP. Non-CUDA AMP will be unaffected and no input tensors will be cast! This usecase may be supported by this function in the future.
When autocast is not active, this decorator does not change any behavior.
- Parameters:
fwd (Optional[Callable]) β
The function to wrap. If omitted, returns a partial application of the decorator, e.g. allowing
new_decorator = fwd_default_precision(cast_inputs=torch.float32)
.Reminder: If you are decorating a function directly, this argument is already specified implicitly.
cast_inputs (Optional[torch.dtype]) β
If not
None
(the default beingtorch.float32
), then any floating-point inputs to the wrapped function will be cast to the specified type.Note: When autocasting is enabled, output tensors of autocast-compatible operations may be of the autocast data type. Disabling autocast without casting inputs will not change this fact, so lower precision operations can happen even inside of an autocast-disabled region, which this argument helps avoid if desired.
- Return type:
The wrapped function