speechbrain.nnet.unet moduleο
A UNet model implementation for use with diffusion models
Adapted from OpenAI guided diffusion, with slight modifications and additional features https://github.com/openai/guided-diffusion
MIT License
Copyright (c) 2021 OpenAI
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the βSoftwareβ), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED βAS ISβ, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- Authors
Artem Ploujnikov 2022
Summaryο
Classes:
An attention block that allows spatial positions to attend to each other. |
|
Two-dimensional attentional pooling |
|
The half UNet model with attention and timestep embedding. |
|
A downsampling layer with an optional convolution. |
|
A wrapper module that applies the necessary padding for the downsampling factor |
|
A simple module that computes the projection of an embedding vector onto the specified number of dimensions |
|
The half UNet model with attention and timestep embedding. |
|
A module which performs QKV attention and splits in a different order. |
|
A residual block that can optionally change the number of channels. |
|
Any module where forward() takes timestep embeddings as a second argument. |
|
A sequential module that passes timestep embeddings to the children that support it as an extra input. |
|
The full UNet model with attention and timestep embedding. |
|
A convenience class for a UNet-based Variational Autoencoder (VAE) - useful in constructing Latent Diffusion models |
|
An upsampling layer with an optional convolution. |
Functions:
Create a 1D, 2D, or 3D average pooling module. |
|
Builds a dictionary of embedding modules for embedding projections |
|
Create a 1D, 2D, or 3D convolution module. |
|
Zero out the parameters of a module and return it. |
|
Create sinusoidal timestep embeddings. |
Referenceο
- speechbrain.nnet.unet.fixup(module, use_fixup_init=True)[source]ο
Zero out the parameters of a module and return it.
- Parameters:
module (torch.nn.Module) β a module
use_fixup_init (bool) β whether to zero out the parameters. If set to false, the function is a no-op
- Return type:
The fixed module
- speechbrain.nnet.unet.conv_nd(dims, *args, **kwargs)[source]ο
Create a 1D, 2D, or 3D convolution module.
- speechbrain.nnet.unet.avg_pool_nd(dims, *args, **kwargs)[source]ο
Create a 1D, 2D, or 3D average pooling module.
- speechbrain.nnet.unet.timestep_embedding(timesteps, dim, max_period=10000)[source]ο
Create sinusoidal timestep embeddings.
- Parameters:
- Returns:
result β an [N x dim] Tensor of positional embeddings.
- Return type:
torch.Tensor
- class speechbrain.nnet.unet.AttentionPool2d(spatial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int = None)[source]ο
Bases:
Module
Two-dimensional attentional pooling
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
- Parameters:
Example
>>> attn_pool = AttentionPool2d( ... spatial_dim=64, ... embed_dim=16, ... num_heads_channels=2, ... output_dim=4 ... ) >>> x = torch.randn(4, 1, 64, 64) >>> x_pool = attn_pool(x) >>> x_pool.shape torch.Size([4, 4])
- class speechbrain.nnet.unet.TimestepBlock(*args, **kwargs)[source]ο
Bases:
Module
Any module where forward() takes timestep embeddings as a second argument.
- class speechbrain.nnet.unet.TimestepEmbedSequential(*args: Module)[source]ο
- class speechbrain.nnet.unet.TimestepEmbedSequential(arg: OrderedDict[str, Module])
Bases:
Sequential
,TimestepBlock
A sequential module that passes timestep embeddings to the children that support it as an extra input.
Example
>>> from speechbrain.nnet.linear import Linear >>> class MyBlock(TimestepBlock): ... def __init__(self, input_size, output_size, emb_size): ... super().__init__() ... self.lin = Linear( ... n_neurons=output_size, ... input_size=input_size ... ) ... self.emb_proj = Linear( ... n_neurons=output_size, ... input_size=emb_size, ... ) ... def forward(self, x, emb): ... return self.lin(x) + self.emb_proj(emb) >>> tes = TimestepEmbedSequential( ... MyBlock(128, 64, 16), ... Linear( ... n_neurons=32, ... input_size=64 ... ) ... ) >>> x = torch.randn(4, 10, 128) >>> emb = torch.randn(4, 10, 16) >>> out = tes(x, emb) >>> out.shape torch.Size([4, 10, 32])
- class speechbrain.nnet.unet.Upsample(channels, use_conv, dims=2, out_channels=None)[source]ο
Bases:
Module
An upsampling layer with an optional convolution.
- Parameters:
channels (torch.Tensor) β channels in the inputs and outputs.
use_conv (bool) β a bool determining if a convolution is applied.
dims (int) β determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
out_channels (int) β Number of output channels. If None, same as input channels.
Example
>>> ups = Upsample(channels=4, use_conv=True, dims=2, out_channels=8) >>> x = torch.randn(8, 4, 32, 32) >>> x_up = ups(x) >>> x_up.shape torch.Size([8, 8, 64, 64])
- class speechbrain.nnet.unet.Downsample(channels, use_conv, dims=2, out_channels=None)[source]ο
Bases:
Module
A downsampling layer with an optional convolution.
- Parameters:
channels (int) β channels in the inputs and outputs.
use_conv (bool) β a bool determining if a convolution is applied.
dims (int) β determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
out_channels (int) β Number of output channels. If None, same as input channels.
Example
>>> ups = Downsample(channels=4, use_conv=True, dims=2, out_channels=8) >>> x = torch.randn(8, 4, 32, 32) >>> x_up = ups(x) >>> x_up.shape torch.Size([8, 8, 16, 16])
- class speechbrain.nnet.unet.ResBlock(channels, emb_channels, dropout, out_channels=None, use_conv=False, dims=2, up=False, down=False, norm_num_groups=32, use_fixup_init=True)[source]ο
Bases:
TimestepBlock
A residual block that can optionally change the number of channels.
- Parameters:
channels (int) β the number of input channels.
emb_channels (int) β the number of timestep embedding channels.
dropout (float) β the rate of dropout.
out_channels (int) β if specified, the number of out channels.
use_conv (bool) β if True and out_channels is specified, use a spatial convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
dims (int) β determines if the signal is 1D, 2D, or 3D.
up (bool) β if True, use this block for upsampling.
down (bool) β if True, use this block for downsampling.
norm_num_groups (int) β the number of groups for group normalization
use_fixup_init (bool) β whether to use FixUp initialization
Example
>>> res = ResBlock( ... channels=4, ... emb_channels=8, ... dropout=0.1, ... norm_num_groups=2, ... use_conv=True, ... ) >>> x = torch.randn(2, 4, 32, 32) >>> emb = torch.randn(2, 8) >>> res_out = res(x, emb) >>> res_out.shape torch.Size([2, 4, 32, 32])
- forward(x, emb=None)[source]ο
Apply the block to a torch.Tensor, conditioned on a timestep embedding.
- Parameters:
x (torch.Tensor) β an [N x C x β¦] Tensor of features.
emb (torch.Tensor) β an [N x emb_channels] Tensor of timestep embeddings.
- Returns:
result β an [N x C x β¦] Tensor of outputs.
- Return type:
torch.Tensor
- class speechbrain.nnet.unet.AttentionBlock(channels, num_heads=1, num_head_channels=-1, norm_num_groups=32, use_fixup_init=True)[source]ο
Bases:
Module
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
- Parameters:
channels (int) β the number of channels
num_heads (int) β the number of attention heads
num_head_channels (int) β the number of channels in each attention head
norm_num_groups (int) β the number of groups used for group normalization
use_fixup_init (bool) β whether to use FixUp initialization
Example
>>> attn = AttentionBlock( ... channels=8, ... num_heads=4, ... num_head_channels=4, ... norm_num_groups=2 ... ) >>> x = torch.randn(4, 8, 16, 16) >>> out = attn(x) >>> out.shape torch.Size([4, 8, 16, 16])
- class speechbrain.nnet.unet.QKVAttention(n_heads)[source]ο
Bases:
Module
A module which performs QKV attention and splits in a different order.
- Parameters:
n_heads (int) β Number of attention heads.
Example
>>> attn = QKVAttention(4) >>> n = 4 >>> c = 8 >>> h = 64 >>> w = 16 >>> qkv = torch.randn(4, (3 * h * c), w) >>> out = attn(qkv) >>> out.shape torch.Size([4, 512, 16])
- speechbrain.nnet.unet.build_emb_proj(emb_config, proj_dim=None, use_emb=None)[source]ο
Builds a dictionary of embedding modules for embedding projections
- Parameters:
- Returns:
result β a ModuleDict with a module for each embedding
- Return type:
torch.nn.ModuleDict
- class speechbrain.nnet.unet.UNetModel(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, emb_dim=None, cond_emb=None, use_cond_emb=None, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, norm_num_groups=32, resblock_updown=False, use_fixup_init=True)[source]ο
Bases:
Module
The full UNet model with attention and timestep embedding.
- Parameters:
in_channels (int) β channels in the input torch.Tensor.
model_channels (int) β base channel count for the model.
out_channels (int) β channels in the output torch.Tensor.
num_res_blocks (int) β number of residual blocks per downsample.
attention_resolutions (int) β a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used.
dropout (float) β the dropout probability.
channel_mult (int) β channel multiplier for each level of the UNet.
conv_resample (bool) β if True, use learned convolutions for upsampling and downsampling
dims (int) β determines if the signal is 1D, 2D, or 3D.
emb_dim (int) β time embedding dimension (defaults to model_channels * 4)
cond_emb (dict) β
embeddings on which the model will be conditioned
Example: {
- βspeakerβ: {
βemb_dimβ: 256
}, βlabelβ: {
βemb_dimβ: 12
}
}
use_cond_emb (dict) β
a dictionary with keys corresponding to keys in cond_emb and values corresponding to Booleans that turn embeddings on and off. This is useful in combination with hparams files to turn embeddings on and off with simple switches
Example: {βspeakerβ: False, βlabelβ: True}
num_heads (int) β the number of attention heads in each attention layer.
num_head_channels (int) β if specified, ignore num_heads and instead use a fixed channel width per attention head.
num_heads_upsample (int) β works with num_heads to set a different number of heads for upsampling. Deprecated.
norm_num_groups (int) β Number of groups in the norm, default 32
resblock_updown (bool) β use residual blocks for up/downsampling.
use_fixup_init (bool) β whether to use FixUp initialization
Example
>>> model = UNetModel( ... in_channels=3, ... model_channels=32, ... out_channels=1, ... num_res_blocks=1, ... attention_resolutions=[1] ... ) >>> x = torch.randn(4, 3, 16, 32) >>> ts = torch.tensor([10, 100, 50, 25]) >>> out = model(x, ts) >>> out.shape torch.Size([4, 1, 16, 32])
- forward(x, timesteps, cond_emb=None)[source]ο
Apply the model to an input batch.
- Parameters:
x (torch.Tensor) β an [N x C x β¦] Tensor of inputs.
timesteps (torch.Tensor) β a 1-D batch of timesteps.
cond_emb (dict) β a string -> tensor dictionary of conditional embeddings (multiple embeddings are supported)
- Returns:
result β an [N x C x β¦] Tensor of outputs.
- Return type:
torch.Tensor
- class speechbrain.nnet.unet.EncoderUNetModel(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, norm_num_groups=32, resblock_updown=False, pool=None, attention_pool_dim=None, out_kernel_size=3, use_fixup_init=True)[source]ο
Bases:
Module
The half UNet model with attention and timestep embedding. For usage, see UNetModel.
- Parameters:
in_channels (int) β channels in the input torch.Tensor.
model_channels (int) β base channel count for the model.
out_channels (int) β channels in the output torch.Tensor.
num_res_blocks (int) β number of residual blocks per downsample.
attention_resolutions (int) β a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used.
dropout (float) β the dropout probability.
channel_mult (int) β channel multiplier for each level of the UNet.
conv_resample (bool) β if True, use learned convolutions for upsampling and downsampling
dims (int) β determines if the signal is 1D, 2D, or 3D.
num_heads (int) β the number of attention heads in each attention layer.
num_head_channels (int) β if specified, ignore num_heads and instead use a fixed channel width per attention head.
num_heads_upsample (int) β works with num_heads to set a different number of heads for upsampling. Deprecated.
norm_num_groups (int) β Number of groups in the norm, default 32.
resblock_updown (bool) β use residual blocks for up/downsampling.
pool (str) β Type of pooling to use, one of: [βadaptiveβ, βattentionβ, βspatialβ, βspatial_v2β].
attention_pool_dim (int) β The dimension on which to apply attention pooling.
out_kernel_size (int) β the kernel size of the output convolution
use_fixup_init (bool) β whether to use FixUp initialization
Example
>>> model = EncoderUNetModel( ... in_channels=3, ... model_channels=32, ... out_channels=1, ... num_res_blocks=1, ... attention_resolutions=[1] ... ) >>> x = torch.randn(4, 3, 16, 32) >>> ts = torch.tensor([10, 100, 50, 25]) >>> out = model(x, ts) >>> out.shape torch.Size([4, 1, 2, 4])
- class speechbrain.nnet.unet.EmbeddingProjection(emb_dim, proj_dim)[source]ο
Bases:
Module
A simple module that computes the projection of an embedding vector onto the specified number of dimensions
- Parameters:
Example
>>> mod_emb_proj = EmbeddingProjection(emb_dim=16, proj_dim=64) >>> emb = torch.randn(4, 16) >>> emb_proj = mod_emb_proj(emb) >>> emb_proj.shape torch.Size([4, 64])
- class speechbrain.nnet.unet.DecoderUNetModel(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, resblock_updown=False, norm_num_groups=32, out_kernel_size=3, use_fixup_init=True)[source]ο
Bases:
Module
The half UNet model with attention and timestep embedding. For usage, see UNet.
- Parameters:
in_channels (int) β channels in the input torch.Tensor.
model_channels (int) β base channel count for the model.
out_channels (int) β channels in the output torch.Tensor.
num_res_blocks (int) β number of residual blocks per downsample.
attention_resolutions (int) β a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used.
dropout (float) β the dropout probability.
channel_mult (int) β channel multiplier for each level of the UNet.
conv_resample (bool) β if True, use learned convolutions for upsampling and downsampling
dims (int) β determines if the signal is 1D, 2D, or 3D.
num_heads (int) β the number of attention heads in each attention layer.
num_head_channels (int) β
- if specified, ignore num_heads and instead use
a fixed channel width per attention head.
num_heads_upsample (int) β
- works with num_heads to set a different number
of heads for upsampling. Deprecated.
resblock_updown (bool) β use residual blocks for up/downsampling.
norm_num_groups (int) β Number of groups to use in norm, default 32
out_kernel_size (int) β Output kernel size, default 3
use_fixup_init (bool) β whether to use FixUp initialization
Example
>>> model = DecoderUNetModel( ... in_channels=1, ... model_channels=32, ... out_channels=3, ... num_res_blocks=1, ... attention_resolutions=[1] ... ) >>> x = torch.randn(4, 1, 2, 4) >>> ts = torch.tensor([10, 100, 50, 25]) >>> out = model(x, ts) >>> out.shape torch.Size([4, 3, 16, 32])
- class speechbrain.nnet.unet.DownsamplingPadding(factor, len_dim=2, dims=None)[source]ο
Bases:
Module
A wrapper module that applies the necessary padding for the downsampling factor
- Parameters:
Example
>>> padding = DownsamplingPadding(factor=4, dims=[1, 2], len_dim=1) >>> x = torch.randn(4, 7, 14) >>> length = torch.tensor([1., 0.8, 1., 0.7]) >>> x, length_new = padding(x, length) >>> x.shape torch.Size([4, 8, 16]) >>> length_new tensor([0.8750, 0.7000, 0.8750, 0.6125])
- class speechbrain.nnet.unet.UNetNormalizingAutoencoder(in_channels, model_channels, encoder_out_channels, latent_channels, encoder_num_res_blocks, encoder_attention_resolutions, decoder_num_res_blocks, decoder_attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), dims=2, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, norm_num_groups=32, resblock_updown=False, out_kernel_size=3, len_dim=2, out_mask_value=0.0, latent_mask_value=0.0, use_fixup_norm=False, downsampling_padding=None)[source]ο
Bases:
NormalizingAutoencoder
A convenience class for a UNet-based Variational Autoencoder (VAE) - useful in constructing Latent Diffusion models
- Parameters:
in_channels (int) β the number of input channels
model_channels (int) β the number of channels in the convolutional layers of the UNet encoder and decoder
encoder_out_channels (int) β the number of channels the encoder will output
latent_channels (int) β the number of channels in the latent space
encoder_num_res_blocks (int) β the number of residual blocks in the encoder
encoder_attention_resolutions (list) β the resolutions at which to apply attention layers in the encoder
decoder_num_res_blocks (int) β the number of residual blocks in the decoder
decoder_attention_resolutions (list) β the resolutions at which to apply attention layers in the encoder
dropout (float) β the dropout probability
channel_mult (tuple) β channel multipliers for each layer
dims (int) β the convolution dimension to use (1, 2 or 3)
num_heads (int) β the number of attention heads
num_head_channels (int) β the number of channels in attention heads
num_heads_upsample (int) β the number of upsampling heads
norm_num_groups (int) β Number of norm groups, default 32
resblock_updown (bool) β whether to use residual blocks for upsampling and downsampling
out_kernel_size (int) β the kernel size for output convolution layers (if applicable)
len_dim (int) β Size of the output.
out_mask_value (float) β Value to fill when masking the output.
latent_mask_value (float) β Value to fill when masking the latent variable.
use_fixup_norm (bool) β whether to use FixUp normalization
downsampling_padding (int) β Amount of padding to apply in downsampling, default 2 ** len(channel_mult)
Example
>>> unet_ae = UNetNormalizingAutoencoder( ... in_channels=1, ... model_channels=4, ... encoder_out_channels=16, ... latent_channels=3, ... encoder_num_res_blocks=1, ... encoder_attention_resolutions=[], ... decoder_num_res_blocks=1, ... decoder_attention_resolutions=[], ... norm_num_groups=2, ... ) >>> x = torch.randn(4, 1, 32, 32) >>> x_enc = unet_ae.encode(x) >>> x_enc.shape torch.Size([4, 3, 4, 4]) >>> x_dec = unet_ae.decode(x_enc) >>> x_dec.shape torch.Size([4, 1, 32, 32])