Source code for speechbrain.utils.logger

"""Managing the logger, utilities

Author
 * Fang-Pen Lin 2012 https://fangpenlin.com/posts/2012/08/26/good-logging-practice-in-python/
 * Peter Plantinga 2020
 * Aku Rouhe 2020
"""

import functools
import logging
import logging.config
import math
import os
import sys

import torch
import tqdm
import yaml

from speechbrain.utils.data_utils import recursive_update
from speechbrain.utils.distributed import if_main_process
from speechbrain.utils.superpowers import run_shell

ORDERS_ABBREV = {
    -24: "y",
    -21: "z",
    -18: "a",
    -15: "f",
    -12: "p",
    -9: "n",
    -6: "µ",
    -3: "m",
    0: "",
    3: "k",
    6: "M",
    9: "G",
    12: "T",
    15: "P",
    18: "E",
    21: "Z",
    24: "Y",
}

# Short scale
# Negative powers of ten in lowercase, positive in uppercase
ORDERS_WORDS = {
    -24: "septillionths",
    -21: "sextillionths",
    -18: "quintillionths",
    -15: "quadrillionths",
    -12: "trillionths",
    -9: "billionths",
    -6: "millionths",
    -3: "thousandths",
    0: "",
    3: "Thousand",
    6: "Million",
    9: "Billion",
    12: "Trillion",
    15: "Quadrillion",
    18: "Quintillion",
    21: "Sextillion",
    24: "Septillion",
}


[docs] class MultiProcessLoggerAdapter(logging.LoggerAdapter): r""" Logger adapter that handles multi-process logging, ensuring logs are written only on the main process if specified. This class extends `logging.LoggerAdapter` and provides additional functionality for controlling logging in multi-process environments, with the option to limit logs to the main process only. This class is heavily inspired by HuggingFace Accelerate toolkit: https://github.com/huggingface/accelerate/blob/85b1a03552cf8d58e036634e004220c189bfb247/src/accelerate/logging.py#L22 """ @staticmethod def _should_log(main_process_only: bool) -> bool: r""" Determines if logging should occur based on whether the code is running on the main process or not. Arguments --------- main_process_only : bool A flag indicating if logging should be restricted to the main process. Returns ------- bool True if logging should be performed (based on the process and the flag), False otherwise. """ return not main_process_only or ( main_process_only and if_main_process() )
[docs] def log(self, level: int, msg: str, *args: tuple, **kwargs: dict): r""" Logs a message with the specified log level, respecting the `main_process_only` flag to decide whether to log based on the current process. Arguments --------- level : int Logging level (e.g., logging.INFO, logging.WARNING). msg : str The message to log. *args : tuple Additional positional arguments passed to the logger. **kwargs : dict Additional keyword arguments passed to the logger, including: - main_process_only (bool): If True, log only from the main process (default: True). - stacklevel (int): The stack level to use when logging (default: 2). Notes ----- If `main_process_only` is True, the log will only be written if the current process is the main process, as determined by `if_main_process()`. """ main_process_only = kwargs.pop("main_process_only", True) kwargs.setdefault("stacklevel", 2) if self.isEnabledFor(level): if self._should_log(main_process_only): msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs)
[docs] @functools.lru_cache(None) def warning_once(self, *args: tuple, **kwargs: dict): r""" Logs a warning message only once by using caching to prevent duplicate warnings. Arguments --------- *args : tuple Positional arguments passed to the warning log. **kwargs : dict Keyword arguments passed to the warning log. Notes ----- This method is decorated with `functools.lru_cache(None)`, ensuring that the warning message is logged only once regardless of how many times the method is called. """ self.warning(*args, **kwargs)
[docs] def get_logger(name: str) -> MultiProcessLoggerAdapter: """ Retrieves a logger with the specified name, applying a log level from the environment variable `SB_LOG_LEVEL` if set, or defaults to `INFO` level. If the environment variable `SB_LOG_LEVEL` is not defined, it defaults to `INFO` level and sets this level in the environment for future use. The environment variable can be set manually or automatically in `Brain` class following `setup_logging`. Arguments --------- name : str The name of the logger to retrieve. Returns ------- MultiProcessLoggerAdapter An instance of `MultiProcessLoggerAdapter` wrapping the logger with the specified name. """ logger = logging.getLogger(name) log_level = os.environ.get("SB_LOG_LEVEL", None) if log_level is None: log_level = "DEBUG" os.environ["SB_LOG_LEVEL"] = log_level logger.setLevel(log_level.upper()) return MultiProcessLoggerAdapter(logger, {})
[docs] def setup_logging( config_path="log-config.yaml", overrides={}, default_level="DEBUG", ): """Setup logging configuration. Arguments --------- config_path : str The path to a logging config file. overrides : dict A dictionary of the same structure as the config dict with any updated values that need to be applied. default_level : str The log level to use if the config file is not found. Python logging allows ints or strings: https://docs.python.org/3/library/logging.html#logging.Logger.setLevel but strings are used here as environment variables have to be strings. The available levels are listed here: https://docs.python.org/3/library/logging.html#levels """ if os.path.exists(config_path): with open(config_path, "rt", encoding="utf-8") as f: config = yaml.safe_load(f) recursive_update(config, overrides) logging.config.dictConfig(config) else: logging.basicConfig(level=default_level) os.environ["SB_LOG_LEVEL"] = default_level
[docs] class TqdmCompatibleStreamHandler(logging.StreamHandler): """TQDM compatible StreamHandler. Writes and prints should be passed through tqdm.tqdm.write so that the tqdm progressbar doesn't get messed up. """
[docs] def emit(self, record): """TQDM compatible StreamHandler.""" try: msg = self.format(record) stream = self.stream tqdm.tqdm.write(msg, end=self.terminator, file=stream) self.flush() except RecursionError: raise except Exception: self.handleError(record)
[docs] def format_order_of_magnitude(number, abbreviate=True): """Formats number to the appropriate order of magnitude for printing. Arguments --------- number : int, float The number to format. abbreviate : bool Whether to use abbreviations (k,M,G) or words (Thousand, Million, Billion). Numbers will be either like: "123.5k" or "123.5 Thousand". Returns ------- str The formatted number. Note that the order of magnitude token is part of the string. Example ------- >>> print(format_order_of_magnitude(123456)) 123.5k >>> print(format_order_of_magnitude(0.00000123, abbreviate=False)) 1.2 millionths >>> print(format_order_of_magnitude(5, abbreviate=False)) 5 """ style = ORDERS_ABBREV if abbreviate else ORDERS_WORDS precision = "{num:3.1f}" order = 3 * math.floor(math.log(math.fabs(number), 1000)) # Fallback for very large numbers: while order not in style and order != 0: order = order - math.copysign(3, order) # Bring 3 units towards 0 order_token = style[order] if order != 0: formatted_number = precision.format(num=number / 10**order) else: if isinstance(number, int): formatted_number = str(number) else: formatted_number = precision.format(num=number) if abbreviate or not order_token: return formatted_number + order_token else: return formatted_number + " " + order_token
[docs] def get_environment_description(): """Returns a string describing the current Python / SpeechBrain environment. Useful for making experiments as replicable as possible. Returns ------- str The string is formatted ready to be written to a file. Example ------- >>> get_environment_description().splitlines()[0] 'SpeechBrain system description' """ python_version_str = "Python version:\n" + sys.version + "\n" try: freezed, _, _ = run_shell("pip freeze") python_packages_str = "Installed Python packages:\n" python_packages_str += freezed.decode(errors="replace") except OSError: python_packages_str = "Could not list python packages with pip freeze" try: git_hash, _, _ = run_shell("git rev-parse --short HEAD") git_str = "Git revision:\n" + git_hash.decode(errors="replace") except OSError: git_str = "Could not get git revision" if torch.cuda.is_available(): if torch.version.cuda is None: cuda_str = "ROCm version:\n" + torch.version.hip else: cuda_str = "CUDA version:\n" + torch.version.cuda else: cuda_str = "CUDA not available" result = "SpeechBrain system description\n" result += "==============================\n" result += python_version_str result += "==============================\n" result += python_packages_str result += "==============================\n" result += git_str result += "==============================\n" result += cuda_str return result