Source code for speechbrain.utils.autocast

"""This module implements utilities and abstractions for use with
`torch.autocast`, i.e. Automatic Mixed Precision.

Authors
 * Sylvain de Langen 2023
 * Adel Moumen 2025
"""

import functools
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Callable, Optional

import torch


[docs] @dataclass class AMPConfig: """Configuration for automatic mixed precision (AMP). Arguments --------- dtype : torch.dtype The dtype to use for AMP. """ dtype: torch.dtype
[docs] @classmethod def from_name(self, name): """Create an AMPConfig from a string name. Arguments --------- name : str The name of the AMPConfig to create. Must be one of `fp32`, `fp16`, or `bf16`. Returns ------- AMPConfig The AMPConfig corresponding to the name. """ if name is None or name == "fp32": return AMPConfig(torch.float32) elif name == "fp16": return AMPConfig(torch.float16) elif name == "bf16": return AMPConfig(torch.bfloat16) else: raise ValueError( f"Specified autocast mode ({name}) incorrect, expected one of `fp32`, `fp16`, `bf16`." )
[docs] class TorchAutocast: """ A context manager that conditionally enables ``torch.autocast`` for GPU operations. This manager wraps around ``torch.autocast`` to automatically enable autocasting when running on a GPU and a data type other than float32 is specified. If the desired data type is float32, autocasting is bypassed and the context manager behaves as a no-op. Parameters ---------- *args : tuple Positional arguments forwarded to `torch.autocast`. See the PyTorch documentation: https://pytorch.org/docs/stable/amp.html#torch.autocast **kwargs : dict Keyword arguments forwarded to `torch.autocast`. Typically includes the `dtype` argument to specify the desired precision. See the PyTorch documentation for more details. """ def __init__(self, *args, **kwargs): enabled = kwargs.get("dtype", torch.float32) != torch.float32 if enabled: self.context = torch.autocast(*args, **kwargs) else: self.context = nullcontext() # no-op context manager
[docs] def __enter__(self): """ Enter the autocast context. Returns ------- context The result of entering the underlying autocast context manager. Raises ------ RuntimeError If an error occurs while entering the autocast context and the context provides 'device' and 'fast_dtype' attributes, a RuntimeError is raised with additional diagnostic information. """ try: return self.context.__enter__() except RuntimeError as e: if hasattr(self.context, "device") and hasattr( self.context, "fast_dtype" ): device = self.context.device dtype = self.context.fast_dtype raise RuntimeError( f"Error during autocasting with dtype={dtype} on device={device}.\n" ) from e else: raise
[docs] def __exit__(self, exc_type, exc_val, exc_tb): """ Exit the autocast context. Parameters ---------- exc_type : type Exception type if an exception occurred, otherwise None. exc_val : Exception Exception instance if an exception occurred, otherwise None. exc_tb : traceback Traceback object if an exception occurred, otherwise None. Returns ------- bool or None The result of exiting the underlying autocast context manager. """ return self.context.__exit__(exc_type, exc_val, exc_tb)
def _infer_device_type(*args, **kwargs): """Infer device type from the input tensors. This function returns the device type of the first tensor found in the arguments or keyword arguments. It assumes all tensors are on the same device, which is typically the case in PyTorch operations. Arguments --------- *args: tuple Arguments that may contain tensors **kwargs: dict Keyword arguments that may contain tensors Returns ------- str Device type ('cuda', 'cpu', 'mps', etc.) """ # Check args for tensors for arg in args: if isinstance(arg, torch.Tensor): return arg.device.type # Check kwargs for tensors for value in kwargs.values(): if isinstance(value, torch.Tensor): return value.device.type # Default to cpu if no tensors found return "cpu"
[docs] def fwd_default_precision( fwd: Optional[Callable] = None, cast_inputs: Optional[torch.dtype] = torch.float32, ): """Decorator for forward methods which, by default, *disables* autocast and casts any floating-point tensor parameters into the specified dtype (much like `torch.amp.custom_fwd`). The *wrapped forward* will gain an additional `force_allow_autocast` keyword parameter. When set to `True`, the function will ignore `cast_inputs` and will not disable autocast, as if this decorator was not specified. (Thus, modules can specify a default recommended precision, and users can override that behavior when desired.) This decorator now supports both CPU and CUDA by using `torch.amp.custom_fwd` with the device_type inferred from input tensors at runtime. When autocast is *not* active, this decorator does not change any behavior. Arguments --------- fwd: Optional[Callable] The function to wrap. If omitted, returns a partial application of the decorator, e.g. allowing `new_decorator = fwd_default_precision(cast_inputs=torch.float32)`. Reminder: If you are decorating a function directly, this argument is already specified implicitly. cast_inputs: Optional[torch.dtype] If not `None` (the default being `torch.float32`), then any floating-point inputs to the wrapped function will be cast to the specified type. Note: When autocasting is enabled, output tensors of autocast-compatible operations may be of the autocast data type. Disabling autocast *without* casting inputs will not change this fact, so lower precision operations can happen even inside of an autocast-disabled region, which this argument helps avoid if desired. Returns ------- The wrapped function """ if fwd is None: return functools.partial(fwd_default_precision, cast_inputs=cast_inputs) # Cache for wrapped functions by device type (lazy initialization) wrapped_cache = {} def get_wrapped_fwd(device_type): """Get or create a wrapped function for the given device type.""" if device_type not in wrapped_cache: wrapped_cache[device_type] = torch.amp.custom_fwd( fwd, device_type=device_type, cast_inputs=cast_inputs ) return wrapped_cache[device_type] @functools.wraps(fwd) def wrapper(*args, force_allow_autocast: bool = False, **kwargs): """Wrapped forward function from fwd_default_precision. Arguments --------- *args: tuple Arguments to be forwarded to the unwrapped function. force_allow_autocast: bool When `True`, the wrapped function will be executed directly with no change to the autocast context and no input casting. **kwargs: dict Arguments to be forwarded to the unwrapped function. Returns ------- The wrapped function if force_allow_autocast, else the original """ if force_allow_autocast: return fwd(*args, **kwargs) else: # Infer device type from input tensors device_type = _infer_device_type(*args, **kwargs) wrapped_fwd = get_wrapped_fwd(device_type) return wrapped_fwd(*args, **kwargs) return wrapper