speechbrain.utils.autocast module

This module implements utilities and abstractions for use with torch.autocast, i.e. Automatic Mixed Precision.

Authors
  • Sylvain de Langen 2023

  • Adel Moumen 2025

Summary

Classes:

AMPConfig

Configuration for automatic mixed precision (AMP).

TorchAutocast

A context manager that conditionally enables torch.autocast for GPU operations.

Functions:

fwd_default_precision

Decorator for forward methods which, by default, disables autocast and casts any floating-point tensor parameters into the specified dtype (much like torch.amp.custom_fwd).

Reference

class speechbrain.utils.autocast.AMPConfig(dtype: dtype)[source]

Bases: object

Configuration for automatic mixed precision (AMP).

Parameters:

dtype (torch.dtype) – The dtype to use for AMP.

dtype: dtype
classmethod from_name(name)[source]

Create an AMPConfig from a string name.

Parameters:

name (str) – The name of the AMPConfig to create. Must be one of fp32, fp16, or bf16.

Returns:

The AMPConfig corresponding to the name.

Return type:

AMPConfig

class speechbrain.utils.autocast.TorchAutocast(*args, **kwargs)[source]

Bases: object

A context manager that conditionally enables torch.autocast for GPU operations.

This manager wraps around torch.autocast to automatically enable autocasting when running on a GPU and a data type other than float32 is specified. If the desired data type is float32, autocasting is bypassed and the context manager behaves as a no-op.

Parameters:
  • *args (tuple) – Positional arguments forwarded to torch.autocast. See the PyTorch documentation: https://pytorch.org/docs/stable/amp.html#torch.autocast

  • **kwargs (dict) – Keyword arguments forwarded to torch.autocast. Typically includes the dtype argument to specify the desired precision. See the PyTorch documentation for more details.

__enter__()[source]

Enter the autocast context.

Returns:

The result of entering the underlying autocast context manager.

Return type:

context

Raises:

RuntimeError – If an error occurs while entering the autocast context and the context provides β€˜device’ and β€˜fast_dtype’ attributes, a RuntimeError is raised with additional diagnostic information.

__exit__(exc_type, exc_val, exc_tb)[source]

Exit the autocast context.

Parameters:
  • exc_type (type) – Exception type if an exception occurred, otherwise None.

  • exc_val (Exception) – Exception instance if an exception occurred, otherwise None.

  • exc_tb (traceback) – Traceback object if an exception occurred, otherwise None.

Returns:

The result of exiting the underlying autocast context manager.

Return type:

bool or None

speechbrain.utils.autocast.fwd_default_precision(fwd: Callable | None = None, cast_inputs: dtype | None = torch.float32)[source]

Decorator for forward methods which, by default, disables autocast and casts any floating-point tensor parameters into the specified dtype (much like torch.amp.custom_fwd).

The wrapped forward will gain an additional force_allow_autocast keyword parameter. When set to True, the function will ignore cast_inputs and will not disable autocast, as if this decorator was not specified. (Thus, modules can specify a default recommended precision, and users can override that behavior when desired.)

This decorator now supports both CPU and CUDA by using torch.amp.custom_fwd with the device_type inferred from input tensors at runtime.

When autocast is not active, this decorator does not change any behavior.

Parameters:
  • fwd (Optional[Callable]) –

    The function to wrap. If omitted, returns a partial application of the decorator, e.g. allowing new_decorator = fwd_default_precision(cast_inputs=torch.float32).

    Reminder: If you are decorating a function directly, this argument is already specified implicitly.

  • cast_inputs (Optional[torch.dtype]) –

    If not None (the default being torch.float32), then any floating-point inputs to the wrapped function will be cast to the specified type.

    Note: When autocasting is enabled, output tensors of autocast-compatible operations may be of the autocast data type. Disabling autocast without casting inputs will not change this fact, so lower precision operations can happen even inside of an autocast-disabled region, which this argument helps avoid if desired.

Return type:

The wrapped function