Source code for speechbrain.utils.callchains

"""Chaining together callables, if some require relative lengths"""
import inspect


[docs] def lengths_arg_exists(func): """Returns True if func takes ``lengths`` keyword argument. Arguments --------- func : callable The function, method, or other callable to search for the lengths arg. """ spec = inspect.getfullargspec(func) return "lengths" in spec.args + spec.kwonlyargs
[docs] class LengthsCapableChain: """Chain together callables. Can handle relative lengths. This is a more light-weight version of speechbrain.nnet.containers.LengthsCapableSequential Arguments --------- *funcs : list, optional Any number of functions or other callables, given in order of execution. Returns ------- Any The input as processed by each function. If no functions were given, simply returns the input. """ def __init__(self, *funcs): self.funcs = [] self.takes_lengths = [] for func in funcs: self.append(func)
[docs] def __call__(self, x, lengths=None): """Run the chain of callables on the given input Arguments --------- x : Any The main input lengths : Any The lengths argument which will be conditionally passed to any functions in the chain that take a 'lengths' argument. In SpeechBrain the convention is to use relative lengths. Note ---- By convention, if a callable in the chain returns multiple outputs (returns a tuple), only the first output is passed to the next callable in the chain. """ if not self.funcs: return x for func, give_lengths in zip(self.funcs, self.takes_lengths): if give_lengths: x = func(x, lengths) else: x = func(x) if isinstance(x, tuple): x = x[0] return x
[docs] def append(self, func): """Add a function to the chain""" self.funcs.append(func) self.takes_lengths.append(lengths_arg_exists(func))
def __str__(self): clsname = self.__class__.__name__ if self.funcs: return f"{clsname}:\n" + "\n".join(str(f) for f in self.funcs) else: return f"Empty {clsname}"