"""
Utilities for training kmeans model.
Author
* Pooneh Mousavi 2023
"""
import os
import logging
from tqdm.contrib import tqdm
try:
from sklearn.cluster import MiniBatchKMeans
except ImportError:
err_msg = "The optional dependency sklearn is needed to use this module\n"
err_msg += "Cannot import sklearn.cluster.MiniBatchKMeans to use KMeans/\n"
err_msg += "Please follow the instructions below\n"
err_msg += "=============================\n"
err_msg += "pip install -U scikit-learn\n"
raise ImportError(err_msg)
import joblib
logger = logging.getLogger(__name__)
[docs]
def accumulate_and_extract_features(
batch, features_list, ssl_model, ssl_layer_num, device
):
""" Extract features (output of SSL model) and acculamte them on cpu to be used for clustering.
Arguments
---------
batch: tensor
Single batch of data.
features_list : list
accumulate features list.
ssl_model
SSL-model used to extract features used for clustering.
ssl_layer_num: int
specify output of which layer of the ssl_model should be used.
device
CPU or GPU.
"""
batch = batch.to(device)
wavs, wav_lens = batch.sig
wavs, wav_lens = (
wavs.to(device),
wav_lens.to(device),
)
feats = ssl_model(wavs, wav_lens)[ssl_layer_num].flatten(end_dim=-2)
features_list.extend(feats.to("cpu").detach().numpy())
[docs]
def fetch_kmeans_model(
n_clusters,
init,
max_iter,
batch_size,
tol,
max_no_improvement,
n_init,
reassignment_ratio,
random_state,
checkpoint_path,
):
"""Return a k-means clustering model with specified parameters.
Arguments
---------
n_clusters : MiniBatchKMeans
The number of clusters to form as well as the number of centroids to generate.
init : int
Method for initialization: {'k-means++'', ''random''}
max_iter : int
Maximum number of iterations over the complete dataset before stopping independently of any early stopping criterion heuristics.
batch_size : int
Size of the mini batches.
tol : float
Control early stopping based on the relative center changes as measured by a smoothed, variance-normalized of the mean center squared position changes.
max_no_improvement :int
Control early stopping based on the consecutive number of mini batches that does not yield an improvement on the smoothed inertia.
n_init : int
Number of random initializations that are tried
reassignment_ratio : float
Control the fraction of the maximum number of counts for a center to be reassigned.
random_state :int
Determines random number generation for centroid initialization and random reassignment.
compute_labels : bool
Compute label assignment and inertia for the complete dataset once the minibatch optimization has converged in fit.
init_size : int
Number of samples to randomly sample for speeding up the initialization.
checkpoint_path : str
Path to saved model.
Returns
---------
MiniBatchKMeans
a k-means clustering model with specified parameters.
"""
if os.path.exists(checkpoint_path):
logger.info(f"The checkpoint is loaded from {checkpoint_path}.")
return joblib.load(checkpoint_path)
logger.info(
f"No checkpoint is found at {checkpoint_path}. New model is initialized for training."
)
return MiniBatchKMeans(
n_clusters=n_clusters,
init=init,
max_iter=max_iter,
batch_size=batch_size,
tol=tol,
max_no_improvement=max_no_improvement,
n_init=n_init,
reassignment_ratio=reassignment_ratio,
random_state=random_state,
verbose=1,
compute_labels=True,
init_size=None,
)
[docs]
def train(
model,
train_set,
ssl_model,
ssl_layer_num,
kmeans_batch_size=1000,
device="cpu",
):
"""Train a Kmeans model .
Arguments
---------
model : MiniBatchKMeans
The initial kmeans model for training.
train_set : Dataloader
Batches of tarining data.
ssl_model
SSL-model used to extract features used for clustering.
ssl_layer_num : int
Specify output of which layer of the ssl_model should be used.
device
CPU or GPU.
kmeans_batch_size : int
Size of the mini batches.
"""
logger.info("Start training kmeans model.")
features_list = []
with tqdm(train_set, dynamic_ncols=True,) as t:
for batch in t:
# train a kmeans model on a single batch if features_list reaches the kmeans_batch_size.
if len(features_list) >= kmeans_batch_size:
model = model.fit(features_list)
features_list = []
# extract features from the SSL model
accumulate_and_extract_features(
batch, features_list, ssl_model, ssl_layer_num, device
)
[docs]
def save_model(model, checkpoint_path):
"""Save a Kmeans model .
Arguments
---------
model : MiniBatchKMeans
The kmeans model to be saved.
checkpoint_path : str)
Path to save the model..
"""
joblib.dump(model, open(checkpoint_path, "wb"))