speechbrain.utils.autocast module
This module implements utilities and abstractions for use with
torch.autocast
, i.e. Automatic Mixed Precision.
- Authors
Sylvain de Langen 2023
Summary
Functions:
Decorator for forward methods which, by default, disables autocast and casts any floating-point tensor parameters into the specified dtype (much like |
Reference
- speechbrain.utils.autocast.fwd_default_precision(fwd: Callable | None = None, cast_inputs: dtype | None = torch.float32)[source]
Decorator for forward methods which, by default, disables autocast and casts any floating-point tensor parameters into the specified dtype (much like
torch.cuda.amp.custom_fwd
).The wrapped forward will gain an additional
force_allow_autocast
keyword parameter. When set toTrue
, the function will ignorecast_inputs
and will not disable autocast, as if this decorator was not specified. (Thus, modules can specify a default recommended precision, and users can override that behavior when desired.)Note that as of PyTorch 2.1.1, this will only affect CUDA AMP. Non-CUDA AMP will be unaffected and no input tensors will be cast! This usecase may be supported by this function in the future.
When autocast is not active, this decorator does not change any behavior.
- Parameters:
fwd (Optional[Callable]) –
The function to wrap. If omitted, returns a partial application of the decorator, e.g. allowing
new_decorator = fwd_default_precision(cast_inputs=torch.float32)
.Reminder: If you are decorating a function directly, this argument is already specified implicitly.
cast_inputs (Optional[torch.dtype]) –
If not
None
(the default beingtorch.float32
), then any floating-point inputs to the wrapped function will be cast to the specified type.Note: When autocasting is enabled, output tensors of autocast-compatible operations may be of the autocast data type. Disabling autocast without casting inputs will not change this fact, so lower precision operations can happen even inside of an autocast-disabled region, which this argument helps avoid if desired.
- Return type:
The wrapped function