"""Library implementing transducer_joint.
Author
Abdelwahab HEBA 2020
"""
import torch
import logging
import torch.nn as nn
logger = logging.getLogger(__name__)
[docs]
class Transducer_joint(nn.Module):
"""Computes joint tensor between Transcription network (TN) & Prediction network (PN)
Arguments
---------
joint_network : torch.class (neural network modules)
if joint == "concat", we call this network after the concatenation of TN and PN
if None, we don't use this network.
joint : joint the two tensors by ("sum",or "concat") option.
nonlinearity : torch class
Activation function used after the joint between TN and PN
Type of nonlinearity (tanh, relu).
Example
-------
>>> from speechbrain.nnet.transducer.transducer_joint import Transducer_joint
>>> from speechbrain.nnet.linear import Linear
>>> input_TN = torch.rand(8, 200, 1, 40)
>>> input_PN = torch.rand(8, 1, 12, 40)
>>> joint_network = Linear(input_size=80, n_neurons=80)
>>> TJoint = Transducer_joint(joint_network, joint="concat")
>>> output = TJoint(input_TN, input_PN)
>>> output.shape
torch.Size([8, 200, 12, 80])
"""
def __init__(
self, joint_network=None, joint="sum", nonlinearity=torch.nn.LeakyReLU
):
super().__init__()
self.joint_network = joint_network
self.joint = joint
self.nonlinearity = nonlinearity()
[docs]
def init_params(self, first_input):
"""
Arguments
---------
first_input : tensor
A first input used for initializing the parameters.
"""
self.joint_network(first_input)
[docs]
def forward(self, input_TN, input_PN):
"""Returns the fusion of inputs tensors.
Arguments
---------
input_TN : torch.Tensor
Input from Transcription Network.
input_PN : torch.Tensor
Input from Prediction Network.
"""
if len(input_TN.shape) != len(input_PN.shape):
raise ValueError("Arg 1 and 2 must be have same size")
if not (len(input_TN.shape) != 4 or len(input_TN.shape) != 1):
raise ValueError("Tensors 1 and 2 must have dim=1 or dim=4")
if self.joint == "sum":
joint = input_TN + input_PN
if self.joint == "concat":
# For training
if len(input_TN.shape) == 4:
dim = len(input_TN.shape) - 1
xs = input_TN
ymat = input_PN
sz = [
max(i, j) for i, j in zip(xs.size()[:-1], ymat.size()[:-1])
]
xs = xs.expand(torch.Size(sz + [xs.shape[-1]]))
ymat = ymat.expand(torch.Size(sz + [ymat.shape[-1]]))
joint = torch.cat((xs, ymat), dim=dim)
# For evaluation
elif len(input_TN.shape) == 1:
joint = torch.cat((input_TN, input_PN), dim=0)
if self.joint_network is not None:
joint = self.joint_network(joint)
return self.nonlinearity(joint)