Source code for speechbrain.utils.fetching

"""Downloads or otherwise fetches pretrained models

Authors:
 * Aku Rouhe 2021
 * Samuele Cornell 2021
 * Andreas Nautsch 2022, 2023
"""
import urllib.request
import urllib.error
import pathlib
import logging
from enum import Enum
import huggingface_hub
from collections import namedtuple
from requests.exceptions import HTTPError

logger = logging.getLogger(__name__)


def _missing_ok_unlink(path):
    # missing_ok=True was added to Path.unlink() in Python 3.8
    # This does the same.
    try:
        path.unlink()
    except FileNotFoundError:
        pass


[docs] class FetchFrom(Enum): """Designator where to fetch models/audios from. Note: HuggingFace repository sources and local folder sources may be confused if their source type is undefined. """ LOCAL = 1 HUGGING_FACE = 2 URI = 3
# For easier use FetchSource = namedtuple("FetchSource", ["FetchFrom", "path"]) FetchSource.__doc__ = ( """NamedTuple describing a source path and how to fetch it""" ) FetchSource.__hash__ = lambda self: hash(self.path) FetchSource.encode = lambda self, *args, **kwargs: "_".join( (str(self.path), str(self.FetchFrom)) ).encode(*args, **kwargs) # FetchSource.__str__ = lambda self: str(self.path)
[docs] def fetch( filename, source, savedir="./pretrained_model_checkpoints", overwrite=False, save_filename=None, use_auth_token=False, revision=None, huggingface_cache_dir=None, ): """Ensures you have a local copy of the file, returns its path In case the source is an external location, downloads the file. In case the source is already accessible on the filesystem, creates a symlink in the savedir. Thus, the side effects of this function always look similar: savedir/save_filename can be used to access the file. And save_filename defaults to the filename arg. Arguments --------- filename : str Name of the file including extensions. source : str or FetchSource Where to look for the file. This is interpreted in special ways: First, if the source begins with "http://" or "https://", it is interpreted as a web address and the file is downloaded. Second, if the source is a valid directory path, a symlink is created to the file. Otherwise, the source is interpreted as a Huggingface model hub ID, and the file is downloaded from there. savedir : str Path where to save downloads/symlinks. overwrite : bool If True, always overwrite existing savedir/filename file and download or recreate the link. If False (as by default), if savedir/filename exists, assume it is correct and don't download/relink. Note that Huggingface local cache is always used - with overwrite=True we just relink from the local cache. save_filename : str The filename to use for saving this file. Defaults to filename if not given. use_auth_token : bool (default: False) If true Hugginface's auth_token will be used to load private models from the HuggingFace Hub, default is False because majority of models are public. revision : str The model revision corresponding to the HuggingFace Hub model revision. This is particularly useful if you wish to pin your code to a particular version of a model hosted at HuggingFace. huggingface_cache_dir: str Path to HuggingFace cache; if None -> "~/.cache/huggingface" (default: None) Returns ------- pathlib.Path Path to file on local file system. Raises ------ ValueError If file is not found """ if save_filename is None: save_filename = filename savedir = pathlib.Path(savedir) savedir.mkdir(parents=True, exist_ok=True) fetch_from = None if isinstance(source, FetchSource): fetch_from, source = source sourcefile = f"{source}/{filename}" destination = savedir / save_filename if destination.exists() and not overwrite: MSG = f"Fetch {filename}: Using existing file/symlink in {str(destination)}." logger.info(MSG) return destination if pathlib.Path(source).is_dir() and fetch_from not in [ FetchFrom.HUGGING_FACE, FetchFrom.URI, ]: # Interpret source as local directory path & create a link and return it as destination sourcepath = pathlib.Path(sourcefile).absolute() _missing_ok_unlink(destination) destination.symlink_to(sourcepath) MSG = f"Destination {filename}: local file in {str(sourcepath)}." logger.info(MSG) return destination if ( str(source).startswith("http:") or str(source).startswith("https:") ) or fetch_from is FetchFrom.URI: # Interpret source as web address. MSG = ( f"Fetch {filename}: Downloading from normal URL {str(sourcefile)} ." ) logger.info(MSG) # Download try: urllib.request.urlretrieve(sourcefile, destination) except urllib.error.URLError: raise ValueError( f"Interpreted {source} as web address, but could not download." ) else: # FetchFrom.HUGGING_FACE check is spared (no other option right now) # Interpret source as huggingface hub ID # Use huggingface hub's fancy cached download. MSG = f"Fetch {filename}: Delegating to Huggingface hub, source {str(source)}." logger.info(MSG) try: fetched_file = huggingface_hub.hf_hub_download( repo_id=source, filename=filename, use_auth_token=use_auth_token, revision=revision, cache_dir=huggingface_cache_dir, ) logger.info(f"HF fetch: {fetched_file}") except HTTPError as e: if "404 Client Error" in str(e): raise ValueError("File not found on HF hub") else: raise # Huggingface hub downloads to etag filename, symlink to the expected one: sourcepath = pathlib.Path(fetched_file).absolute() _missing_ok_unlink(destination) destination.symlink_to(sourcepath) return destination