{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "sb_auto_header",
"tags": [
"sb_auto_header"
]
},
"source": [
"\n",
"\n",
"\n",
"[
](https://colab.research.google.com/github/speechbrain/speechbrain/blob/develop/docs/tutorials/advanced/federated-speech-model-training-via-speechbrain-and-flower.ipynb)\n",
"to execute or view/download this notebook on\n",
"[GitHub](https://github.com/speechbrain/speechbrain/tree/develop/docs/tutorials/advanced/federated-speech-model-training-via-speechbrain-and-flower.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X4rk4cFW8x6o"
},
"source": [
"# Federated Speech Model Training via Flower and SpeechBrain\n",
"\n",
"Are you interested in both federated learning (FL) and speech, but worried about the proper tools to run experiments? Today you will get the answer. This tutorial introduces how to integrate [Flower](https://github.com/adap/flower) and [SpeechBrain](https://github.com/speechbrain/speechbrain) to achieve federated speech model training.\n",
"\n",
"**Important:** It is recommended to be familiar with SpeechBrain and Flower before jumping into this tutorial as some parts may involve some level of complexity. Tutorials are available for both toolkits on their respective website!\n",
"\n",
"For simplicity, we choose a popular speech task --- automatic speech recognition (ASR) as an example, and training will be done with a toy dataset which only contains 100 audio recordings. In a real case, you need much more training data (e.g 100 or even 1000 hours) to reach acceptable performance. Note that ASR is regarded as a case study, all other speech related tasks can be done similarly.\n",
"\n",
"Apart from running normal federated ASR model training, the code also provides three other features to speed up model converge and improve the performance.\n",
"\n",
"* Loading a centralised initial model before federated training starts.\n",
"\n",
"* Providing three aggregation weighting strategies --- standard FedAvg, Loss-based and WER-based aggregation based on [this paper](https://arxiv.org/abs/2104.14297).\n",
"\n",
"* Facilitating an additional training with a held-out dataset on the server side after aggregation.\n",
"\n",
"The details of them will be elaborated in the later sections.\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0TMVAXERs_sb"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FwGNEPVTPAwS"
},
"source": [
"To run the code fast enough, we suggest using a GPU (`Runtime => change runtime type => GPU`).\n",
"\n",
"\n",
"## Installation\n",
"Before starting, let's install Flower and SpeechBrain:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"executionInfo": {
"elapsed": 20306,
"status": "ok",
"timestamp": 1709075703966,
"user": {
"displayName": "Mirco Ravanelli",
"userId": "06892056361698510975"
},
"user_tz": 300
},
"id": "saSuftgFBTHv"
},
"outputs": [],
"source": [
"%%capture\n",
"# Installing SpeechBrain via pip\n",
"BRANCH = 'develop'\n",
"!python -m pip install git+https://github.com/speechbrain/speechbrain.git@$BRANCH"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"executionInfo": {
"elapsed": 21591,
"status": "ok",
"timestamp": 1709075731107,
"user": {
"displayName": "Mirco Ravanelli",
"userId": "06892056361698510975"
},
"user_tz": 300
},
"id": "CpwJ-Of0tbWe"
},
"outputs": [],
"source": [
"%%capture\n",
"# For pip installation\n",
"!pip install flwr\n",
"\n",
"# update tqdm package to avoid an ImportError.\n",
"!pip install tqdm==4.50.2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CtMlAwdZwXT3"
},
"source": [
"Then, download Flower-SpeechBrain integration code and template dataset, which was released on ```github.com/yan-gao-GY/Flower-SpeechBrain```. This integration will be explained in more details later on."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 216,
"status": "ok",
"timestamp": 1709075733276,
"user": {
"displayName": "Mirco Ravanelli",
"userId": "06892056361698510975"
},
"user_tz": 300
},
"id": "QJXNAg7fE6ld",
"outputId": "3614a5d1-a832-4a52-db9d-47e40c9fcbcc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/content\n"
]
}
],
"source": [
"%cd /content\n",
"%rm -rf results"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"executionInfo": {
"elapsed": 842,
"status": "ok",
"timestamp": 1709075734948,
"user": {
"displayName": "Mirco Ravanelli",
"userId": "06892056361698510975"
},
"user_tz": 300
},
"id": "16hQcJa-h6XE"
},
"outputs": [],
"source": [
"%%capture\n",
"!git clone https://github.com/yan-gao-GY/Flower-SpeechBrain.git"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fh9WH4THxkgA"
},
"source": [
"## What steps are needed for your experiments?\n",
"\n",
"The steps needed to launch a federated speech model training are just as normal Flower experiments.\n",
"\n",
"1. **Prepare your data**. The goal of this step is to create the data manifest files (TSV format) to fit the input format of SpeechBrain. The data manifest files contains the location of the speech data and their corresponding text annotations. In this tutorial, we skip the data partitioning step and simulate different partitions using a small template dataset. But in practice, you might want to have different files per federated client or a more complex data partitioning scheme.\n",
"\n",
"Now let's uncompress our template dataset."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 240,
"status": "ok",
"timestamp": 1709075737691,
"user": {
"displayName": "Mirco Ravanelli",
"userId": "06892056361698510975"
},
"user_tz": 300
},
"id": "dOXD6PjcjjSi",
"outputId": "91717178-c41b-4ea6-a667-10daa4285464"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/content/Flower-SpeechBrain/temp_dataset\n",
"/content\n"
]
}
],
"source": [
"%cd /content/Flower-SpeechBrain/temp_dataset/\n",
"import zipfile\n",
"import os\n",
"\n",
"# Uncompression function\n",
"def un_zip(file_name):\n",
" zip_file = zipfile.ZipFile(file_name)\n",
" for names in zip_file.namelist():\n",
" zip_file.extract(names)\n",
" zip_file.close()\n",
"\n",
"un_zip(\"temp_dataset.zip\")\n",
"\n",
"# Simulate partitions using template dataset.\n",
"%cp temp_dataset.tsv train_0.tsv\n",
"\n",
"# Go back to /content directory.\n",
"%cd /content"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jcUTmthGIp84"
},
"source": [
"2. **Specify server and clients**. As Colab notebooks only allow one cell to be run at a time, we simulate the server and the clients as background processes within this tutorial. The following cells create `server.sh` and `clients.sh` scripts that will launch the required processes. All arguments required for federated training are passed in from the scripts."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 95,
"status": "ok",
"timestamp": 1709075740241,
"user": {
"displayName": "Mirco Ravanelli",
"userId": "06892056361698510975"
},
"user_tz": 300
},
"id": "FkrC6mTObM05",
"outputId": "b05e8e1a-c177-4ef2-bc7f-71b4c99a0c8e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Writing server.sh\n"
]
}
],
"source": [
"%%writefile server.sh\n",
"PYTHONUNBUFFERED=1 python3 /content/Flower-SpeechBrain/server.py \\\n",
" --data_path=\"/content/Flower-SpeechBrain/temp_dataset/\" \\\n",
" --config_path=\"/content/Flower-SpeechBrain/configs/\" \\\n",
" --tr_path=\"/content/Flower-SpeechBrain/temp_dataset/temp_dataset.tsv\" \\\n",
" --test_path=\"/content/Flower-SpeechBrain/temp_dataset/temp_dataset.tsv\" \\\n",
" --tr_add_path=\"/content/Flower-SpeechBrain/temp_dataset/temp_dataset.tsv\" \\\n",
" --config_file=\"template.yaml\" \\\n",
" --min_fit_clients=1 \\\n",
" --min_available_clients=1 \\\n",
" --rounds=1 \\\n",
" --local_epochs=1 \\\n",
" --server_address=\"localhost:24338\""
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 91,
"status": "ok",
"timestamp": 1709075741600,
"user": {
"displayName": "Mirco Ravanelli",
"userId": "06892056361698510975"
},
"user_tz": 300
},
"id": "kUP683Skelc5",
"outputId": "83481e23-de95-4002-aa7b-4d85acfe662e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Writing clients.sh\n"
]
}
],
"source": [
"%%writefile clients.sh\n",
"export PYTHONUNBUFFERED=1\n",
"NUM_CLIENTS=1\n",
"\n",
"\n",
"echo \"Starting $NUM_CLIENTS clients.\"\n",
"for ((i = 0; i < $NUM_CLIENTS; i++))\n",
"do\n",
" echo \"Starting client(cid=$i) with partition $i out of $NUM_CLIENTS clients.\"\n",
" # Staggered loading of clients: clients are loaded 8s apart.\n",
" sleep 8s\n",
" python3 /content/Flower-SpeechBrain/client.py \\\n",
" --cid=$i \\\n",
" --data_path=\"/content/Flower-SpeechBrain/temp_dataset/\" \\\n",
" --tr_path=\"/content/Flower-SpeechBrain/temp_dataset/\" \\\n",
" --dev_path=\"/content/Flower-SpeechBrain/temp_dataset/temp_dataset.tsv\" \\\n",
" --config_path=\"/content/Flower-SpeechBrain/configs/\" \\\n",
" --config_file=\"template.yaml\" \\\n",
" --eval_device=\"cuda:0\" \\\n",
" --server_address=\"localhost:24338\" &\n",
"done\n",
"echo \"Started $NUM_CLIENTS clients.\""
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"executionInfo": {
"elapsed": 239,
"status": "ok",
"timestamp": 1709075799505,
"user": {
"displayName": "Mirco Ravanelli",
"userId": "06892056361698510975"
},
"user_tz": 300
},
"id": "vLIzAe_0gEuI"
},
"outputs": [],
"source": [
"# Execute this after running any of the %%writefile cells above\n",
"!chmod +x clients.sh server.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OJtjZHr0N5G3"
},
"source": [
"3. **Launch federated training~!** The following single cell will start the server, wait 5 seconds for it to initialise, and then start all clients.\n",
"\n",
" ```\n",
" !((./server.sh & sleep 5s); ./clients.sh)\n",
" ```\n",
"\n",
" We suggest running it at the end of this tutorial.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IYSzHjeR_VSS"
},
"source": [
"## Integration details — coupling SpeechBrain to Flower\n",
"Let's first see some details of the integration process to better understand the code. There are only four main steps required:\n",
"\n",
"1. Define a Brain class ([SpeechBrain Brain Class tutorial](https://speechbrain.readthedocs.io/en/latest/tutorials/basics/brain-class.html)).\n",
"2. Initialise the Brain class and dataset ([SpeechBrain dataio tutorial](https://speechbrain.readthedocs.io/en/latest/tutorials/basics/data-loading-pipeline.html)).\n",
"3. Define a SpeechBrain Client ([Flower client documentation](https://flower.dev/docs/quickstart_pytorch.html#flower-client)).\n",
"4. Define a Flower Strategy on the server side ([Flower strategies](https://flower.dev/docs/strategies.html#strategies))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rFSxYCs6JKfp"
},
"source": [
"### Define a Brain class\n",
"\n",
"First, we define our customised Brain class as any normal SpeechBrain experiments. This override is necessary (while usually not needed on SpeechBrain) because Flower requires the number of processed samples to perform aggregation!\n",
"\n",
"```python\n",
"class ASR(sb.core.Brain):\n",
" def compute_forward(self, batch, stage):\n",
" \"\"\"Forward pass, to be overridden by sub-classes.\n",
"\n",
" Arguments\n",
" ---------\n",
" batch : torch.Tensor or tensors\n",
" An element from the dataloader, including inputs for processing.\n",
" stage : Stage\n",
" The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST\n",
"\n",
" Returns\n",
" -------\n",
" torch.Tensor or Tensors\n",
" The outputs after all processing is complete.\n",
" Directly passed to ``compute_objectives()``.\n",
" \"\"\"\n",
" [...]\n",
"\n",
" def compute_objectives(self, predictions, batch, stage):\n",
" \"\"\"Compute loss, to be overridden by sub-classes.\n",
"\n",
" Arguments\n",
" ---------\n",
" predictions : torch.Tensor or Tensors\n",
" The output tensor or tensors to evaluate.\n",
" Comes directly from ``compute_forward()``.\n",
" batch : torch.Tensor or tensors\n",
" An element from the dataloader, including targets for comparison.\n",
" stage : Stage\n",
" The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST\n",
"\n",
" Returns\n",
" -------\n",
" loss : torch.Tensor\n",
" A tensor with the computed loss.\n",
" \"\"\"\n",
" [...]\n",
"\n",
" def fit_batch(self, batch):\n",
" \"\"\"Fit one batch, override to do multiple updates.\n",
"\n",
" The default implementation depends on a few methods being defined\n",
" with a particular behavior:\n",
"\n",
" * ``compute_forward()``\n",
" * ``compute_objectives()``\n",
" * ``optimizers_step()``\n",
"\n",
" Also depends on having optimizers passed at initialization.\n",
"\n",
" Arguments\n",
" ---------\n",
" batch : list of torch.Tensors\n",
" Batch of data to use for training. Default implementation assumes\n",
" this batch has two elements: inputs and targets.\n",
"\n",
" Returns\n",
" -------\n",
" detached loss\n",
" \"\"\"\n",
" [...]\n",
"\n",
" def evaluate_batch(self, batch, stage):\n",
" \"\"\"Evaluate one batch, override for different procedure than train.\n",
"\n",
" The default implementation depends on two methods being defined\n",
" with a particular behavior:\n",
"\n",
" * ``compute_forward()``\n",
" * ``compute_objectives()``\n",
"\n",
" Arguments\n",
" ---------\n",
" batch : list of torch.Tensors\n",
" Batch of data to use for evaluation. Default implementation assumes\n",
" this batch has two elements: inputs and targets.\n",
" stage : Stage\n",
" The stage of the experiment: Stage.VALID, Stage.TEST\n",
"\n",
" Returns\n",
" -------\n",
" detached loss\n",
" \"\"\"\n",
" [...]\n",
"\n",
" def fit(\n",
" self,\n",
" epoch_counter,\n",
" train_set,\n",
" valid_set=None,\n",
" progressbar=None,\n",
" train_loader_kwargs={},\n",
" valid_loader_kwargs={},\n",
" ):\n",
" \"\"\"Iterate epochs and datasets to improve objective.\n",
"\n",
" Relies on the existence of multiple functions that can (or should) be\n",
" overridden. The following methods are used and expected to have a\n",
" certain behavior:\n",
"\n",
" * ``fit_batch()``\n",
" * ``evaluate_batch()``\n",
" * ``update_average()``\n",
"\n",
" If the initialization was done with distributed_count > 0 and the\n",
" distributed_backend is ddp, this will generally handle multiprocess\n",
" logic, like splitting the training data into subsets for each device and\n",
" only saving a checkpoint on the main process.\n",
"\n",
" Arguments\n",
" ---------\n",
" epoch_counter : iterable\n",
" Each call should return an integer indicating the epoch count.\n",
" train_set : Dataset, DataLoader\n",
" A set of data to use for training. If a Dataset is given, a\n",
" DataLoader is automatically created. If a DataLoader is given, it is\n",
" used directly.\n",
" valid_set : Dataset, DataLoader\n",
" A set of data to use for validation. If a Dataset is given, a\n",
" DataLoader is automatically created. If a DataLoader is given, it is\n",
" used directly.\n",
" train_loader_kwargs : dict\n",
" Kwargs passed to `make_dataloader()` for making the train_loader\n",
" (if train_set is a Dataset, not DataLoader).\n",
" E.G. batch_size, num_workers.\n",
" DataLoader kwargs are all valid.\n",
" valid_loader_kwargs : dict\n",
" Kwargs passed to `make_dataloader()` for making the valid_loader\n",
" (if valid_set is a Dataset, not DataLoader).\n",
" E.g., batch_size, num_workers.\n",
" DataLoader kwargs are all valid.\n",
" progressbar : bool\n",
" Whether to display the progress of each epoch in a progressbar.\n",
" \"\"\"\n",
" [...]\n",
"\n",
" def evaluate(\n",
" self,\n",
" test_set,\n",
" progressbar=None,\n",
" test_loader_kwargs={},\n",
" ):\n",
" \"\"\"Iterate test_set and evaluate brain performance. By default, loads\n",
" the best-performing checkpoint (as recorded using the checkpointer).\n",
"\n",
" Arguments\n",
" ---------\n",
" test_set : Dataset, DataLoader\n",
" If a DataLoader is given, it is iterated directly. Otherwise passed\n",
" to ``self.make_dataloader()``.\n",
" max_key : str\n",
" Key to use for finding best checkpoint, passed to\n",
" ``on_evaluate_start()``.\n",
" min_key : str\n",
" Key to use for finding best checkpoint, passed to\n",
" ``on_evaluate_start()``.\n",
" progressbar : bool\n",
" Whether to display the progress in a progressbar.\n",
" test_loader_kwargs : dict\n",
" Kwargs passed to ``make_dataloader()`` if ``test_set`` is not a\n",
" DataLoader. NOTE: ``loader_kwargs[\"ckpt_prefix\"]`` gets\n",
" automatically overwritten to ``None`` (so that the test DataLoader\n",
" is not added to the checkpointer).\n",
"\n",
" Returns\n",
" -------\n",
" average test loss\n",
" \"\"\"\n",
" [...]\n",
"```\n",
"\n",
"We override the `fit()` method, which calculates number of training examples, average training loss and average WER. In practice, the code is almost identical to the official SpeechBrain (copy and paste), as we just need to return the number of processed samples !\n",
"\n",
"```python\n",
" def fit(\n",
" self,\n",
" epoch_counter,\n",
" train_set,\n",
" valid_set=None,\n",
" progressbar=None,\n",
" train_loader_kwargs={},\n",
" valid_loader_kwargs={},\n",
" ):\n",
" if self.test_only:\n",
" return\n",
"\n",
" if not (\n",
" isinstance(train_set, DataLoader)\n",
" or isinstance(train_set, LoopedLoader)\n",
" ):\n",
" train_set = self.make_dataloader(\n",
" train_set, stage=sb.Stage.TRAIN, **train_loader_kwargs\n",
" )\n",
" if valid_set is not None and not (\n",
" isinstance(valid_set, DataLoader)\n",
" or isinstance(valid_set, LoopedLoader)\n",
" ):\n",
" valid_set = self.make_dataloader(\n",
" valid_set,\n",
" stage=sb.Stage.VALID,\n",
" ckpt_prefix=None,\n",
" **valid_loader_kwargs,\n",
" )\n",
"\n",
" self.on_fit_start()\n",
"\n",
" if progressbar is None:\n",
" progressbar = not self.noprogressbar\n",
"\n",
" batch_count = 0\n",
" # Iterate epochs\n",
" for epoch in epoch_counter:\n",
"\n",
" # Training stage\n",
" self.on_stage_start(sb.Stage.TRAIN, epoch)\n",
" self.modules.train()\n",
" avg_wer = 0.0\n",
"\n",
" # Reset nonfinite count to 0 each epoch\n",
" self.nonfinite_count = 0\n",
"\n",
" if self.train_sampler is not None and hasattr(\n",
" self.train_sampler, \"set_epoch\"\n",
" ):\n",
" self.train_sampler.set_epoch(epoch)\n",
"\n",
" # Time since last intra-epoch checkpoint\n",
" last_ckpt_time = time.time()\n",
"\n",
" # Only show progressbar if requested and main_process\n",
" enable = progressbar and sb.utils.distributed.if_main_process()\n",
" with tqdm(\n",
" train_set,\n",
" initial=self.step,\n",
" dynamic_ncols=True,\n",
" disable=not enable,\n",
" ) as t:\n",
" for batch in t:\n",
" self.step += 1\n",
" loss, wer = self.fit_batch(batch)\n",
" _, wav_lens = batch.sig\n",
" batch_count += wav_lens.shape[0]\n",
"\n",
" self.avg_train_loss = self.update_average(\n",
" loss, self.avg_train_loss\n",
" )\n",
" avg_wer = self.update_average_wer(\n",
" wer, avg_wer\n",
" )\n",
" t.set_postfix(train_loss=self.avg_train_loss)\n",
"\n",
" # Debug mode only runs a few batches\n",
" if self.debug and self.step == self.debug_batches:\n",
" break\n",
"\n",
" if (\n",
" self.checkpointer is not None\n",
" and self.ckpt_interval_minutes > 0\n",
" and time.time() - last_ckpt_time\n",
" >= self.ckpt_interval_minutes * 60.0\n",
" ):\n",
" run_on_main(self._save_intra_epoch_ckpt)\n",
" last_ckpt_time = time.time()\n",
"\n",
" # Run train \"on_stage_end\" on all processes\n",
" if epoch == epoch_counter.limit:\n",
" avg_loss = self.avg_train_loss\n",
"\n",
" self.on_stage_end(sb.Stage.TRAIN, self.avg_train_loss, epoch)\n",
" self.avg_train_loss = 0.0\n",
" self.step = 0\n",
"\n",
" # Validation stage\n",
" if valid_set is not None:\n",
" self.on_stage_start(sb.Stage.VALID, epoch)\n",
" self.modules.eval()\n",
" avg_valid_loss = 0.0\n",
" with torch.no_grad():\n",
" for batch in tqdm(\n",
" valid_set, dynamic_ncols=True, disable=not enable\n",
" ):\n",
" self.step += 1\n",
" loss = self.evaluate_batch(batch, stage=sb.Stage.VALID)\n",
" avg_valid_loss = self.update_average(\n",
" loss, avg_valid_loss\n",
" )\n",
"\n",
" # Debug mode only runs a few batches\n",
" if self.debug and self.step == self.debug_batches:\n",
" break\n",
"\n",
" # Only run validation \"on_stage_end\" on main process\n",
" self.step = 0\n",
" valid_wer = self.on_stage_end(sb.Stage.VALID, avg_valid_loss, epoch)\n",
" if epoch == epoch_counter.limit:\n",
" valid_wer_last = valid_wer\n",
"\n",
" # Debug mode only runs a few epochs\n",
" if self.debug and epoch == self.debug_epochs:\n",
" break\n",
"\n",
" return batch_count, avg_loss, valid_wer_last\n",
"\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yKX7H6NBCpCJ"
},
"source": [
"### Initialise Brain class and dataset\n",
"Next, we instantiate the `ASR` Brain class defined before, as well as the dataset. In SpeechBrain, this would be the `Main` function of your speech recipe. Here, we need to encapsulate this because Flower will call this function for each client of the federated setup to initialise it properly!\n",
"\n",
"```python\n",
"def int_model(\n",
" flower_path,\n",
" tr_path,\n",
" dev_path,\n",
" test_path,\n",
" save_path,\n",
" data_path,\n",
" config_file=\"CRDNN.yaml\",\n",
" tokenizer_path=None,\n",
" eval_device=\"cuda:0\",\n",
" evaluate=False,\n",
" add_train=False):\n",
"\n",
" # Load hyperparameters file with command-line overrides\n",
" params_file = flower_path + config_file\n",
"\n",
" # Override with FLOWER PARAMS\n",
" if evaluate:\n",
" overrides = {\n",
" \"output_folder\": save_path,\n",
" \"number_of_epochs\": 1,\n",
" \"test_batch_size\": 4,\n",
" \"device\": eval_device,\n",
" # \"device\": 'cpu'\n",
" }\n",
" elif add_train:\n",
" overrides = {\n",
" \"output_folder\": save_path,\n",
" \"lr\": 0.01\n",
" }\n",
"\n",
" else:\n",
" overrides = {\n",
" \"output_folder\": save_path\n",
" }\n",
" run_opts = None\n",
"\n",
" with open(params_file) as fin:\n",
" params = load_hyperpyyaml(fin, overrides)\n",
"\n",
" params[\"data_folder\"] = data_path\n",
" params[\"train_tsv_file\"] = tr_path\n",
" params[\"dev_tsv_file\"] = dev_path\n",
" params[\"test_tsv_file\"] = test_path\n",
" params[\"save_folder\"] = params[\"output_folder\"] + \"/save\"\n",
" params[\"train_csv\"] = params[\"save_folder\"] + \"/train.csv\"\n",
" params[\"valid_csv\"] = params[\"save_folder\"] + \"/dev.csv\"\n",
" params[\"test_csv\"] = params[\"save_folder\"] + \"/test.csv\"\n",
" params[\"tokenizer_csv\"] = tokenizer_path if tokenizer_path is not None else params[\"train_csv\"]\n",
"\n",
" # Dataset preparation (parsing CommonVoice)\n",
" from common_voice_prepare import prepare_common_voice # noqa\n",
"\n",
" # Create experiment directory\n",
" sb.create_experiment_directory(\n",
" experiment_directory=params[\"output_folder\"],\n",
" hyperparams_to_save=params_file,\n",
" overrides=overrides,\n",
" )\n",
"\n",
" # Due to DDP, we do the preparation ONLY on the main python process\n",
" run_on_main(\n",
" prepare_common_voice,\n",
" kwargs={\n",
" \"data_folder\": params[\"data_folder\"],\n",
" \"save_folder\": params[\"save_folder\"],\n",
" \"train_tsv_file\": params[\"train_tsv_file\"],\n",
" \"dev_tsv_file\": params[\"dev_tsv_file\"],\n",
" \"test_tsv_file\": params[\"test_tsv_file\"],\n",
" \"accented_letters\": params[\"accented_letters\"],\n",
" \"language\": params[\"language\"],\n",
" },\n",
" )\n",
"\n",
" # Defining tokenizer and loading it\n",
" tokenizer = SentencePiece(\n",
" model_dir=params[\"save_folder\"],\n",
" vocab_size=params[\"output_neurons\"],\n",
" annotation_train=params[\"train_csv\"],\n",
" annotation_read=\"wrd\",\n",
" model_type=params[\"token_type\"],\n",
" character_coverage=params[\"character_coverage\"],\n",
" )\n",
"\n",
" # Create the datasets objects as well as tokenization and encoding :-D\n",
" train_data, valid_data, test_data = dataio_prepare(params, tokenizer)\n",
"\n",
" # Trainer initialization\n",
" asr_brain = ASR(\n",
" modules=params[\"modules\"],\n",
" hparams=params,\n",
" run_opts=run_opts,\n",
" opt_class=params[\"opt_class\"],\n",
" checkpointer=params[\"checkpointer\"],\n",
" )\n",
"\n",
" # Adding objects to trainer.\n",
" asr_brain.tokenizer = tokenizer\n",
"\n",
" return asr_brain, [train_data, valid_data, test_data]\n",
"\n",
"asr_brain, dataset = int_model(...)\n",
"```\n",
"This function can also load all hyper-parameters from provided `yaml` file as normal SpeechBrain model training. Additionally, we can overwrite the hyper-parameters of `yaml` file here."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dQxY4EWKKa3F"
},
"source": [
"### Define a SpeechBrain client\n",
"We define a customised Flower client that can mainly achieve three features:\n",
"* Set server weights to SpeechBrain model.\n",
"* Trigger SpeechBrain model training.\n",
"* Extract model weights after training.\n",
"\n",
"Let's first see `set_weights` and `get_weights` functions. This is quite simple, just the transformation between pytorch tensor and NumPy ndarrays.\n",
"\n",
"If you are familiar with SpeechBrain, you will recognize the **modules** argument. If not, this simply is all the PyTorch blocks of your pipeline. This means that we can iterated over the state_dict to obtain absolutely all the parameters of the speech pipeline.\n",
"\n",
"```python\n",
"def set_weights(weights, modules, device) -> None:\n",
" \"\"\"Set model weights from a list of NumPy ndarrays.\"\"\"\n",
" state_dict = OrderedDict()\n",
" valid_keys = [k for k in modules.state_dict().keys()]\n",
" for k, v in zip(valid_keys, weights):\n",
" v_ = torch.Tensor(np.array(v))\n",
" v_ = v_.to(device)\n",
" state_dict[k] = v_\n",
" modules.load_state_dict(state_dict, strict=False)\n",
"\n",
"def get_weights(modules):\n",
" \"\"\"Get model weights as a list of NumPy ndarrays.\"\"\"\n",
" w = []\n",
" for k, v in modules.state_dict().items():\n",
" w.append(v.cpu().numpy())\n",
" return w\n",
"```\n",
"\n",
"Then, we define the `SpeechBrainClient` class.\n",
"\n",
"```python\n",
"class SpeechBrainClient(fl.client.NumPyClient):\n",
" def __init__(self,\n",
" cid: int,\n",
" asr_brain,\n",
" dataset,\n",
" pre_train_model_path=None):\n",
"\n",
" self.cid = cid\n",
" self.params = asr_brain.hparams\n",
" self.modules = asr_brain.modules\n",
" self.asr_brain = asr_brain\n",
" self.dataset = dataset\n",
" self.pre_train_model_path = pre_train_model_path\n",
"\n",
" def get_parameters(self, config):\n",
" print(f\"Client {self.cid}: get_parameters\")\n",
" weights = get_weights(self.modules)\n",
" return weights\n",
"\n",
" def fit(self, parameters, config):\n",
" print(f\"Client {self.cid}: fit\")\n",
"\n",
" # Read training configuration\n",
" global_rounds = int(config[\"epoch_global\"])\n",
" print(\"Current global round: \", global_rounds)\n",
" epochs = int(config[\"epochs\"])\n",
"\n",
" (\n",
" new_weights,\n",
" num_examples,\n",
" num_examples_ceil,\n",
" fit_duration,\n",
" avg_loss,\n",
" avg_wer\n",
" ) = self.train_speech_recogniser(\n",
" parameters,\n",
" epochs,\n",
" global_rounds=global_rounds\n",
" )\n",
"\n",
" metrics = {\"train_loss\": avg_loss, \"wer\": avg_wer}\n",
"\n",
" # Release GPU VRAM\n",
" torch.cuda.empty_cache()\n",
"\n",
" return self.get_parameters(config={}), num_examples, metrics\n",
"\n",
" def evaluate(self, parameters, config):\n",
" print(f\"Client {self.cid}: evaluate\")\n",
"\n",
" num_examples, loss, wer = self.train_speech_recogniser(\n",
" server_params=parameters,\n",
" epochs=1,\n",
" evaluate=True\n",
" )\n",
" torch.cuda.empty_cache()\n",
"\n",
" # Return the number of evaluation examples and the evaluation result (loss)\n",
" return float(loss), num_examples, {\"accuracy\": float(wer)}\n",
"\n",
"\n",
" def train_speech_recogniser(\n",
" self,\n",
" server_params,\n",
" epochs,\n",
" evaluate=False,\n",
" add_train=False,\n",
" global_rounds=None\n",
" ):\n",
" '''\n",
" This function aims to trigger client local training or evaluation\n",
" via calling the fit() or evaluate() function of SpeechBrain Brain\n",
" class. It can also load a pre-trained model before training.\n",
"\n",
" Arguments\n",
" ---------\n",
" server_params : Parameter\n",
" The parameters given by the server.\n",
" epochs : int\n",
" The total number of local epochs for training.\n",
" evaluate : bool\n",
" Evaluation or not.\n",
" add_train : bool\n",
" The additional training on the server or not.\n",
" global_rounds : int\n",
" The current global round.\n",
" \n",
" Returns\n",
" -------\n",
" model weights after training,\n",
" number of total training samples,\n",
" number of training samples ceil,\n",
" training duration,\n",
" training loss,\n",
" valid WER\n",
" '''\n",
" self.params.epoch_counter.limit = epochs\n",
" self.params.epoch_counter.current = 0\n",
"\n",
" train_data, valid_data, test_data = self.dataset\n",
"\n",
" # Set the parameters to the ones given by the server\n",
" if server_params is not None:\n",
" set_weights(server_params, self.modules, evaluate, add_train, self.params.device)\n",
"\n",
" # Load the pre-trained model at global round 1\n",
" if global_rounds == 1 and not add_train and not evaluate:\n",
" if self.pre_train_model_path is not None:\n",
" print(\"loading pre-trained model...\")\n",
" state_dict = torch.load(self.pre_train_model_path)\n",
" self.params.model.load_state_dict(state_dict)\n",
"\n",
" # Exclude two layers which do not join the aggregation\n",
" if global_rounds != 1:\n",
" # Two layer names that do not join aggregation\n",
" k1 = \"enc.DNN.block_0.norm.norm.num_batches_tracked\"\n",
" k2 = \"enc.DNN.block_1.norm.norm.num_batches_tracked\"\n",
"\n",
" state_dict_norm = OrderedDict()\n",
" state_dict_norm[k1] = torch.tensor(1, device=self.params.device)\n",
" state_dict_norm[k2] = torch.tensor(0, device=self.params.device)\n",
" self.modules.load_state_dict(state_dict_norm, strict=False)\n",
"\n",
" # Load best checkpoint for evaluation\n",
" if evaluate:\n",
" self.params.test_wer_file = self.params.output_folder + \"/wer_test.txt\"\n",
" batch_count, loss, wer = self.asr_brain.evaluate(\n",
" test_data,\n",
" test_loader_kwargs=self.params.test_dataloader_options,\n",
" )\n",
" return batch_count, loss, wer\n",
"\n",
" # Training\n",
" fit_begin = timeit.default_timer()\n",
"\n",
" count_sample, avg_loss, avg_wer = self.asr_brain.fit(\n",
" self.params.epoch_counter,\n",
" train_data,\n",
" valid_data,\n",
" train_loader_kwargs=self.params.dataloader_options,\n",
" valid_loader_kwargs=self.params.test_dataloader_options,\n",
" )\n",
"\n",
" # Exp operation to avg_loss and avg_wer\n",
" avg_wer = 100 if avg_wer > 100 else avg_wer\n",
" avg_loss = exp(- avg_loss)\n",
" avg_wer = exp(100 - avg_wer)\n",
"\n",
" # Retrieve the parameters to return\n",
" params_list = get_weights(self.modules)\n",
"\n",
" if add_train:\n",
" return params_list\n",
"\n",
" fit_duration = timeit.default_timer() - fit_begin\n",
"\n",
" # Manage when last batch isn't full w.r.t batch size\n",
" train_set = sb.dataio.dataloader.make_dataloader(train_data, **self.params.dataloader_options)\n",
" if count_sample > len(train_set) * self.params.batch_size * epochs:\n",
" count_sample = len(train_set) * self.params.batch_size * epochs\n",
"\n",
" return (\n",
" params_list,\n",
" count_sample,\n",
" len(train_set) * self.params.batch_size * epochs,\n",
" fit_duration,\n",
" avg_loss,\n",
" avg_wer\n",
" )\n",
"\n",
"client = SpeechBrainClient(...)\n",
"```\n",
"\n",
"The training process happens in `fit()` method of our defined `SpeechBrainClient` class.. A function named `train_speech_recogniser()` is called inside of `fit()`. This function aims to trigger client local training by calling `fit()` method of SpeechBrain Brain class. Also, we can load a pre-trained model at 1st global round for initialisation.\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HYXRGO_-R1Ol"
},
"source": [
"### Define a Flower Strategy on the server side\n",
"To achieve different aggregation weighting strategies and an additional training after aggregation, we need to define a customised Flower Strategy class.\n",
"\n",
"```python\n",
"class TrainAfterAggregateStrategy(fl.server.strategy.FedAvg):\n",
" def aggregate_fit(\n",
" self,\n",
" server_round: int,\n",
" results: List[Tuple[ClientProxy, FitRes]],\n",
" failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],\n",
" ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:\n",
" \"\"\"Aggregate fit results using weighted average.\"\"\"\n",
"\n",
" if not results:\n",
" return None, {}\n",
" # Do not aggregate if there are failures and failures are not accepted\n",
" if not self.accept_failures and failures:\n",
" return None, {}\n",
"\n",
" # Convert results\n",
" key_name = 'train_loss' if args.weight_strategy == 'loss' else 'wer'\n",
" weights = None\n",
"\n",
" # Standard FedAvg\n",
" if args.weight_strategy == 'num':\n",
" weights_results = [\n",
" (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)\n",
" for _, fit_res in results\n",
" ]\n",
" # Here we do aggregation\n",
" weights = aggregate(weights_results)\n",
"\n",
" # If loss-based or WER-based aggregation, fetch the values of loss or WER from `metrics`\n",
" elif args.weight_strategy == 'loss' or args.weight_strategy == 'wer':\n",
" weights_results = [\n",
" (parameters_to_ndarrays(fit_res.parameters), fit_res.metrics[key_name])\n",
" for client, fit_res in results\n",
" ]\n",
" # Here we do aggregation\n",
" weights = aggregate(weights_results)\n",
"\n",
" # Aggregate custom metrics if aggregation fn was provided\n",
" metrics_aggregated = {}\n",
" if self.fit_metrics_aggregation_fn:\n",
" fit_metrics = [(res.num_examples, res.metrics) for _, res in results]\n",
" metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)\n",
" elif server_round == 1: # Only log this warning once\n",
" log(WARNING, \"No fit_metrics_aggregation_fn provided\")\n",
"\n",
" # Train model after aggregation\n",
" if weights is not None:\n",
" print(f\"Train model after aggregation\")\n",
" save_path = args.save_path + \"add_train\"\n",
" # Initial Brain class and dataset\n",
" asr_brain, dataset = int_model(args.config_path, args.tr_add_path, args.tr_path, args.tr_path,\n",
" save_path,\n",
" args.data_path, args.config_file, args.tokenizer_path, add_train=True)\n",
" # Initial SpeechBrain client\n",
" client = SpeechBrainClient(None, asr_brain, dataset)\n",
"\n",
" # Call the training function\n",
" weights_after_server_side_training = client.train_speech_recogniser(\n",
" server_params=weights,\n",
" epochs=1,\n",
" add_train=True\n",
" )\n",
" # Release cuda memory after training\n",
" torch.cuda.empty_cache()\n",
" return ndarrays_to_parameters(weights_after_server_side_training), metrics_aggregated \n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1N7R8UhLJwsX"
},
"source": [
"## Run an experiment\n",
"\n",
"OK, it's time for launching our experiment!"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 176603,
"status": "ok",
"timestamp": 1709075987779,
"user": {
"displayName": "Mirco Ravanelli",
"userId": "06892056361698510975"
},
"user_tz": 300
},
"id": "eIinvOYVgKaz",
"outputId": "b940513c-62b7-4ad9-f534-e0c6ab221fbe"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All background processes were killed.\n",
"2024-02-27 23:16:54.932275: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-02-27 23:16:54.932334: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-02-27 23:16:54.933651: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-02-27 23:16:54.940829: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"Starting 1 clients.\n",
"Starting client(cid=0) with partition 0 out of 1 clients.\n",
"2024-02-27 23:16:56.317841: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
"Server IP: 172.28.0.12\n",
"WARNING flwr 2024-02-27 23:17:00,830 | fedavg.py:118 | \n",
"Setting `min_available_clients` lower than `min_fit_clients` or\n",
"`min_evaluate_clients` can cause the server to fail when there are too few clients\n",
"connected to the server. `min_available_clients` must be set to a value larger\n",
"than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.\n",
"\n",
"INFO flwr 2024-02-27 23:17:00,838 | app.py:163 | Starting Flower server, config: ServerConfig(num_rounds=1, round_timeout=None)\n",
"INFO flwr 2024-02-27 23:17:00,890 | app.py:176 | Flower ECE: gRPC server running (1 rounds), SSL is disabled\n",
"INFO flwr 2024-02-27 23:17:00,890 | server.py:89 | Initializing global parameters\n",
"INFO flwr 2024-02-27 23:17:00,890 | server.py:276 | Requesting initial parameters from one random client\n",
"Started 1 clients.\n",
"2024-02-27 23:17:06.200995: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-02-27 23:17:06.201049: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-02-27 23:17:06.202383: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-02-27 23:17:06.209160: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2024-02-27 23:17:07.312437: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
"speechbrain.core - Beginning experiment!\n",
"speechbrain.core - Experiment folder: ./results/client_0\n",
"common_voice_prepare - Preparing CSV files for 50 samples ...\n",
"common_voice_prepare - Creating csv lists in ./results/client_0/save/train.csv ...\n",
"100% 50/50 [00:00<00:00, 129.06it/s]\n",
"common_voice_prepare - ./results/client_0/save/train.csv sucessfully created!\n",
"common_voice_prepare - Number of samples: 50 \n",
"common_voice_prepare - Total duration: 0.08 Hours\n",
"common_voice_prepare - Preparing CSV files for 50 samples ...\n",
"common_voice_prepare - Creating csv lists in ./results/client_0/save/dev.csv ...\n",
"100% 50/50 [00:00<00:00, 159.40it/s]\n",
"common_voice_prepare - ./results/client_0/save/dev.csv sucessfully created!\n",
"common_voice_prepare - Number of samples: 50 \n",
"common_voice_prepare - Total duration: 0.08 Hours\n",
"common_voice_prepare - Preparing CSV files for 50 samples ...\n",
"common_voice_prepare - Creating csv lists in ./results/client_0/save/test.csv ...\n",
"100% 50/50 [00:00<00:00, 178.27it/s]\n",
"common_voice_prepare - ./results/client_0/save/test.csv sucessfully created!\n",
"common_voice_prepare - Number of samples: 50 \n",
"common_voice_prepare - Total duration: 0.08 Hours\n",
"speechbrain.tokenizers.SentencePiece - Train tokenizer with type:unigram\n",
"speechbrain.tokenizers.SentencePiece - Extract wrd sequences from:./results/client_0/save/train.csv\n",
"speechbrain.tokenizers.SentencePiece - Text file created at: ./results/client_0/save/train.txt\n",
"sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=./results/client_0/save/train.txt --model_prefix=./results/client_0/save/250_unigram --model_type=unigram --bos_id=-1 --eos_id=-1 --pad_id=-1 --unk_id=0 --max_sentencepiece_length=10 --character_coverage=1.0 --add_dummy_prefix=True --vocab_size=250\n",
"sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : \n",
"trainer_spec {\n",
" input: ./results/client_0/save/train.txt\n",
" input_format: \n",
" model_prefix: ./results/client_0/save/250_unigram\n",
" model_type: UNIGRAM\n",
" vocab_size: 250\n",
" self_test_sample_size: 0\n",
" character_coverage: 1\n",
" input_sentence_size: 0\n",
" shuffle_input_sentence: 1\n",
" seed_sentencepiece_size: 1000000\n",
" shrinking_factor: 0.75\n",
" max_sentence_length: 4192\n",
" num_threads: 16\n",
" num_sub_iterations: 2\n",
" max_sentencepiece_length: 10\n",
" split_by_unicode_script: 1\n",
" split_by_number: 1\n",
" split_by_whitespace: 1\n",
" split_digits: 0\n",
" pretokenization_delimiter: \n",
" treat_whitespace_as_suffix: 0\n",
" allow_whitespace_only_pieces: 0\n",
" required_chars: \n",
" byte_fallback: 0\n",
" vocabulary_output_piece_score: 1\n",
" train_extremely_large_corpus: 0\n",
" hard_vocab_limit: 1\n",
" use_all_vocab: 0\n",
" unk_id: 0\n",
" bos_id: -1\n",
" eos_id: -1\n",
" pad_id: -1\n",
" unk_piece: \n",
" bos_piece: \n",
" eos_piece: \n",
" pad_piece: \n",
" unk_surface: ⁇ \n",
" enable_differential_privacy: 0\n",
" differential_privacy_noise_level: 0\n",
" differential_privacy_clipping_threshold: 0\n",
"}\n",
"normalizer_spec {\n",
" name: nmt_nfkc\n",
" add_dummy_prefix: 1\n",
" remove_extra_whitespaces: 1\n",
" escape_whitespaces: 1\n",
" normalization_rule_tsv: \n",
"}\n",
"denormalizer_spec {}\n",
"trainer_interface.cc(351) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.\n",
"trainer_interface.cc(183) LOG(INFO) Loading corpus: ./results/client_0/save/train.txt\n",
"trainer_interface.cc(407) LOG(INFO) Loaded all 50 sentences\n",
"trainer_interface.cc(423) LOG(INFO) Adding meta_piece: \n",
"trainer_interface.cc(428) LOG(INFO) Normalizing sentences...\n",
"trainer_interface.cc(537) LOG(INFO) all chars count=3018\n",
"trainer_interface.cc(558) LOG(INFO) Alphabet size=37\n",
"trainer_interface.cc(559) LOG(INFO) Final character coverage=1\n",
"trainer_interface.cc(591) LOG(INFO) Done! preprocessed 50 sentences.\n",
"unigram_model_trainer.cc(222) LOG(INFO) Making suffix array...\n",
"unigram_model_trainer.cc(226) LOG(INFO) Extracting frequent sub strings... node_num=1348\n",
"unigram_model_trainer.cc(274) LOG(INFO) Initialized 635 seed sentencepieces\n",
"trainer_interface.cc(597) LOG(INFO) Tokenizing input sentences with whitespace: 50\n",
"trainer_interface.cc(608) LOG(INFO) Done! 328\n",
"unigram_model_trainer.cc(564) LOG(INFO) Using 328 sentences for EM training\n",
"unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=0 size=411 obj=14.1573 num_tokens=965 num_tokens/piece=2.34793\n",
"unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=1 size=365 obj=13.6029 num_tokens=976 num_tokens/piece=2.67397\n",
"unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=0 size=275 obj=14.1354 num_tokens=1061 num_tokens/piece=3.85818\n",
"unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=1 size=274 obj=13.813 num_tokens=1061 num_tokens/piece=3.87226\n",
"trainer_interface.cc(686) LOG(INFO) Saving model: ./results/client_0/save/250_unigram.model\n",
"trainer_interface.cc(698) LOG(INFO) Saving vocabs: ./results/client_0/save/250_unigram.vocab\n",
"speechbrain.tokenizers.SentencePiece - ==== Loading Tokenizer ===\n",
"speechbrain.tokenizers.SentencePiece - Tokenizer path: ./results/client_0/save/250_unigram.model\n",
"speechbrain.tokenizers.SentencePiece - Tokenizer vocab_size: 250\n",
"speechbrain.tokenizers.SentencePiece - Tokenizer type: unigram\n",
"speechbrain.core - Info: device arg from hparam file is used\n",
"speechbrain.core - Info: precision arg from hparam file is used\n",
"speechbrain.core - Gradscaler enabled: False. Using precision: fp32.\n",
"speechbrain.core - 46.3M trainable parameters in ASR\n",
"INFO flwr 2024-02-27 23:17:31,020 | grpc.py:52 | Opened insecure gRPC connection (no certificates were passed)\n",
"flwr - Opened insecure gRPC connection (no certificates were passed)\n",
"DEBUG flwr 2024-02-27 23:17:31,023 | connection.py:55 | ChannelConnectivity.IDLE\n",
"DEBUG flwr 2024-02-27 23:17:31,027 | connection.py:55 | ChannelConnectivity.CONNECTING\n",
"DEBUG flwr 2024-02-27 23:17:31,031 | connection.py:55 | ChannelConnectivity.READY\n",
"Client 0: get_parameters\n",
"INFO flwr 2024-02-27 23:17:34,462 | server.py:280 | Received initial parameters from one random client\n",
"INFO flwr 2024-02-27 23:17:34,462 | server.py:91 | Evaluating initial parameters\n",
"speechbrain.core - Beginning experiment!\n",
"speechbrain.core - Experiment folder: ./results/evaluation\n",
"common_voice_prepare - Preparing CSV files for 50 samples ...\n",
"common_voice_prepare - Creating csv lists in ./results/evaluation/save/train.csv ...\n",
"100% 50/50 [00:00<00:00, 162.24it/s]\n",
"common_voice_prepare - ./results/evaluation/save/train.csv sucessfully created!\n",
"common_voice_prepare - Number of samples: 50 \n",
"common_voice_prepare - Total duration: 0.08 Hours\n",
"common_voice_prepare - Preparing CSV files for 50 samples ...\n",
"common_voice_prepare - Creating csv lists in ./results/evaluation/save/dev.csv ...\n",
"100% 50/50 [00:00<00:00, 169.96it/s]\n",
"common_voice_prepare - ./results/evaluation/save/dev.csv sucessfully created!\n",
"common_voice_prepare - Number of samples: 50 \n",
"common_voice_prepare - Total duration: 0.08 Hours\n",
"common_voice_prepare - Preparing CSV files for 50 samples ...\n",
"common_voice_prepare - Creating csv lists in ./results/evaluation/save/test.csv ...\n",
"100% 50/50 [00:00<00:00, 172.94it/s]\n",
"common_voice_prepare - ./results/evaluation/save/test.csv sucessfully created!\n",
"common_voice_prepare - Number of samples: 50 \n",
"common_voice_prepare - Total duration: 0.08 Hours\n",
"speechbrain.tokenizers.SentencePiece - Train tokenizer with type:unigram\n",
"speechbrain.tokenizers.SentencePiece - Extract wrd sequences from:./results/evaluation/save/train.csv\n",
"speechbrain.tokenizers.SentencePiece - Text file created at: ./results/evaluation/save/train.txt\n",
"sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=./results/evaluation/save/train.txt --model_prefix=./results/evaluation/save/250_unigram --model_type=unigram --bos_id=-1 --eos_id=-1 --pad_id=-1 --unk_id=0 --max_sentencepiece_length=10 --character_coverage=1.0 --add_dummy_prefix=True --vocab_size=250\n",
"sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : \n",
"trainer_spec {\n",
" input: ./results/evaluation/save/train.txt\n",
" input_format: \n",
" model_prefix: ./results/evaluation/save/250_unigram\n",
" model_type: UNIGRAM\n",
" vocab_size: 250\n",
" self_test_sample_size: 0\n",
" character_coverage: 1\n",
" input_sentence_size: 0\n",
" shuffle_input_sentence: 1\n",
" seed_sentencepiece_size: 1000000\n",
" shrinking_factor: 0.75\n",
" max_sentence_length: 4192\n",
" num_threads: 16\n",
" num_sub_iterations: 2\n",
" max_sentencepiece_length: 10\n",
" split_by_unicode_script: 1\n",
" split_by_number: 1\n",
" split_by_whitespace: 1\n",
" split_digits: 0\n",
" pretokenization_delimiter: \n",
" treat_whitespace_as_suffix: 0\n",
" allow_whitespace_only_pieces: 0\n",
" required_chars: \n",
" byte_fallback: 0\n",
" vocabulary_output_piece_score: 1\n",
" train_extremely_large_corpus: 0\n",
" hard_vocab_limit: 1\n",
" use_all_vocab: 0\n",
" unk_id: 0\n",
" bos_id: -1\n",
" eos_id: -1\n",
" pad_id: -1\n",
" unk_piece: \n",
" bos_piece: \n",
" eos_piece: \n",
" pad_piece: \n",
" unk_surface: ⁇ \n",
" enable_differential_privacy: 0\n",
" differential_privacy_noise_level: 0\n",
" differential_privacy_clipping_threshold: 0\n",
"}\n",
"normalizer_spec {\n",
" name: nmt_nfkc\n",
" add_dummy_prefix: 1\n",
" remove_extra_whitespaces: 1\n",
" escape_whitespaces: 1\n",
" normalization_rule_tsv: \n",
"}\n",
"denormalizer_spec {}\n",
"trainer_interface.cc(351) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.\n",
"trainer_interface.cc(183) LOG(INFO) Loading corpus: ./results/evaluation/save/train.txt\n",
"trainer_interface.cc(407) LOG(INFO) Loaded all 50 sentences\n",
"trainer_interface.cc(423) LOG(INFO) Adding meta_piece: \n",
"trainer_interface.cc(428) LOG(INFO) Normalizing sentences...\n",
"trainer_interface.cc(537) LOG(INFO) all chars count=3018\n",
"trainer_interface.cc(558) LOG(INFO) Alphabet size=37\n",
"trainer_interface.cc(559) LOG(INFO) Final character coverage=1\n",
"trainer_interface.cc(591) LOG(INFO) Done! preprocessed 50 sentences.\n",
"unigram_model_trainer.cc(222) LOG(INFO) Making suffix array...\n",
"unigram_model_trainer.cc(226) LOG(INFO) Extracting frequent sub strings... node_num=1348\n",
"unigram_model_trainer.cc(274) LOG(INFO) Initialized 635 seed sentencepieces\n",
"trainer_interface.cc(597) LOG(INFO) Tokenizing input sentences with whitespace: 50\n",
"trainer_interface.cc(608) LOG(INFO) Done! 328\n",
"unigram_model_trainer.cc(564) LOG(INFO) Using 328 sentences for EM training\n",
"unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=0 size=411 obj=14.1573 num_tokens=965 num_tokens/piece=2.34793\n",
"unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=1 size=365 obj=13.6029 num_tokens=976 num_tokens/piece=2.67397\n",
"unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=0 size=275 obj=14.1354 num_tokens=1061 num_tokens/piece=3.85818\n",
"unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=1 size=274 obj=13.813 num_tokens=1061 num_tokens/piece=3.87226\n",
"trainer_interface.cc(686) LOG(INFO) Saving model: ./results/evaluation/save/250_unigram.model\n",
"trainer_interface.cc(698) LOG(INFO) Saving vocabs: ./results/evaluation/save/250_unigram.vocab\n",
"speechbrain.tokenizers.SentencePiece - ==== Loading Tokenizer ===\n",
"speechbrain.tokenizers.SentencePiece - Tokenizer path: ./results/evaluation/save/250_unigram.model\n",
"speechbrain.tokenizers.SentencePiece - Tokenizer vocab_size: 250\n",
"speechbrain.tokenizers.SentencePiece - Tokenizer type: unigram\n",
"speechbrain.core - Info: device arg from hparam file is used\n",
"speechbrain.core - Info: precision arg from hparam file is used\n",
"speechbrain.core - Gradscaler enabled: False. Using precision: fp32.\n",
"speechbrain.core - 46.3M trainable parameters in ASR\n",
"100% 13/13 [00:14<00:00, 1.15s/it]\n",
"speechbrain.utils.train_logger - Epoch loaded: 0 - test loss: 0.00e+00, test CER: 1.00e+02, test WER: 1.00e+02\n",
"INFO flwr 2024-02-27 23:18:08,737 | server.py:94 | initial parameters (loss, other metrics): 0.0, {'accuracy': 100.0}\n",
"flwr - initial parameters (loss, other metrics): 0.0, {'accuracy': 100.0}\n",
"INFO flwr 2024-02-27 23:18:08,739 | server.py:104 | FL starting\n",
"flwr - FL starting\n",
"DEBUG flwr 2024-02-27 23:18:08,739 | server.py:222 | fit_round 1: strategy sampled 1 clients (out of 1)\n",
"Client 0: fit\n",
"Current global round: 1\n",
"speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.\n",
"speechbrain.utils.epoch_loop - Going into epoch 1\n",
"100% 12/12 [00:11<00:00, 1.05it/s, train_loss=7.47]\n",
"100% 13/13 [00:03<00:00, 4.20it/s]\n",
"speechbrain.utils.train_logger - epoch: 1, lr: 1.00e+00 - train loss: 7.47 - valid loss: 5.40, valid CER: 89.62, valid WER: 1.01e+02\n",
"Client 0: get_parameters\n",
"DEBUG flwr 2024-02-27 23:18:29,403 | server.py:236 | fit_round 1 received 1 results and 0 failures\n",
"WARNING flwr 2024-02-27 23:18:29,615 | server.py:103 | No fit_metrics_aggregation_fn provided\n",
"flwr - No fit_metrics_aggregation_fn provided\n",
"Train model after aggregation\n",
"speechbrain.core - Beginning experiment!\n",
"speechbrain.core - Experiment folder: ./results/add_train\n",
"common_voice_prepare - Preparing CSV files for 50 samples ...\n",
"common_voice_prepare - Creating csv lists in ./results/add_train/save/train.csv ...\n",
"100% 50/50 [00:00<00:00, 165.90it/s]\n",
"common_voice_prepare - ./results/add_train/save/train.csv sucessfully created!\n",
"common_voice_prepare - Number of samples: 50 \n",
"common_voice_prepare - Total duration: 0.08 Hours\n",
"common_voice_prepare - Preparing CSV files for 50 samples ...\n",
"common_voice_prepare - Creating csv lists in ./results/add_train/save/dev.csv ...\n",
"100% 50/50 [00:00<00:00, 170.89it/s]\n",
"common_voice_prepare - ./results/add_train/save/dev.csv sucessfully created!\n",
"common_voice_prepare - Number of samples: 50 \n",
"common_voice_prepare - Total duration: 0.08 Hours\n",
"common_voice_prepare - Preparing CSV files for 50 samples ...\n",
"common_voice_prepare - Creating csv lists in ./results/add_train/save/test.csv ...\n",
"100% 50/50 [00:00<00:00, 169.93it/s]\n",
"common_voice_prepare - ./results/add_train/save/test.csv sucessfully created!\n",
"common_voice_prepare - Number of samples: 50 \n",
"common_voice_prepare - Total duration: 0.08 Hours\n",
"speechbrain.tokenizers.SentencePiece - Train tokenizer with type:unigram\n",
"speechbrain.tokenizers.SentencePiece - Extract wrd sequences from:./results/add_train/save/train.csv\n",
"speechbrain.tokenizers.SentencePiece - Text file created at: ./results/add_train/save/train.txt\n",
"sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=./results/add_train/save/train.txt --model_prefix=./results/add_train/save/250_unigram --model_type=unigram --bos_id=-1 --eos_id=-1 --pad_id=-1 --unk_id=0 --max_sentencepiece_length=10 --character_coverage=1.0 --add_dummy_prefix=True --vocab_size=250\n",
"sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : \n",
"trainer_spec {\n",
" input: ./results/add_train/save/train.txt\n",
" input_format: \n",
" model_prefix: ./results/add_train/save/250_unigram\n",
" model_type: UNIGRAM\n",
" vocab_size: 250\n",
" self_test_sample_size: 0\n",
" character_coverage: 1\n",
" input_sentence_size: 0\n",
" shuffle_input_sentence: 1\n",
" seed_sentencepiece_size: 1000000\n",
" shrinking_factor: 0.75\n",
" max_sentence_length: 4192\n",
" num_threads: 16\n",
" num_sub_iterations: 2\n",
" max_sentencepiece_length: 10\n",
" split_by_unicode_script: 1\n",
" split_by_number: 1\n",
" split_by_whitespace: 1\n",
" split_digits: 0\n",
" pretokenization_delimiter: \n",
" treat_whitespace_as_suffix: 0\n",
" allow_whitespace_only_pieces: 0\n",
" required_chars: \n",
" byte_fallback: 0\n",
" vocabulary_output_piece_score: 1\n",
" train_extremely_large_corpus: 0\n",
" hard_vocab_limit: 1\n",
" use_all_vocab: 0\n",
" unk_id: 0\n",
" bos_id: -1\n",
" eos_id: -1\n",
" pad_id: -1\n",
" unk_piece: \n",
" bos_piece: \n",
" eos_piece: \n",
" pad_piece: \n",
" unk_surface: ⁇ \n",
" enable_differential_privacy: 0\n",
" differential_privacy_noise_level: 0\n",
" differential_privacy_clipping_threshold: 0\n",
"}\n",
"normalizer_spec {\n",
" name: nmt_nfkc\n",
" add_dummy_prefix: 1\n",
" remove_extra_whitespaces: 1\n",
" escape_whitespaces: 1\n",
" normalization_rule_tsv: \n",
"}\n",
"denormalizer_spec {}\n",
"trainer_interface.cc(351) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.\n",
"trainer_interface.cc(183) LOG(INFO) Loading corpus: ./results/add_train/save/train.txt\n",
"trainer_interface.cc(407) LOG(INFO) Loaded all 50 sentences\n",
"trainer_interface.cc(423) LOG(INFO) Adding meta_piece: \n",
"trainer_interface.cc(428) LOG(INFO) Normalizing sentences...\n",
"trainer_interface.cc(537) LOG(INFO) all chars count=3018\n",
"trainer_interface.cc(558) LOG(INFO) Alphabet size=37\n",
"trainer_interface.cc(559) LOG(INFO) Final character coverage=1\n",
"trainer_interface.cc(591) LOG(INFO) Done! preprocessed 50 sentences.\n",
"unigram_model_trainer.cc(222) LOG(INFO) Making suffix array...\n",
"unigram_model_trainer.cc(226) LOG(INFO) Extracting frequent sub strings... node_num=1348\n",
"unigram_model_trainer.cc(274) LOG(INFO) Initialized 635 seed sentencepieces\n",
"trainer_interface.cc(597) LOG(INFO) Tokenizing input sentences with whitespace: 50\n",
"trainer_interface.cc(608) LOG(INFO) Done! 328\n",
"unigram_model_trainer.cc(564) LOG(INFO) Using 328 sentences for EM training\n",
"unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=0 size=411 obj=14.1573 num_tokens=965 num_tokens/piece=2.34793\n",
"unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=1 size=365 obj=13.6029 num_tokens=976 num_tokens/piece=2.67397\n",
"unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=0 size=275 obj=14.1354 num_tokens=1061 num_tokens/piece=3.85818\n",
"unigram_model_trainer.cc(580) LOG(INFO) EM sub_iter=1 size=274 obj=13.813 num_tokens=1061 num_tokens/piece=3.87226\n",
"trainer_interface.cc(686) LOG(INFO) Saving model: ./results/add_train/save/250_unigram.model\n",
"trainer_interface.cc(698) LOG(INFO) Saving vocabs: ./results/add_train/save/250_unigram.vocab\n",
"speechbrain.tokenizers.SentencePiece - ==== Loading Tokenizer ===\n",
"speechbrain.tokenizers.SentencePiece - Tokenizer path: ./results/add_train/save/250_unigram.model\n",
"speechbrain.tokenizers.SentencePiece - Tokenizer vocab_size: 250\n",
"speechbrain.tokenizers.SentencePiece - Tokenizer type: unigram\n",
"speechbrain.core - Info: device arg from hparam file is used\n",
"speechbrain.core - Info: precision arg from hparam file is used\n",
"speechbrain.core - Gradscaler enabled: False. Using precision: fp32.\n",
"speechbrain.core - 46.3M trainable parameters in ASR\n",
"speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.\n",
"speechbrain.utils.epoch_loop - Going into epoch 1\n",
"100% 12/12 [00:12<00:00, 1.06s/it, train_loss=5.56]\n",
"100% 13/13 [00:10<00:00, 1.24it/s]\n",
"speechbrain.utils.train_logger - epoch: 1, lr: 1.00e-02 - train loss: 5.56 - valid loss: 5.34, valid CER: 1.18e+02, valid WER: 2.27e+02\n",
"speechbrain.core - Beginning experiment!\n",
"speechbrain.core - Experiment folder: ./results/evaluation\n",
"common_voice_prepare - ./results/evaluation/save/train.csv already exists, skipping data preparation!\n",
"common_voice_prepare - ./results/evaluation/save/dev.csv already exists, skipping data preparation!\n",
"common_voice_prepare - ./results/evaluation/save/test.csv already exists, skipping data preparation!\n",
"speechbrain.tokenizers.SentencePiece - Tokenizer is already trained.\n",
"speechbrain.tokenizers.SentencePiece - ==== Loading Tokenizer ===\n",
"speechbrain.tokenizers.SentencePiece - Tokenizer path: ./results/evaluation/save/250_unigram.model\n",
"speechbrain.tokenizers.SentencePiece - Tokenizer vocab_size: 250\n",
"speechbrain.tokenizers.SentencePiece - Tokenizer type: unigram\n",
"speechbrain.core - Info: device arg from hparam file is used\n",
"speechbrain.core - Info: precision arg from hparam file is used\n",
"speechbrain.core - Gradscaler enabled: False. Using precision: fp32.\n",
"speechbrain.core - 46.3M trainable parameters in ASR\n",
"100% 13/13 [00:12<00:00, 1.03it/s]\n",
"speechbrain.utils.train_logger - Epoch loaded: 0 - test loss: 0.00e+00, test CER: 1.00e+02, test WER: 1.00e+02\n",
"INFO flwr 2024-02-27 23:19:43,732 | server.py:125 | fit progress: (1, 0.0, {'accuracy': 100.0}, 94.99289715399999)\n",
"flwr - fit progress: (1, 0.0, {'accuracy': 100.0}, 94.99289715399999)\n",
"INFO flwr 2024-02-27 23:19:43,733 | client_manager.py:196 | Sampling failed: number of available clients (1) is less than number of requested clients (2).\n",
"flwr - Sampling failed: number of available clients (1) is less than number of requested clients (2).\n",
"INFO flwr 2024-02-27 23:19:43,733 | server.py:171 | evaluate_round 1: no clients selected, cancel\n",
"flwr - evaluate_round 1: no clients selected, cancel\n",
"INFO flwr 2024-02-27 23:19:43,733 | server.py:153 | FL finished in 94.99359605999996\n",
"flwr - FL finished in 94.99359605999996\n",
"INFO flwr 2024-02-27 23:19:43,743 | app.py:226 | app_fit: losses_distributed []\n",
"flwr - app_fit: losses_distributed []\n",
"INFO flwr 2024-02-27 23:19:43,743 | app.py:227 | app_fit: metrics_distributed_fit {}\n",
"flwr - app_fit: metrics_distributed_fit {}\n",
"INFO flwr 2024-02-27 23:19:43,743 | app.py:228 | app_fit: metrics_distributed {}\n",
"flwr - app_fit: metrics_distributed {}\n",
"INFO flwr 2024-02-27 23:19:43,743 | app.py:229 | app_fit: losses_centralized [(0, 0.0), (1, 0.0)]\n",
"flwr - app_fit: losses_centralized [(0, 0.0), (1, 0.0)]\n",
"INFO flwr 2024-02-27 23:19:43,743 | app.py:230 | app_fit: metrics_centralized {'accuracy': [(0, 100.0), (1, 100.0)]}\n",
"flwr - app_fit: metrics_centralized {'accuracy': [(0, 100.0), (1, 100.0)]}\n",
"DEBUG flwr 2024-02-27 23:19:43,792 | connection.py:220 | gRPC channel closed\n",
"INFO flwr 2024-02-27 23:19:43,793 | app.py:398 | Disconnect and shut down\n",
"flwr - Disconnect and shut down\n"
]
}
],
"source": [
"%killbgscripts\n",
"!((./server.sh & sleep 5s); ./clients.sh)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IUDpDzQsCkKI"
},
"source": [
"As Colab only allow one cell to be run at a time, logs from both the server and all the clients will be blended together in the cell output. Here are a few tips on reading the log and dealing with the environment:\n",
"\n",
"* At the start, the clients first load the data, and you will see `common_voice_prepare - Preparing CSV files for ... samples`. The statistic information of loading data will be showed. Then, the following lines are about trianing tokenizer. Afterwards, you'll see the expected training or evaluation progressbar in the log.\n",
"* To see the evaluation WER, look for the `speechbrain.utils.train_logger - Epoch loaded: 0 - test loss: ..., test CER: ..., test WER: ...`. To see the training WER and loss, look for the line `speechbrain.utils.train_logger - epoch: ..., lr: ... - train loss: ... - valid loss: ..., valid CER: ..., valid WER: ...`.\n",
"* To terminate the experiment early, press the stop icon next to the left of the cell. The stop icon is equivalent to `Ctrl+C` in a terminal, so you might have to press it multiple times to terminate quicker; if you get a pop-up saying that the environment became unresponsive, press `Cancel` rather than `Terminate`, as it should come back within a few seconds and you will not lose your progress.\n",
"\n",
"We can find that the results are horrible. This is because we didn't leverage a pre-trained model for initailisation and only trained on little toy dataset. Don't worry about results. You get acceptable results by running on real dataset."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sb_auto_footer",
"tags": [
"sb_auto_footer"
]
},
"source": [
"## Citing SpeechBrain\n",
"\n",
"If you use SpeechBrain in your research or business, please cite it using the following BibTeX entry:\n",
"\n",
"```bibtex\n",
"@misc{speechbrainV1,\n",
" title={Open-Source Conversational AI with {SpeechBrain} 1.0},\n",
" author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},\n",
" year={2024},\n",
" eprint={2407.00463},\n",
" archivePrefix={arXiv},\n",
" primaryClass={cs.LG},\n",
" url={https://arxiv.org/abs/2407.00463},\n",
"}\n",
"@misc{speechbrain,\n",
" title={{SpeechBrain}: A General-Purpose Speech Toolkit},\n",
" author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},\n",
" year={2021},\n",
" eprint={2106.04624},\n",
" archivePrefix={arXiv},\n",
" primaryClass={eess.AS},\n",
" note={arXiv:2106.04624}\n",
"}\n",
"```"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": [
{
"file_id": "17tKZMghjFF0ZqHnDGty26Yn1RXW67DrX",
"timestamp": 1635935282341
}
]
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}