speechbrain.utils.autocast moduleο
This module implements utilities and abstractions for use with
torch.autocast, i.e. Automatic Mixed Precision.
- Authors
Sylvain de Langen 2023
Adel Moumen 2025
Summaryο
Classes:
Configuration for automatic mixed precision (AMP). |
|
A context manager that conditionally enables |
Functions:
Decorator for forward methods which, by default, disables autocast and casts any floating-point tensor parameters into the specified dtype (much like |
Referenceο
- class speechbrain.utils.autocast.AMPConfig(dtype: dtype)[source]ο
Bases:
objectConfiguration for automatic mixed precision (AMP).
- Parameters:
dtype (torch.dtype) β The dtype to use for AMP.
- class speechbrain.utils.autocast.TorchAutocast(*args, **kwargs)[source]ο
Bases:
objectA context manager that conditionally enables
torch.autocastfor GPU operations.This manager wraps around
torch.autocastto automatically enable autocasting when running on a GPU and a data type other than float32 is specified. If the desired data type is float32, autocasting is bypassed and the context manager behaves as a no-op.- Parameters:
*args (tuple) β Positional arguments forwarded to
torch.autocast. See the PyTorch documentation: https://pytorch.org/docs/stable/amp.html#torch.autocast**kwargs (dict) β Keyword arguments forwarded to
torch.autocast. Typically includes thedtypeargument to specify the desired precision. See the PyTorch documentation for more details.
- __enter__()[source]ο
Enter the autocast context.
- Returns:
The result of entering the underlying autocast context manager.
- Return type:
context
- Raises:
RuntimeError β If an error occurs while entering the autocast context and the context provides βdeviceβ and βfast_dtypeβ attributes, a RuntimeError is raised with additional diagnostic information.
- speechbrain.utils.autocast.fwd_default_precision(fwd: Callable | None = None, cast_inputs: dtype | None = torch.float32)[source]ο
Decorator for forward methods which, by default, disables autocast and casts any floating-point tensor parameters into the specified dtype (much like
torch.amp.custom_fwd).The wrapped forward will gain an additional
force_allow_autocastkeyword parameter. When set toTrue, the function will ignorecast_inputsand will not disable autocast, as if this decorator was not specified. (Thus, modules can specify a default recommended precision, and users can override that behavior when desired.)This decorator now supports both CPU and CUDA by using
torch.amp.custom_fwdwith the device_type inferred from input tensors at runtime.When autocast is not active, this decorator does not change any behavior.
- Parameters:
fwd (Optional[Callable]) β
The function to wrap. If omitted, returns a partial application of the decorator, e.g. allowing
new_decorator = fwd_default_precision(cast_inputs=torch.float32).Reminder: If you are decorating a function directly, this argument is already specified implicitly.
cast_inputs (Optional[torch.dtype]) β
If not
None(the default beingtorch.float32), then any floating-point inputs to the wrapped function will be cast to the specified type.Note: When autocasting is enabled, output tensors of autocast-compatible operations may be of the autocast data type. Disabling autocast without casting inputs will not change this fact, so lower precision operations can happen even inside of an autocast-disabled region, which this argument helps avoid if desired.
- Return type:
The wrapped function