speechbrain.lobes.models.kmeans module

K-means implementation.

Authors * Luca Della Libera 2024

Summary

Classes:

MiniBatchKMeansSklearn

A wrapper for scikit-learn MiniBatchKMeans, providing integration with PyTorch tensors.

Reference

class speechbrain.lobes.models.kmeans.MiniBatchKMeansSklearn(*args, **kwargs)[source]

Bases: Module

A wrapper for scikit-learn MiniBatchKMeans, providing integration with PyTorch tensors.

See https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html.

Parameters:
  • *args (tuple) – Positional arguments passed to scikit-learn MiniBatchKMeans.

  • **kwargs (dict) – Keyword arguments passed to scikit-learn MiniBatchKMeans.

Example

>>> import torch
>>> device = "cpu"
>>> n_clusters = 20
>>> batch_size = 8
>>> seq_length = 100
>>> hidden_size = 256
>>> model = MiniBatchKMeansSklearn(n_clusters).to(device)
>>> input = torch.randn(batch_size, seq_length, hidden_size, device=device)
>>> model.partial_fit(input)
>>> labels = model(input)
>>> labels.shape
torch.Size([8, 100])
>>> centers = model.cluster_centers
>>> centers.shape
torch.Size([20, 256])
>>> len(list(model.buffers()))
1
>>> model.n_steps
1
>>> inertia = model.inertia(input)
to(device=None, **kwargs)[source]

See documentation of torch.nn.Module.to.

save(path)[source]

Saves the model to the specified file.

Parameters:

path (str) – The file path to save the model.

load(path, end_of_epoch)[source]

Loads the model from the specified file.

Parameters:
  • path (str) – The file path from which to load the model.

  • end_of_epoch (bool) – Indicates if this load is triggered at the end of an epoch.

fit(input)[source]

Fits the model to the input data.

Parameters:

input (torch.Tensor) – The input data tensor of shape (…, n_features).

partial_fit(input)[source]

Performs an incremental fit of the model on the input data.

Parameters:

input (torch.Tensor) – The input data tensor of shape (…, n_features).

forward(input)[source]

Predicts cluster indices for the input data.

Parameters:

input (torch.Tensor) – The input data tensor of shape (…, n_features).

Returns:

Predicted cluster indices of shape (…,).

Return type:

torch.Tensor

inertia(input)[source]

Returns the inertia of the clustering.

Parameters:

input (torch.Tensor) – The input data tensor of shape (…, n_features).

Returns:

Inertia (sum of squared distances to the cluster centers).

Return type:

torch.Tensor

property n_steps

Returns the number of minibatches processed.

Returns:

Number of minibatches processed.

Return type:

int

property cluster_centers_

Returns the cluster centers.

Returns:

Cluster centers of shape (n_clusters, n_features).

Return type:

torch.Tensor