speechbrain.utils.autocast module

This module implements utilities and abstractions for use with torch.autocast, i.e. Automatic Mixed Precision.

Authors
  • Sylvain de Langen 2023

  • Adel Moumen 2025

Summary

Classes:

AMPConfig

Configuration for automatic mixed precision (AMP).

TorchAutocast

A context manager that conditionally enables torch.autocast for GPU operations.

Functions:

fwd_default_precision

Decorator for forward methods which, by default, disables autocast and casts any floating-point tensor parameters into the specified dtype (much like torch.cuda.amp.custom_fwd).

Reference

class speechbrain.utils.autocast.AMPConfig(dtype: dtype)[source]

Bases: object

Configuration for automatic mixed precision (AMP).

Parameters:

dtype (torch.dtype) – The dtype to use for AMP.

dtype: dtype
classmethod from_name(name)[source]

Create an AMPConfig from a string name.

Parameters:

name (str) – The name of the AMPConfig to create. Must be one of fp32, fp16, or bf16.

Returns:

The AMPConfig corresponding to the name.

Return type:

AMPConfig

class speechbrain.utils.autocast.TorchAutocast(*args, **kwargs)[source]

Bases: object

A context manager that conditionally enables torch.autocast for GPU operations.

This manager wraps around torch.autocast to automatically enable autocasting when running on a GPU and a data type other than float32 is specified. If the desired data type is float32, autocasting is bypassed and the context manager behaves as a no-op.

Parameters:
  • *args (tuple) – Positional arguments forwarded to torch.autocast. See the PyTorch documentation: https://pytorch.org/docs/stable/amp.html#torch.autocast

  • **kwargs (dict) – Keyword arguments forwarded to torch.autocast. Typically includes the dtype argument to specify the desired precision. See the PyTorch documentation for more details.

__enter__()[source]

Enter the autocast context.

Returns:

The result of entering the underlying autocast context manager.

Return type:

context

Raises:

RuntimeError – If an error occurs while entering the autocast context and the context provides β€˜device’ and β€˜fast_dtype’ attributes, a RuntimeError is raised with additional diagnostic information.

__exit__(exc_type, exc_val, exc_tb)[source]

Exit the autocast context.

Parameters:
  • exc_type (type) – Exception type if an exception occurred, otherwise None.

  • exc_val (Exception) – Exception instance if an exception occurred, otherwise None.

  • exc_tb (traceback) – Traceback object if an exception occurred, otherwise None.

Returns:

The result of exiting the underlying autocast context manager.

Return type:

bool or None

speechbrain.utils.autocast.fwd_default_precision(fwd: Callable | None = None, cast_inputs: dtype | None = torch.float32)[source]

Decorator for forward methods which, by default, disables autocast and casts any floating-point tensor parameters into the specified dtype (much like torch.cuda.amp.custom_fwd).

The wrapped forward will gain an additional force_allow_autocast keyword parameter. When set to True, the function will ignore cast_inputs and will not disable autocast, as if this decorator was not specified. (Thus, modules can specify a default recommended precision, and users can override that behavior when desired.)

Note that as of PyTorch 2.1.1, this will only affect CUDA AMP. Non-CUDA AMP will be unaffected and no input tensors will be cast! This usecase may be supported by this function in the future.

When autocast is not active, this decorator does not change any behavior.

Parameters:
  • fwd (Optional[Callable]) –

    The function to wrap. If omitted, returns a partial application of the decorator, e.g. allowing new_decorator = fwd_default_precision(cast_inputs=torch.float32).

    Reminder: If you are decorating a function directly, this argument is already specified implicitly.

  • cast_inputs (Optional[torch.dtype]) –

    If not None (the default being torch.float32), then any floating-point inputs to the wrapped function will be cast to the specified type.

    Note: When autocasting is enabled, output tensors of autocast-compatible operations may be of the autocast data type. Disabling autocast without casting inputs will not change this fact, so lower precision operations can happen even inside of an autocast-disabled region, which this argument helps avoid if desired.

Return type:

The wrapped function