speechbrain.nnet.unet module

A UNet model implementation for use with diffusion models

Adapted from OpenAI guided diffusion, with slight modifications and additional features https://github.com/openai/guided-diffusion

MIT License

Copyright (c) 2021 OpenAI

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

Authors
  • Artem Ploujnikov 2022

Summary

Classes:

AttentionBlock

An attention block that allows spatial positions to attend to each other.

AttentionPool2d

Two-dimensional attentional pooling

DecoderUNetModel

The half UNet model with attention and timestep embedding.

Downsample

A downsampling layer with an optional convolution.

DownsamplingPadding

A wrapper module that applies the necessary padding for the downsampling factor

EmbeddingProjection

A simple module that computes the projection of an embedding vector onto the specified number of dimensions

EncoderUNetModel

The half UNet model with attention and timestep embedding.

QKVAttention

A module which performs QKV attention and splits in a different order.

ResBlock

A residual block that can optionally change the number of channels.

TimestepBlock

Any module where forward() takes timestep embeddings as a second argument.

TimestepEmbedSequential

A sequential module that passes timestep embeddings to the children that support it as an extra input.

UNetModel

The full UNet model with attention and timestep embedding.

UNetNormalizingAutoencoder

A convenience class for a UNet-based Variational Autoencoder (VAE) - useful in constructing Latent Diffusion models

Upsample

An upsampling layer with an optional convolution.

Functions:

avg_pool_nd

Create a 1D, 2D, or 3D average pooling module.

build_emb_proj

Builds a dictionary of embedding modules for embedding projections

conv_nd

Create a 1D, 2D, or 3D convolution module.

fixup

Zero out the parameters of a module and return it.

timestep_embedding

Create sinusoidal timestep embeddings.

Reference

speechbrain.nnet.unet.fixup(module, use_fixup_init=True)[source]

Zero out the parameters of a module and return it.

Parameters:
  • module (torch.nn.Module) – a module

  • use_fixup_init (bool) – whether to zero out the parameters. If set to false, the function is a no-op

speechbrain.nnet.unet.conv_nd(dims, *args, **kwargs)[source]

Create a 1D, 2D, or 3D convolution module.

Parameters:
  • dims (int) – The number of dimensions

  • constructor (Any remaining arguments are passed to the) –

speechbrain.nnet.unet.avg_pool_nd(dims, *args, **kwargs)[source]

Create a 1D, 2D, or 3D average pooling module.

speechbrain.nnet.unet.timestep_embedding(timesteps, dim, max_period=10000)[source]

Create sinusoidal timestep embeddings.

Parameters:
  • timesteps (torch.Tensor) – a 1-D Tensor of N indices, one per batch element. These may be fractional.

  • dim (int) – the dimension of the output.

  • max_period (int) – controls the minimum frequency of the embeddings.

Returns:

result – an [N x dim] Tensor of positional embeddings.

Return type:

torch.Tensor

class speechbrain.nnet.unet.AttentionPool2d(spatial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int | None = None)[source]

Bases: Module

Two-dimensional attentional pooling

Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py

Parameters:
  • spatial_dim (int) – the size of the spatial dimension

  • embed_dim (int) – the embedding dimension

  • num_heads_channels (int) – the number of attention heads

  • output_dim (int) – the output dimension

Example

>>> attn_pool = AttentionPool2d(
...     spatial_dim=64,
...     embed_dim=16,
...     num_heads_channels=2,
...     output_dim=4
... )
>>> x = torch.randn(4, 1, 64, 64)
>>> x_pool = attn_pool(x)
>>> x_pool.shape
torch.Size([4, 4])
forward(x)[source]

Computes the attention forward pass

Parameters:

x (torch.Tensor) – the tensor to be attended to

Returns:

result – the attention output

Return type:

torch.Tensor

training: bool
class speechbrain.nnet.unet.TimestepBlock(*args, **kwargs)[source]

Bases: Module

Any module where forward() takes timestep embeddings as a second argument.

abstract forward(x, emb=None)[source]

Apply the module to x given emb timestep embeddings.

Parameters:
training: bool
class speechbrain.nnet.unet.TimestepEmbedSequential(*args: Module)[source]
class speechbrain.nnet.unet.TimestepEmbedSequential(arg: OrderedDict[str, Module])

Bases: Sequential, TimestepBlock

A sequential module that passes timestep embeddings to the children that support it as an extra input.

Example

>>> from speechbrain.nnet.linear import Linear
>>> class MyBlock(TimestepBlock):
...     def __init__(self, input_size, output_size, emb_size):
...         super().__init__()
...         self.lin = Linear(
...             n_neurons=output_size,
...             input_size=input_size
...         )
...         self.emb_proj = Linear(
...             n_neurons=output_size,
...             input_size=emb_size,
...         )
...     def forward(self, x, emb):
...         return self.lin(x) + self.emb_proj(emb)
>>> tes = TimestepEmbedSequential(
...     MyBlock(128, 64, 16),
...     Linear(
...         n_neurons=32,
...         input_size=64
...     )
... )
>>> x = torch.randn(4, 10, 128)
>>> emb = torch.randn(4, 10, 16)
>>> out = tes(x, emb)
>>> out.shape
torch.Size([4, 10, 32])
forward(x, emb=None)[source]

Computes a sequential pass with sequential embeddings where applicable

Parameters:
class speechbrain.nnet.unet.Upsample(channels, use_conv, dims=2, out_channels=None)[source]

Bases: Module

An upsampling layer with an optional convolution.

Parameters:
  • channels (torch.Tensor) – channels in the inputs and outputs.

  • use_conv (bool) – a bool determining if a convolution is applied.

  • dims (int) – determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.

Example

>>> ups = Upsample(channels=4, use_conv=True, dims=2, out_channels=8)
>>> x = torch.randn(8, 4, 32, 32)
>>> x_up = ups(x)
>>> x_up.shape
torch.Size([8, 8, 64, 64])
forward(x)[source]

Computes the upsampling pass

Parameters:
training: bool
class speechbrain.nnet.unet.Downsample(channels, use_conv, dims=2, out_channels=None)[source]

Bases: Module

A downsampling layer with an optional convolution.

Parameters:
  • channels (int) – channels in the inputs and outputs.

  • use_conv (bool) – a bool determining if a convolution is applied.

  • dims (int) – determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.

Example

>>> ups = Downsample(channels=4, use_conv=True, dims=2, out_channels=8)
>>> x = torch.randn(8, 4, 32, 32)
>>> x_up = ups(x)
>>> x_up.shape
torch.Size([8, 8, 16, 16])
forward(x)[source]

Computes the downsampling pass

Parameters:

x (torch.Tensor) – layer inputs

Returns:

result – downsampled outputs

Return type:

torch.Tensor

training: bool
class speechbrain.nnet.unet.ResBlock(channels, emb_channels, dropout, out_channels=None, use_conv=False, dims=2, up=False, down=False, norm_num_groups=32, use_fixup_init=True)[source]

Bases: TimestepBlock

A residual block that can optionally change the number of channels.

Parameters:
  • channels (int) – the number of input channels.

  • emb_channels (int) – the number of timestep embedding channels.

  • dropout (float) – the rate of dropout.

  • out_channels (int) – if specified, the number of out channels.

  • use_conv (bool) – if True and out_channels is specified, use a spatial convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.

  • dims (int) – determines if the signal is 1D, 2D, or 3D.

  • up (bool) – if True, use this block for upsampling.

  • down (bool) – if True, use this block for downsampling.

  • norm_num_groups (int) – the number of groups for group normalization

  • use_fixup_init (bool) – whether to use FixUp initialization

Example

>>> res = ResBlock(
...     channels=4,
...     emb_channels=8,
...     dropout=0.1,
...     norm_num_groups=2,
...     use_conv=True,
... )
>>> x = torch.randn(2, 4, 32, 32)
>>> emb = torch.randn(2, 8)
>>> res_out = res(x, emb)
>>> res_out.shape
torch.Size([2, 4, 32, 32])
forward(x, emb=None)[source]

Apply the block to a Tensor, conditioned on a timestep embedding.

Parameters:
  • x (torch.Tensor) – an [N x C x …] Tensor of features.

  • emb (torch.Tensor) – an [N x emb_channels] Tensor of timestep embeddings.

Returns:

result – an [N x C x …] Tensor of outputs.

Return type:

torch.Tensor

training: bool
class speechbrain.nnet.unet.AttentionBlock(channels, num_heads=1, num_head_channels=-1, norm_num_groups=32, use_fixup_init=True)[source]

Bases: Module

An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.

Parameters:
  • channels (int) – the number of channels

  • num_heads (int) – the number of attention heads

  • num_head_channels (int) – the number of channels in each attention head

  • norm_num_groups (int) – the number of groups used for group normalization

  • use_fixup_init (bool) – whether to use FixUp initialization

Example

>>> attn = AttentionBlock(
...     channels=8,
...     num_heads=4,
...     num_head_channels=4,
...     norm_num_groups=2
... )
>>> x = torch.randn(4, 8, 16, 16)
>>> out = attn(x)
>>> out.shape
torch.Size([4, 8, 16, 16])
forward(x)[source]

Completes the forward pass

Parameters:

x (torch.Tensor) – the data to be attended to

Returns:

result – The data, with attention applied

Return type:

torch.Tensor

training: bool
class speechbrain.nnet.unet.QKVAttention(n_heads)[source]

Bases: Module

A module which performs QKV attention and splits in a different order.

Example

>>> attn = QKVAttention(4)
>>> n = 4
>>> c = 8
>>> h = 64
>>> w = 16
>>> qkv = torch.randn(4, (3 * h * c), w)
>>> out = attn(qkv)
>>> out.shape
torch.Size([4, 512, 16])
forward(qkv)[source]

Apply QKV attention.

Parameters:
  • qkv (torch.Tensor) – an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.

  • Results

  • -------

  • result (torch.Tensor) – an [N x (H * C) x T] tensor after attention.

training: bool
speechbrain.nnet.unet.build_emb_proj(emb_config, proj_dim=None, use_emb=None)[source]

Builds a dictionary of embedding modules for embedding projections

Parameters:
  • emb_config (dict) – a configuration dictionary

  • proj_dim (int) – the target projection dimension

  • use_cond_emb (dict) – an optional dictioanry of “switches” to turn embeddings on and off

Returns:

result – a ModuleDict with a module for each embedding

Return type:

torch.nn.ModuleDict

class speechbrain.nnet.unet.UNetModel(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, emb_dim=None, cond_emb=None, use_cond_emb=None, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, norm_num_groups=32, resblock_updown=False, use_fixup_init=True)[source]

Bases: Module

The full UNet model with attention and timestep embedding.

Parameters:
  • in_channels (int) – channels in the input Tensor.

  • model_channels (int) – base channel count for the model.

  • out_channels (int) – channels in the output Tensor.

  • num_res_blocks (int) – number of residual blocks per downsample.

  • attention_resolutions (int) – a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used.

  • dropout (float) – the dropout probability.

  • channel_mult (int) – channel multiplier for each level of the UNet.

  • conv_resample (bool) – if True, use learned convolutions for upsampling and downsampling

  • emb_dim (int) – time embedding dimension (defaults to model_channels * 4)

  • cond_emb (dict) –

    embeddings on which the model will be conditioned

    Example: {

    ”speaker”: {

    “emb_dim”: 256

    }, “label”: {

    ”emb_dim”: 12

    }

    }

  • use_cond_emb (dict) –

    a dictionary with keys corresponding to keys in cond_emb and values corresponding to Booleans that turn embeddings on and off. This is useful in combination with hparams files to turn embeddings on and off with simple switches

    Example: {“speaker”: False, “label”: True}

  • dims (int) – determines if the signal is 1D, 2D, or 3D.

  • num_heads (int) – the number of attention heads in each attention layer.

  • num_heads_channels (int) –

    if specified, ignore num_heads and instead use

    a fixed channel width per attention head.

  • num_heads_upsample (int) –

    works with num_heads to set a different number

    of heads for upsampling. Deprecated.

  • resblock_updown (bool) – use residual blocks for up/downsampling.

  • use_fixup_init (bool) – whether to use FixUp initialization

Example

>>> model = UNetModel(
...    in_channels=3,
...    model_channels=32,
...    out_channels=1,
...    num_res_blocks=1,
...    attention_resolutions=[1]
... )
>>> x = torch.randn(4, 3, 16, 32)
>>> ts = torch.tensor([10, 100, 50, 25])
>>> out = model(x, ts)
>>> out.shape
torch.Size([4, 1, 16, 32])
forward(x, timesteps, cond_emb=None, **kwargs)[source]

Apply the model to an input batch.

Parameters:
  • x (torch.Tensor) – an [N x C x …] Tensor of inputs.

  • timesteps (torch.Tensor) – a 1-D batch of timesteps.

  • cond_emb (dict) – a string -> tensor dictionary of conditional embeddings (multiple embeddings are supported)

Returns:

result – an [N x C x …] Tensor of outputs.

Return type:

torch.Tensor

training: bool
class speechbrain.nnet.unet.EncoderUNetModel(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, norm_num_groups=32, resblock_updown=False, pool=None, attention_pool_dim=None, out_kernel_size=3, use_fixup_init=True)[source]

Bases: Module

The half UNet model with attention and timestep embedding. For usage, see UNetModel.

Parameters:
  • in_channels (int) – channels in the input Tensor.

  • model_channels (int) – base channel count for the model.

  • out_channels (int) – channels in the output Tensor.

  • num_res_blocks (int) – number of residual blocks per downsample.

  • attention_resolutions (int) – a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used.

  • dropout (float) – the dropout probability.

  • channel_mult (int) – channel multiplier for each level of the UNet.

  • conv_resample (bool) – if True, use learned convolutions for upsampling and downsampling

  • emb_dim (int) – time embedding dimension (defaults to model_channels * 4)

  • cond_emb (dict) –

    embeddings on which the model will be conditioned

    Example: {

    ”speaker”: {

    “emb_dim”: 256

    }, “label”: {

    ”emb_dim”: 12

    }

    }

  • use_cond_emb (dict) –

    a dictionary with keys corresponding to keys in cond_emb and values corresponding to Booleans that turn embeddings on and off. This is useful in combination with hparams files to turn embeddings on and off with simple switches

    Example: {“speaker”: False, “label”: True}

  • dims (int) – determines if the signal is 1D, 2D, or 3D.

  • num_heads (int) – the number of attention heads in each attention layer.

  • num_heads_channels (int) –

    if specified, ignore num_heads and instead use

    a fixed channel width per attention head.

  • num_heads_upsample (int) –

    works with num_heads to set a different number

    of heads for upsampling. Deprecated.

  • resblock_updown (bool) – use residual blocks for up/downsampling.

  • use_fixup_init (bool) – whether to use FixUp initialization

  • out_kernel_size (int) – the kernel size of the output convolution

Example

>>> model = EncoderUNetModel(
...    in_channels=3,
...    model_channels=32,
...    out_channels=1,
...    num_res_blocks=1,
...    attention_resolutions=[1]
... )
>>> x = torch.randn(4, 3, 16, 32)
>>> ts = torch.tensor([10, 100, 50, 25])
>>> out = model(x, ts)
>>> out.shape
torch.Size([4, 1, 2, 4])
forward(x, timesteps=None)[source]

Apply the model to an input batch.

Parameters:
Returns:

result – an [N x K] Tensor of outputs.

Return type:

torch.Tensor

training: bool
class speechbrain.nnet.unet.EmbeddingProjection(emb_dim, proj_dim)[source]

Bases: Module

A simple module that computes the projection of an embedding vector onto the specified number of dimensions

Parameters:
  • emb_dim (int) – the original embedding dimensionality

  • proj_dim (int) – the dimensionality of the target projection space

Example

>>> mod_emb_proj = EmbeddingProjection(emb_dim=16, proj_dim=64)
>>> emb = torch.randn(4, 16)
>>> emb_proj = mod_emb_proj(emb)
>>> emb_proj.shape
torch.Size([4, 64])
forward(emb)[source]

Computes the forward pass

Parameters:

emb (torch.Tensor) – the original embedding tensor

Returns:

result – the target embedding space

Return type:

torch.Tensor

training: bool
class speechbrain.nnet.unet.DecoderUNetModel(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, resblock_updown=False, norm_num_groups=32, out_kernel_size=3, use_fixup_init=True)[source]

Bases: Module

The half UNet model with attention and timestep embedding. For usage, see UNet.

Parameters:
  • in_channels (int) – channels in the input Tensor.

  • model_channels (int) – base channel count for the model.

  • out_channels (int) – channels in the output Tensor.

  • num_res_blocks (int) – number of residual blocks per downsample.

  • attention_resolutions (int) – a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used.

  • dropout (float) – the dropout probability.

  • channel_mult (int) – channel multiplier for each level of the UNet.

  • conv_resample (bool) – if True, use learned convolutions for upsampling and downsampling

  • emb_dim (int) – time embedding dimension (defaults to model_channels * 4)

  • cond_emb (dict) –

    embeddings on which the model will be conditioned

    Example: {

    ”speaker”: {

    “emb_dim”: 256

    }, “label”: {

    ”emb_dim”: 12

    }

    }

  • use_cond_emb (dict) –

    a dictionary with keys corresponding to keys in cond_emb and values corresponding to Booleans that turn embeddings on and off. This is useful in combination with hparams files to turn embeddings on and off with simple switches

    Example: {“speaker”: False, “label”: True}

  • dims (int) – determines if the signal is 1D, 2D, or 3D.

  • num_heads (int) – the number of attention heads in each attention layer.

  • num_heads_channels (int) –

    if specified, ignore num_heads and instead use

    a fixed channel width per attention head.

  • num_heads_upsample (int) –

    works with num_heads to set a different number

    of heads for upsampling. Deprecated.

  • resblock_updown (bool) – use residual blocks for up/downsampling.

  • use_fixup_init (bool) – whether to use FixUp initialization

Example

>>> model = DecoderUNetModel(
...    in_channels=1,
...    model_channels=32,
...    out_channels=3,
...    num_res_blocks=1,
...    attention_resolutions=[1]
... )
>>> x = torch.randn(4, 1, 2, 4)
>>> ts = torch.tensor([10, 100, 50, 25])
>>> out = model(x, ts)
>>> out.shape
torch.Size([4, 3, 16, 32])
training: bool
forward(x, timesteps=None)[source]

Apply the model to an input batch.

Parameters:
Returns:

result – an [N x K] Tensor of outputs.

Return type:

torch.Tensor

class speechbrain.nnet.unet.DownsamplingPadding(factor, len_dim=2, dims=None)[source]

Bases: Module

A wrapper module that applies the necessary padding for the downsampling factor

Parameters:
  • factor (int) – the downsampling / divisibility factor

  • len_dim (int) – the index of the dimension in which the length will vary

  • dims (list) – the list of dimensions to be included in padding

Example

>>> padding = DownsamplingPadding(factor=4, dims=[1, 2], len_dim=1)
>>> x = torch.randn(4, 7, 14)
>>> length = torch.tensor([1., 0.8, 1., 0.7])
>>> x, length_new = padding(x, length)
>>> x.shape
torch.Size([4, 8, 16])
>>> length_new
tensor([0.8750, 0.7000, 0.8750, 0.6125])
training: bool
forward(x, length=None)[source]

Applies the padding

Parameters:
Returns:

  • x_pad (torch.Tensor) – the padded tensor

  • lens (torch.Tensor) – the new, adjusted lengths, if applicable

class speechbrain.nnet.unet.UNetNormalizingAutoencoder(in_channels, model_channels, encoder_out_channels, latent_channels, encoder_num_res_blocks, encoder_attention_resolutions, decoder_num_res_blocks, decoder_attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), dims=2, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, norm_num_groups=32, resblock_updown=False, out_kernel_size=3, len_dim=2, out_mask_value=0.0, latent_mask_value=0.0, use_fixup_norm=False, downsampling_padding=None)[source]

Bases: NormalizingAutoencoder

A convenience class for a UNet-based Variational Autoencoder (VAE) - useful in constructing Latent Diffusion models

Parameters:
  • in_channels (int) – the number of input channels

  • model_channels (int) – the number of channels in the convolutional layers of the UNet encoder and decoder

  • encoder_out_channels (int) – the number of channels the encoder will output

  • latent_channels (int) – the number of channels in the latent space

  • encoder_num_res_blocks (int) – the number of residual blocks in the encoder

  • encoder_attention_resolutions (list) – the resolutions at which to apply attention layers in the encoder

  • decoder_num_res_blocks (int) – the number of residual blocks in the decoder

  • decoder_attention_resolutions (list) – the resolutions at which to apply attention layers in the encoder

  • dropout (float) – the dropout probability

  • channel_mult (tuple) – channel multipliers for each layer

  • dims (int) – the convolution dimension to use (1, 2 or 3)

  • num_heads (int) – the number of attention heads

  • num_head_channels (int) – the number of channels in attention heads

  • num_heads_upsample (int) – the number of upsampling heads

  • resblock_updown (bool) – whether to use residual blocks for upsampling and downsampling

  • out_kernel_size (int) – the kernel size for output convolution layers (if applicable)

  • use_fixup_norm (bool) – whether to use FixUp normalization

Example

>>> unet_ae = UNetNormalizingAutoencoder(
...     in_channels=1,
...     model_channels=4,
...     encoder_out_channels=16,
...     latent_channels=3,
...     encoder_num_res_blocks=1,
...     encoder_attention_resolutions=[],
...     decoder_num_res_blocks=1,
...     decoder_attention_resolutions=[],
...     norm_num_groups=2,
... )
>>> x = torch.randn(4, 1, 32, 32)
>>> x_enc = unet_ae.encode(x)
>>> x_enc.shape
torch.Size([4, 3, 4, 4])
>>> x_dec = unet_ae.decode(x_enc)
>>> x_dec.shape
torch.Size([4, 1, 32, 32])
training: bool