speechbrain.lobes.models.kmeans moduleο
K-means implementation.
Authors * Luca Della Libera 2024
Summaryο
Classes:
A wrapper for scikit-learn MiniBatchKMeans, providing integration with PyTorch tensors. |
Referenceο
- class speechbrain.lobes.models.kmeans.MiniBatchKMeansSklearn(*args, **kwargs)[source]ο
Bases:
Module
A wrapper for scikit-learn MiniBatchKMeans, providing integration with PyTorch tensors.
See https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html.
- Parameters:
Example
>>> import torch >>> device = "cpu" >>> n_clusters = 20 >>> batch_size = 8 >>> seq_length = 100 >>> hidden_size = 256 >>> model = MiniBatchKMeansSklearn(n_clusters).to(device) >>> input = torch.randn(batch_size, seq_length, hidden_size, device=device) >>> model.partial_fit(input) >>> labels = model(input) >>> labels.shape torch.Size([8, 100]) >>> centers = model.cluster_centers >>> centers.shape torch.Size([20, 256]) >>> len(list(model.buffers())) 1 >>> model.n_steps 1 >>> inertia = model.inertia(input)
- save(path)[source]ο
Saves the model to the specified file.
- Parameters:
path (str) β The file path to save the model.
- fit(input)[source]ο
Fits the model to the input data.
- Parameters:
input (torch.Tensor) β The input data tensor of shape (β¦, n_features).
- partial_fit(input)[source]ο
Performs an incremental fit of the model on the input data.
- Parameters:
input (torch.Tensor) β The input data tensor of shape (β¦, n_features).
- forward(input)[source]ο
Predicts cluster indices for the input data.
- Parameters:
input (torch.Tensor) β The input data tensor of shape (β¦, n_features).
- Returns:
Predicted cluster indices of shape (β¦,).
- Return type:
torch.Tensor
- inertia(input)[source]ο
Returns the inertia of the clustering.
- Parameters:
input (torch.Tensor) β The input data tensor of shape (β¦, n_features).
- Returns:
Inertia (sum of squared distances to the cluster centers).
- Return type:
torch.Tensor
- property n_stepsο
Returns the number of minibatches processed.
- Returns:
Number of minibatches processed.
- Return type:
- property cluster_centers_ο
Returns the cluster centers.
- Returns:
Cluster centers of shape (n_clusters, n_features).
- Return type:
torch.Tensor