Source code for speechbrain.nnet.loss.transducer_loss

"""
Transducer loss implementation (depends on numba)

Authors
 * Abdelwahab Heba 2020
 * Titouan Parcollet 2023
"""

import torch
from torch.autograd import Function
from torch.nn import Module
import logging
import math
import warnings

NUMBA_VERBOSE = 0

logger = logging.getLogger(__name__)

try:
    from numba import cuda

    # Numba is extra verbose and this may lead to log.txt file of multiple gigabytes... we deactivate
    if not NUMBA_VERBOSE:
        logger.info(
            "Numba verbose is deactivated. To enable it, set NUMBA_VERBOSE to 1."
        )

        nb_logger = logging.getLogger("numba")
        nb_logger.setLevel(logging.ERROR)  # only show error

        from numba.core.errors import NumbaPerformanceWarning

        warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
    else:
        logger.info(
            "Numba verbose is enabled. To desactivate it, set NUMBA_VERBOSE to 0."
        )

except ImportError:
    err_msg = "The optional dependency Numba is needed to use this module\n"
    err_msg += "Cannot import numba. To use Transducer loss\n"
    err_msg += "Please follow the instructions below\n"
    err_msg += "=============================\n"
    err_msg += "If you use your localhost:\n"
    err_msg += "pip install numba\n"
    err_msg += "export NUMBAPRO_LIBDEVICE='/usr/local/cuda/nvvm/libdevice/' \n"
    err_msg += "export NUMBAPRO_NVVM='/usr/local/cuda/nvvm/lib64/libnvvm.so' \n"
    err_msg += "================================ \n"
    err_msg += "If you use conda:\n"
    err_msg += "conda install numba cudatoolkit"
    raise ImportError(err_msg)


[docs] @cuda.jit() def cu_kernel_forward(log_probs, labels, alpha, log_p, T, U, blank, lock): """ Compute forward pass for the forward-backward algorithm using Numba cuda kernel. Sequence Transduction with naive implementation : https://arxiv.org/pdf/1211.3711.pdf Arguments --------- log_probs : tensor 4D Tensor of (batch x TimeLength x LabelLength x outputDim) from the Transducer network. labels : tensor 2D Tensor of (batch x MaxSeqLabelLength) containing targets of the batch with zero padding. alpha : tensor 3D Tensor of (batch x TimeLength x LabelLength) for forward computation. log_p : tensor 1D Tensor of (batch) for forward cost computation. T : tensor 1D Tensor of (batch) containing TimeLength of each target. U : tensor 1D Tensor of (batch) containing LabelLength of each target. blank : int Blank indice. lock : tensor 2D Tensor of (batch x LabelLength) containing bool(1-0) lock for parallel computation. """ # parallelize the forward algorithm over batch and target length dim b = cuda.blockIdx.x u = cuda.threadIdx.x t = 0 if u <= U[b]: # for each (B,U) Thread # wait the unlock of the previous computation of Alpha[b,U-1,:] # Do the computation over the whole Time sequence on alpha[B,U,:] # and then unlock the target U+1 for computation while t < T[b]: if u == 0: if t > 0: alpha[b, t, 0] = ( alpha[b, t - 1, 0] + log_probs[b, t - 1, 0, blank] ) cuda.atomic.add(lock, (b, u + 1), -1) t += 1 else: if cuda.atomic.add(lock, (b, u), 0) < 0: if t == 0: alpha[b, 0, u] = ( alpha[b, 0, u - 1] + log_probs[b, 0, u - 1, labels[b, u - 1]] ) else: # compute emission prob emit = ( alpha[b, t, u - 1] + log_probs[b, t, u - 1, labels[b, u - 1]] ) # compute no_emission prob no_emit = ( alpha[b, t - 1, u] + log_probs[b, t - 1, u, blank] ) # do logsumexp between log_emit and log_no_emit alpha[b, t, u] = max(no_emit, emit) + math.log1p( math.exp(-abs(no_emit - emit)) ) if u < U[b]: cuda.atomic.add(lock, (b, u + 1), -1) cuda.atomic.add(lock, (b, u), 1) t += 1 if u == U[b]: # for each thread b (utterance) # normalize the loss over time log_p[b] = ( alpha[b, T[b] - 1, U[b]] + log_probs[b, T[b] - 1, U[b], blank] ) / T[b]
[docs] @cuda.jit() def cu_kernel_backward(log_probs, labels, beta, log_p, T, U, blank, lock): """ Compute backward pass for the forward-backward algorithm using Numba cuda kernel. Sequence Transduction with naive implementation : https://arxiv.org/pdf/1211.3711.pdf Arguments --------- log_probs : tensor 4D Tensor of (batch x TimeLength x LabelLength x outputDim) from the Transducer network. labels : tensor 2D Tensor of (batch x MaxSeqLabelLength) containing targets of the batch with zero padding. beta : tensor 3D Tensor of (batch x TimeLength x LabelLength) for backward computation. log_p : tensor 1D Tensor of (batch) for backward cost computation. T : tensor 1D Tensor of (batch) containing TimeLength of each target. U : tensor 1D Tensor of (batch) containing LabelLength of each target. blank : int Blank indice. lock : tensor 2D Tensor of (batch x LabelLength) containing bool(1-0) lock for parallel computation. """ # parallelize the forward algorithm over batch and target length dim b = cuda.blockIdx.x u = cuda.threadIdx.x t = T[b] - 1 if u <= U[b]: # for each (B,U) Thread # wait the unlock of the next computation of beta[b,U+1,:] # Do the computation over the whole Time sequence on beta[B,U,:] # and then unlock the target U-1 for computation while t >= 0: if u == U[b]: if t == T[b] - 1: beta[b, t, u] = log_probs[b, t, u, blank] else: beta[b, t, u] = ( beta[b, t + 1, u] + log_probs[b, t, u, blank] ) cuda.atomic.add(lock, (b, u - 1), -1) t -= 1 else: if cuda.atomic.add(lock, (b, u), 0) < 0: if t == T[b] - 1: # do logsumexp between log_emit and log_no_emit beta[b, t, u] = ( beta[b, t, u + 1] + log_probs[b, t, u, labels[b, u]] ) else: # compute emission prob emit = ( beta[b, t, u + 1] + log_probs[b, t, u, labels[b, u]] ) # compute no_emission prob no_emit = beta[b, t + 1, u] + log_probs[b, t, u, blank] # do logsumexp between log_emit and log_no_emit beta[b, t, u] = max(no_emit, emit) + math.log1p( math.exp(-abs(no_emit - emit)) ) if u > 0: cuda.atomic.add(lock, (b, u - 1), -1) cuda.atomic.add(lock, (b, u), 1) t -= 1 if u == 0: # for each thread b (utterance) # normalize the loss over time log_p[b] = beta[b, 0, 0] / T[b]
[docs] @cuda.jit() def cu_kernel_compute_grad(log_probs, labels, alpha, beta, grads, T, U, blank): """ Compute gradient for the forward-backward algorithm using Numba cuda kernel. Sequence Transduction with naive implementation : https://arxiv.org/pdf/1211.3711.pdf Arguments --------- log_probs : tensor 4D Tensor of (batch x TimeLength x LabelLength x outputDim) from the Transducer network. labels : tensor 2D Tensor of (batch x MaxSeqLabelLength) containing targets of the batch with zero padding. beta : tensor 3D Tensor of (batch x TimeLength x LabelLength) for backward computation. log_p : tensor 1D Tensor of (batch) for backward cost computation. T : tensor 1D Tensor of (batch) containing TimeLength of each target. U : tensor 1D Tensor of (batch) containing LabelLength of each target. blank : int Blank indice. lock : int 2D Tensor of (batch x LabelLength) containing bool(1-0) lock for parallel computation. """ # parallelize the gradient computation over batch and timeseq length dim t = cuda.blockIdx.x b = cuda.threadIdx.x if t < T[b]: # compute the gradient for no_emit prob if t == 0: grads[b, T[b] - 1, U[b], blank] = -math.exp( alpha[b, T[b] - 1, U[b]] + log_probs[b, T[b] - 1, U[b], blank] - beta[b, 0, 0] ) if t < T[b] - 1: for u in range(U[b] + 1): grads[b, t, u, blank] = alpha[b, t, u] + beta[b, t + 1, u] grads[b, t, u, blank] = -math.exp( grads[b, t, u, blank] + log_probs[b, t, u, blank] - beta[b, 0, 0] ) # compute the gradient for emit prob for u, l in enumerate(labels[b]): if u < U[b]: grads[b, t, u, l] = alpha[b, t, u] + beta[b, t, u + 1] grads[b, t, u, l] = -math.exp( grads[b, t, u, l] + log_probs[b, t, u, l] - beta[b, 0, 0] )
[docs] class Transducer(Function): """ This class implements the Transducer loss computation with forward-backward algorithm Sequence Transduction with naive implementation : https://arxiv.org/pdf/1211.3711.pdf This class use torch.autograd.Function. In fact of using the forward-backward algorithm, we need to compute the gradient manually. This class can't be instantiated, please refer to TransducerLoss class It is also possible to use this class directly by using Transducer.apply """
[docs] @staticmethod def forward(ctx, log_probs, labels, T, U, blank, reduction): """Computes the transducer loss.""" log_probs = log_probs.detach() B, maxT, maxU, A = log_probs.shape grads = torch.zeros( (B, maxT, maxU, A), dtype=log_probs.dtype, device=log_probs.device ) alpha = torch.zeros( (B, maxT, maxU), device=log_probs.device, dtype=log_probs.dtype ) beta = torch.zeros( (B, maxT, maxU), device=log_probs.device, dtype=log_probs.dtype ) lock = torch.zeros( (B, maxU), dtype=torch.int32, device=log_probs.device ) log_p_alpha = torch.zeros( (B,), device=log_probs.device, dtype=log_probs.dtype ) log_p_beta = torch.zeros( (B,), device=log_probs.device, dtype=log_probs.dtype ) cu_kernel_forward[B, maxU]( log_probs, labels, alpha, log_p_alpha, T, U, blank, lock, ) lock = lock * 0 cu_kernel_backward[B, maxU]( log_probs, labels, beta, log_p_beta, T, U, blank, lock ) cu_kernel_compute_grad[maxT, B]( log_probs, labels, alpha, beta, grads, T, U, blank ) ctx.grads = grads del alpha, beta, lock, log_p_beta, T, U, log_probs, labels torch.cuda.empty_cache() if reduction == "mean": return -log_p_alpha.mean() elif reduction == "sum": return sum(-log_p_alpha) elif reduction == "none": return -log_p_alpha else: raise Exception("Unexpected reduction {}".format(reduction))
[docs] @staticmethod def backward(ctx, grad_output): """Backward computations for the transducer loss.""" grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads) return ctx.grads.mul_(grad_output), None, None, None, None, None, None
[docs] class TransducerLoss(Module): """ This class implements the Transduce loss computation with forward-backward algorithm. Sequence Transduction with naive implementation : https://arxiv.org/pdf/1211.3711.pdf The TranducerLoss(nn.Module) use Transducer(autograd.Function) to compute the forward-backward loss and gradients. Input tensors must be on a cuda device. Example ------- >>> import torch >>> loss = TransducerLoss(blank=0) >>> logits = torch.randn((1,2,3,5)).cuda().requires_grad_() >>> labels = torch.Tensor([[1,2]]).cuda().int() >>> act_length = torch.Tensor([2]).cuda().int() >>> # U = label_length+1 >>> label_length = torch.Tensor([2]).cuda().int() >>> l = loss(logits, labels, act_length, label_length) >>> l.backward() """ def __init__(self, blank=0, reduction="mean"): super(TransducerLoss, self).__init__() self.blank = blank self.reduction = reduction self.loss = Transducer.apply try: cuda.cuda_paths except ImportError: err_msg = "cannot import numba. To use Transducer loss\n" err_msg += "=============================\n" err_msg += "If you use your localhost:\n" err_msg += "pip install numba\n" err_msg += ( "export NUMBAPRO_LIBDEVICE='/usr/local/cuda/nvvm/libdevice/' \n" ) err_msg += "export NUMBAPRO_NVVM='/usr/local/cuda/nvvm/lib64/libnvvm.so' \n" err_msg += "================================ \n" err_msg += "If you use conda:\n" err_msg += "conda install numba cudatoolkit=XX (XX is your cuda toolkit version)" raise ImportError(err_msg)
[docs] def forward(self, logits, labels, T, U): """Computes the transducer loss.""" # Transducer.apply function take log_probs tensor. if all(t.is_cuda for t in (logits, labels, T, U)): log_probs = logits.log_softmax(-1) return self.loss( log_probs, labels, T, U, self.blank, self.reduction ) else: raise ValueError( f"Found inputs tensors to be on {[logits.device, labels.device, T.device, U.device]} while needed to be on a 'cuda' device to use the transducer loss." )