Source code for speechbrain.integrations.audio_tokenizers.kmeans

"""K-means implementation.

Authors
* Luca Della Libera 2024
"""

import joblib
import torch


[docs] class MiniBatchKMeansSklearn(torch.nn.Module): """A wrapper for scikit-learn MiniBatchKMeans, providing integration with PyTorch tensors. See https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html. Arguments --------- *args : tuple Positional arguments passed to scikit-learn `MiniBatchKMeans`. **kwargs : dict Keyword arguments passed to scikit-learn `MiniBatchKMeans`. Example ------- >>> import torch >>> device = "cpu" >>> n_clusters = 20 >>> batch_size = 8 >>> seq_length = 100 >>> hidden_size = 256 >>> model = MiniBatchKMeansSklearn(n_clusters).to(device) >>> input = torch.randn(batch_size, seq_length, hidden_size, device=device) >>> model.partial_fit(input) >>> labels = model(input) >>> labels.shape torch.Size([8, 100]) >>> centers = model.cluster_centers >>> centers.shape torch.Size([20, 256]) >>> len(list(model.buffers())) 1 >>> model.n_steps 1 >>> inertia = model.inertia(input) """ def __init__(self, *args, **kwargs): try: from sklearn.cluster import MiniBatchKMeans except ImportError: err_msg = "The optional dependency `scikit-learn` must be installed to use this module.\n" err_msg += "Install using `pip install scikit-learn`.\n" raise ImportError(err_msg) super().__init__() self.kmeans = MiniBatchKMeans(*args, **kwargs) self.device = torch.device("cpu") self.register_buffer( "cluster_centers", self.cluster_centers_, persistent=False )
[docs] def to(self, device=None, **kwargs): """See documentation of `torch.nn.Module.to`.""" self.device = device return super().to(device)
[docs] def save(self, path): """Saves the model to the specified file. Arguments --------- path : str The file path to save the model. """ joblib.dump(self.kmeans, path)
[docs] def load(self, path, end_of_epoch): """Loads the model from the specified file. Arguments --------- path : str The file path from which to load the model. end_of_epoch : bool Indicates if this load is triggered at the end of an epoch. """ self.kmeans = joblib.load(path) self.cluster_centers = self.cluster_centers_
[docs] def fit(self, input): """Fits the model to the input data. Arguments --------- input : torch.Tensor The input data tensor of shape (..., n_features). """ numpy_input = input.detach().flatten(end_dim=-2).cpu().numpy() self.kmeans.fit(numpy_input) self.cluster_centers = self.cluster_centers_
[docs] def partial_fit(self, input): """Performs an incremental fit of the model on the input data. Arguments --------- input : torch.Tensor The input data tensor of shape (..., n_features). """ numpy_input = input.detach().flatten(end_dim=-2).cpu().numpy() self.kmeans.partial_fit(numpy_input) self.cluster_centers = self.cluster_centers_
[docs] def forward(self, input): """Predicts cluster indices for the input data. Arguments --------- input : torch.Tensor The input data tensor of shape (..., n_features). Returns ------- torch.Tensor Predicted cluster indices of shape (...,). """ numpy_input = input.detach().flatten(end_dim=-2).cpu().numpy() cluster_idxes = self.kmeans.predict(numpy_input) cluster_idxes = torch.tensor(cluster_idxes, device=self.device).long() cluster_idxes = cluster_idxes.reshape(input.shape[:-1]) return cluster_idxes
[docs] def inertia(self, input): """Returns the inertia of the clustering. Arguments --------- input : torch.Tensor The input data tensor of shape (..., n_features). Returns ------- torch.Tensor Inertia (sum of squared distances to the cluster centers). """ numpy_input = input.detach().flatten(end_dim=-2).cpu().numpy() score = self.kmeans.score(numpy_input) inertia = -torch.tensor(score, device=self.device).float() return inertia
@property def n_steps(self): """Returns the number of minibatches processed. Returns ------- int Number of minibatches processed. """ return self.kmeans.n_steps_ @property def cluster_centers_(self): """Returns the cluster centers. Returns ------- torch.Tensor Cluster centers of shape (n_clusters, n_features). """ if hasattr(self.kmeans, "cluster_centers_"): cluster_centers = self.kmeans.cluster_centers_ cluster_centers = torch.tensor( cluster_centers, device=self.device ).float() else: cluster_centers = torch.tensor(0.0, device=self.device) return cluster_centers