speechbrain.decoders.utils module

Utils functions for the decoding modules.

Authors
  • Adel Moumen 2023

  • Ju-Chieh Chou 2020

  • Peter Plantinga 2020

  • Mirco Ravanelli 2020

  • Sung-Lin Yeh 2020

Summary

Functions:

batch_filter_seq2seq_output

Calling batch_size times of filter_seq2seq_output.

filter_seq2seq_output

Filter the output until the first eos occurs (exclusive).

inflate_tensor

This function inflates the tensor for times along dim.

mask_by_condition

This function will mask some element in the tensor with fill_value, if condition=False.

Reference

speechbrain.decoders.utils.inflate_tensor(tensor, times, dim)[source]

This function inflates the tensor for times along dim.

Parameters:
  • tensor (torch.Tensor) – The tensor to be inflated.

  • times (int) – The tensor will inflate for this number of times.

  • dim (int) – The dim to be inflated.

Returns:

The inflated tensor.

Return type:

torch.Tensor

Example

>>> tensor = torch.Tensor([[1,2,3], [4,5,6]])
>>> new_tensor = inflate_tensor(tensor, 2, dim=0)
>>> new_tensor
tensor([[1., 2., 3.],
        [1., 2., 3.],
        [4., 5., 6.],
        [4., 5., 6.]])
speechbrain.decoders.utils.mask_by_condition(tensor, cond, fill_value)[source]

This function will mask some element in the tensor with fill_value, if condition=False.

Parameters:
  • tensor (torch.Tensor) – The tensor to be masked.

  • cond (torch.BoolTensor) – This tensor has to be the same size as tensor. Each element represents whether to keep the value in tensor.

  • fill_value (float) – The value to fill in the masked element.

Returns:

The masked tensor.

Return type:

torch.Tensor

Example

>>> tensor = torch.Tensor([[1,2,3], [4,5,6]])
>>> cond = torch.BoolTensor([[True, True, False], [True, False, False]])
>>> mask_by_condition(tensor, cond, 0)
tensor([[1., 2., 0.],
        [4., 0., 0.]])
speechbrain.decoders.utils.batch_filter_seq2seq_output(prediction, eos_id=-1)[source]

Calling batch_size times of filter_seq2seq_output.

Parameters:
  • prediction (list of torch.Tensor) – A list containing the output ints predicted by the seq2seq system.

  • eos_id (int, string) – The id of the eos.

Returns:

The output predicted by seq2seq model.

Return type:

list

Example

>>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])]
>>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4)
>>> predictions
[[1, 2, 3], [2, 3]]
speechbrain.decoders.utils.filter_seq2seq_output(string_pred, eos_id=-1)[source]

Filter the output until the first eos occurs (exclusive).

Parameters:
  • string_pred (list) – A list containing the output strings/ints predicted by the seq2seq system.

  • eos_id (int, string) – The id of the eos.

Returns:

The output predicted by seq2seq model.

Return type:

list

Example

>>> string_pred = ['a','b','c','d','eos','e']
>>> string_out = filter_seq2seq_output(string_pred, eos_id='eos')
>>> string_out
['a', 'b', 'c', 'd']