"""A UNet model implementation for use with diffusion models
Adapted from OpenAI guided diffusion, with slight modifications
and additional features
https://github.com/openai/guided-diffusion
MIT License
Copyright (c) 2021 OpenAI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Authors
* Artem Ploujnikov 2022
"""
from abc import abstractmethod
from speechbrain.utils.data_utils import pad_divisible
from .autoencoders import NormalizingAutoencoder
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
def fixup(module, use_fixup_init=True):
"""
Zero out the parameters of a module and return it.
Arguments
---------
module: torch.nn.Module
a module
use_fixup_init: bool
whether to zero out the parameters. If set to
false, the function is a no-op
"""
if use_fixup_init:
for p in module.parameters():
p.detach().zero_()
return module
[docs]
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
Arguments
---------
dims: int
The number of dimensions
Any remaining arguments are passed to the constructor
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
[docs]
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
[docs]
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Arguments
---------
timesteps: torch.Tensor
a 1-D Tensor of N indices, one per batch element. These may be fractional.
dim: int
the dimension of the output.
max_period: int
controls the minimum frequency of the embeddings.
Returns
-------
result: torch.Tensor
an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
[docs]
class AttentionPool2d(nn.Module):
"""Two-dimensional attentional pooling
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
Arguments
---------
spatial_dim: int
the size of the spatial dimension
embed_dim: int
the embedding dimension
num_heads_channels: int
the number of attention heads
output_dim: int
the output dimension
Example
-------
>>> attn_pool = AttentionPool2d(
... spatial_dim=64,
... embed_dim=16,
... num_heads_channels=2,
... output_dim=4
... )
>>> x = torch.randn(4, 1, 64, 64)
>>> x_pool = attn_pool(x)
>>> x_pool.shape
torch.Size([4, 4])
"""
def __init__(
self,
spatial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(embed_dim, spatial_dim ** 2 + 1) / embed_dim ** 0.5
)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
[docs]
def forward(self, x):
"""Computes the attention forward pass
Arguments
---------
x: torch.Tensor
the tensor to be attended to
Returns
-------
result: torch.Tensor
the attention output
"""
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]
[docs]
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
[docs]
@abstractmethod
def forward(self, x, emb=None):
"""
Apply the module to `x` given `emb` timestep embeddings.
Arguments
---------
x: torch.Tensor
the data tensor
emb: torch.Tensor
the embedding tensor
"""
[docs]
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""A sequential module that passes timestep embeddings to the children that
support it as an extra input.
Example
-------
>>> from speechbrain.nnet.linear import Linear
>>> class MyBlock(TimestepBlock):
... def __init__(self, input_size, output_size, emb_size):
... super().__init__()
... self.lin = Linear(
... n_neurons=output_size,
... input_size=input_size
... )
... self.emb_proj = Linear(
... n_neurons=output_size,
... input_size=emb_size,
... )
... def forward(self, x, emb):
... return self.lin(x) + self.emb_proj(emb)
>>> tes = TimestepEmbedSequential(
... MyBlock(128, 64, 16),
... Linear(
... n_neurons=32,
... input_size=64
... )
... )
>>> x = torch.randn(4, 10, 128)
>>> emb = torch.randn(4, 10, 16)
>>> out = tes(x, emb)
>>> out.shape
torch.Size([4, 10, 32])
"""
[docs]
def forward(self, x, emb=None):
"""Computes a sequential pass with sequential embeddings where applicable
Arguments
---------
x: torch.Tensor
the data tensor
emb: torch.Tensor
timestep embeddings"""
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
[docs]
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
Arguments
---------
channels: torch.Tensor
channels in the inputs and outputs.
use_conv: bool
a bool determining if a convolution is applied.
dims: int
determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
Example
-------
>>> ups = Upsample(channels=4, use_conv=True, dims=2, out_channels=8)
>>> x = torch.randn(8, 4, 32, 32)
>>> x_up = ups(x)
>>> x_up.shape
torch.Size([8, 8, 64, 64])
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(
dims, self.channels, self.out_channels, 3, padding=1
)
[docs]
def forward(self, x):
"""Computes the upsampling pass
Arguments
---------
x: torch.Tensor
layer inputs
Results
-------
result: torch.Tensor
upsampled outputs"""
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
[docs]
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
Arguments
---------
channels: int
channels in the inputs and outputs.
use_conv: bool
a bool determining if a convolution is applied.
dims: int
determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
Example
-------
>>> ups = Downsample(channels=4, use_conv=True, dims=2, out_channels=8)
>>> x = torch.randn(8, 4, 32, 32)
>>> x_up = ups(x)
>>> x_up.shape
torch.Size([8, 8, 16, 16])
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=1,
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
[docs]
def forward(self, x):
"""Computes the downsampling pass
Arguments
---------
x: torch.Tensor
layer inputs
Returns
-------
result: torch.Tensor
downsampled outputs
"""
assert x.shape[1] == self.channels
return self.op(x)
[docs]
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
Arguments
---------
channels: int
the number of input channels.
emb_channels: int
the number of timestep embedding channels.
dropout: float
the rate of dropout.
out_channels: int
if specified, the number of out channels.
use_conv: bool
if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
dims: int
determines if the signal is 1D, 2D, or 3D.
up: bool
if True, use this block for upsampling.
down: bool
if True, use this block for downsampling.
norm_num_groups: int
the number of groups for group normalization
use_fixup_init: bool
whether to use FixUp initialization
Example
-------
>>> res = ResBlock(
... channels=4,
... emb_channels=8,
... dropout=0.1,
... norm_num_groups=2,
... use_conv=True,
... )
>>> x = torch.randn(2, 4, 32, 32)
>>> emb = torch.randn(2, 8)
>>> res_out = res(x, emb)
>>> res_out.shape
torch.Size([2, 4, 32, 32])
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
dims=2,
up=False,
down=False,
norm_num_groups=32,
use_fixup_init=True,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.in_layers = nn.Sequential(
nn.GroupNorm(norm_num_groups, channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
if emb_channels is not None:
self.emb_layers = nn.Sequential(
nn.SiLU(), nn.Linear(emb_channels, self.out_channels,),
)
else:
self.emb_layers = None
self.out_layers = nn.Sequential(
nn.GroupNorm(norm_num_groups, self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
fixup(
conv_nd(
dims, self.out_channels, self.out_channels, 3, padding=1
),
use_fixup_init=use_fixup_init,
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
[docs]
def forward(self, x, emb=None):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
Arguments
---------
x: torch.Tensor
an [N x C x ...] Tensor of features.
emb: torch.Tensor
an [N x emb_channels] Tensor of timestep embeddings.
Returns
-------
result: torch.Tensor
an [N x C x ...] Tensor of outputs.
"""
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
if emb is not None:
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
else:
emb_out = torch.zeros_like(h)
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
[docs]
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Arguments
---------
channels: int
the number of channels
num_heads: int
the number of attention heads
num_head_channels: int
the number of channels in each attention head
norm_num_groups: int
the number of groups used for group normalization
use_fixup_init: bool
whether to use FixUp initialization
Example
-------
>>> attn = AttentionBlock(
... channels=8,
... num_heads=4,
... num_head_channels=4,
... norm_num_groups=2
... )
>>> x = torch.randn(4, 8, 16, 16)
>>> out = attn(x)
>>> out.shape
torch.Size([4, 8, 16, 16])
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
norm_num_groups=32,
use_fixup_init=True,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.GroupNorm(norm_num_groups, channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
self.attention = QKVAttention(self.num_heads)
self.proj_out = fixup(conv_nd(1, channels, channels, 1), use_fixup_init)
[docs]
def forward(self, x):
"""Completes the forward pass
Arguments
---------
x: torch.Tensor
the data to be attended to
Returns
-------
result: torch.Tensor
The data, with attention applied
"""
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
[docs]
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
Example
-------
>>> attn = QKVAttention(4)
>>> n = 4
>>> c = 8
>>> h = 64
>>> w = 16
>>> qkv = torch.randn(4, (3 * h * c), w)
>>> out = attn(qkv)
>>> out.shape
torch.Size([4, 512, 16])
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
[docs]
def forward(self, qkv):
"""Apply QKV attention.
Arguments
---------
qkv: torch.Tensor
an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
Results
-------
result: torch.Tensor
an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum(
"bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)
)
return a.reshape(bs, -1, length)
[docs]
def build_emb_proj(emb_config, proj_dim=None, use_emb=None):
"""Builds a dictionary of embedding modules for embedding
projections
Arguments
---------
emb_config: dict
a configuration dictionary
proj_dim: int
the target projection dimension
use_cond_emb: dict
an optional dictioanry of "switches" to turn
embeddings on and off
Returns
-------
result: torch.nn.ModuleDict
a ModuleDict with a module for each embedding
"""
emb_proj = {}
if emb_config is not None:
for key, item_config in emb_config.items():
if use_emb is None or use_emb.get(key):
if "emb_proj" in item_config:
emb_proj[key] = emb_proj
else:
emb_proj[key] = EmbeddingProjection(
emb_dim=item_config["emb_dim"], proj_dim=proj_dim
)
return nn.ModuleDict(emb_proj)
[docs]
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
Arguments
---------
in_channels: int
channels in the input Tensor.
model_channels: int
base channel count for the model.
out_channels: int
channels in the output Tensor.
num_res_blocks: int
number of residual blocks per downsample.
attention_resolutions: int
a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
dropout: float
the dropout probability.
channel_mult: int
channel multiplier for each level of the UNet.
conv_resample: bool
if True, use learned convolutions for upsampling and
downsampling
emb_dim: int
time embedding dimension (defaults to model_channels * 4)
cond_emb: dict
embeddings on which the model will be conditioned
Example:
{
"speaker": {
"emb_dim": 256
},
"label": {
"emb_dim": 12
}
}
use_cond_emb: dict
a dictionary with keys corresponding to keys in cond_emb
and values corresponding to Booleans that turn embeddings
on and off. This is useful in combination with hparams files
to turn embeddings on and off with simple switches
Example:
{"speaker": False, "label": True}
dims: int
determines if the signal is 1D, 2D, or 3D.
num_heads: int
the number of attention heads in each attention layer.
num_heads_channels: int
if specified, ignore num_heads and instead use
a fixed channel width per attention head.
num_heads_upsample: int
works with num_heads to set a different number
of heads for upsampling. Deprecated.
resblock_updown: bool
use residual blocks for up/downsampling.
use_fixup_init: bool
whether to use FixUp initialization
Example
-------
>>> model = UNetModel(
... in_channels=3,
... model_channels=32,
... out_channels=1,
... num_res_blocks=1,
... attention_resolutions=[1]
... )
>>> x = torch.randn(4, 3, 16, 32)
>>> ts = torch.tensor([10, 100, 50, 25])
>>> out = model(x, ts)
>>> out.shape
torch.Size([4, 1, 16, 32])
"""
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
emb_dim=None,
cond_emb=None,
use_cond_emb=None,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
norm_num_groups=32,
resblock_updown=False,
use_fixup_init=True,
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.dtype = torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.cond_emb = cond_emb
self.use_cond_emb = use_cond_emb
if emb_dim is None:
emb_dim = model_channels * 4
self.time_embed = EmbeddingProjection(model_channels, emb_dim)
self.cond_emb_proj = build_emb_proj(
emb_config=cond_emb, proj_dim=emb_dim, use_emb=use_cond_emb
)
ch = input_ch = int(channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, ch, 3, padding=1)
)
]
)
self._feature_size = ch
input_block_chans = [ch]
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
emb_dim,
dropout,
out_channels=int(mult * model_channels),
dims=dims,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
)
]
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
emb_dim,
dropout,
out_channels=out_ch,
dims=dims,
down=True,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
emb_dim,
dropout,
dims=dims,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
),
ResBlock(
ch,
emb_dim,
dropout,
dims=dims,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
emb_dim,
dropout,
out_channels=int(model_channels * mult),
dims=dims,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
)
]
ch = int(model_channels * mult)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads_upsample,
num_head_channels=num_head_channels,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(
ch,
emb_dim,
dropout,
out_channels=out_ch,
dims=dims,
up=True,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
)
if resblock_updown
else Upsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
nn.GroupNorm(norm_num_groups, ch),
nn.SiLU(),
fixup(
conv_nd(dims, input_ch, out_channels, 3, padding=1),
use_fixup_init=use_fixup_init,
),
)
[docs]
def forward(self, x, timesteps, cond_emb=None, **kwargs):
"""Apply the model to an input batch.
Arguments
---------
x: torch.Tensor
an [N x C x ...] Tensor of inputs.
timesteps: torch.Tensor
a 1-D batch of timesteps.
cond_emb: dict
a string -> tensor dictionary of conditional
embeddings (multiple embeddings are supported)
Returns
-------
result: torch.Tensor
an [N x C x ...] Tensor of outputs.
"""
hs = []
emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels)
)
if cond_emb is not None:
for key, value in cond_emb.items():
emb_proj = self.cond_emb_proj[key](value)
emb += emb_proj
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = h.type(x.dtype)
return self.out(h)
[docs]
class EncoderUNetModel(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNetModel.
Arguments
---------
in_channels: int
channels in the input Tensor.
model_channels: int
base channel count for the model.
out_channels: int
channels in the output Tensor.
num_res_blocks: int
number of residual blocks per downsample.
attention_resolutions: int
a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
dropout: float
the dropout probability.
channel_mult: int
channel multiplier for each level of the UNet.
conv_resample: bool
if True, use learned convolutions for upsampling and
downsampling
emb_dim: int
time embedding dimension (defaults to model_channels * 4)
cond_emb: dict
embeddings on which the model will be conditioned
Example:
{
"speaker": {
"emb_dim": 256
},
"label": {
"emb_dim": 12
}
}
use_cond_emb: dict
a dictionary with keys corresponding to keys in cond_emb
and values corresponding to Booleans that turn embeddings
on and off. This is useful in combination with hparams files
to turn embeddings on and off with simple switches
Example:
{"speaker": False, "label": True}
dims: int
determines if the signal is 1D, 2D, or 3D.
num_heads: int
the number of attention heads in each attention layer.
num_heads_channels: int
if specified, ignore num_heads and instead use
a fixed channel width per attention head.
num_heads_upsample: int
works with num_heads to set a different number
of heads for upsampling. Deprecated.
resblock_updown: bool
use residual blocks for up/downsampling.
use_fixup_init: bool
whether to use FixUp initialization
out_kernel_size: int
the kernel size of the output convolution
Example
-------
>>> model = EncoderUNetModel(
... in_channels=3,
... model_channels=32,
... out_channels=1,
... num_res_blocks=1,
... attention_resolutions=[1]
... )
>>> x = torch.randn(4, 3, 16, 32)
>>> ts = torch.tensor([10, 100, 50, 25])
>>> out = model(x, ts)
>>> out.shape
torch.Size([4, 1, 2, 4])
"""
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
norm_num_groups=32,
resblock_updown=False,
pool=None,
attention_pool_dim=None,
out_kernel_size=3,
use_fixup_init=True,
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.dtype = torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.out_kernel_size = out_kernel_size
emb_dim = model_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(model_channels, emb_dim),
nn.SiLU(),
nn.Linear(emb_dim, emb_dim),
)
ch = int(channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, ch, 3, padding=1)
)
]
)
self._feature_size = ch
input_block_chans = [ch]
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
emb_dim,
dropout,
out_channels=int(mult * model_channels),
dims=dims,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
)
]
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
emb_dim,
dropout,
out_channels=out_ch,
dims=dims,
down=True,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch, emb_dim, dropout, dims=dims, use_fixup_init=use_fixup_init,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
),
ResBlock(
ch, emb_dim, dropout, dims=dims, use_fixup_init=use_fixup_init,
),
)
self._feature_size += ch
self.pool = pool
self.spatial_pooling = False
if pool is None:
self.out = nn.Sequential(
nn.GroupNorm(
num_channels=ch, num_groups=norm_num_groups, eps=1e-6
),
nn.SiLU(),
conv_nd(
dims,
ch,
out_channels,
kernel_size=out_kernel_size,
padding="same",
),
)
elif pool == "adaptive":
self.out = nn.Sequential(
nn.GroupNorm(norm_num_groups, ch),
nn.SiLU(),
nn.AdaptiveAvgPool2d((1, 1)),
fixup(
conv_nd(dims, ch, out_channels, 1),
use_fixup_init=use_fixup_init,
),
nn.Flatten(),
)
elif pool == "attention":
assert num_head_channels != -1
self.out = nn.Sequential(
nn.GroupNorm(norm_num_groups, ch),
nn.SiLU(),
AttentionPool2d(
attention_pool_dim // ds,
ch,
num_head_channels,
out_channels,
),
)
elif pool == "spatial":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, self.out_channels),
)
self.spatial_pooling = True
elif pool == "spatial_v2":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
nn.GroupNorm(norm_num_groups, 2048),
nn.SiLU(),
nn.Linear(2048, self.out_channels),
)
self.spatial_pooling = True
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
[docs]
def forward(self, x, timesteps=None):
"""
Apply the model to an input batch.
Arguments
---------
x: torch.Tensor
an [N x C x ...] Tensor of inputs.
timesteps: torch.Tensor
a 1-D batch of timesteps.
Returns
--------
result: torch.Tensor
an [N x K] Tensor of outputs.
"""
emb = None
if timesteps is not None:
emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels)
)
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
if self.spatial_pooling:
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb)
if self.spatial_pooling:
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = torch.cat(results, axis=-1)
return self.out(h)
else:
h = h.type(x.dtype)
return self.out(h)
[docs]
class EmbeddingProjection(nn.Module):
"""A simple module that computes the projection of an
embedding vector onto the specified number of dimensions
Arguments
---------
emb_dim: int
the original embedding dimensionality
proj_dim: int
the dimensionality of the target projection
space
Example
-------
>>> mod_emb_proj = EmbeddingProjection(emb_dim=16, proj_dim=64)
>>> emb = torch.randn(4, 16)
>>> emb_proj = mod_emb_proj(emb)
>>> emb_proj.shape
torch.Size([4, 64])
"""
def __init__(self, emb_dim, proj_dim):
super().__init__()
self.emb_dim = emb_dim
self.proj_dim = proj_dim
self.input = nn.Linear(emb_dim, proj_dim)
self.act = nn.SiLU()
self.output = nn.Linear(proj_dim, proj_dim)
[docs]
def forward(self, emb):
"""Computes the forward pass
Arguments
---------
emb: torch.Tensor
the original embedding tensor
Returns
-------
result: torch.Tensor
the target embedding space
"""
x = self.input(emb)
x = self.act(x)
x = self.output(x)
return x
[docs]
class DecoderUNetModel(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
Arguments
---------
in_channels: int
channels in the input Tensor.
model_channels: int
base channel count for the model.
out_channels: int
channels in the output Tensor.
num_res_blocks: int
number of residual blocks per downsample.
attention_resolutions: int
a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
dropout: float
the dropout probability.
channel_mult: int
channel multiplier for each level of the UNet.
conv_resample: bool
if True, use learned convolutions for upsampling and
downsampling
emb_dim: int
time embedding dimension (defaults to model_channels * 4)
cond_emb: dict
embeddings on which the model will be conditioned
Example:
{
"speaker": {
"emb_dim": 256
},
"label": {
"emb_dim": 12
}
}
use_cond_emb: dict
a dictionary with keys corresponding to keys in cond_emb
and values corresponding to Booleans that turn embeddings
on and off. This is useful in combination with hparams files
to turn embeddings on and off with simple switches
Example:
{"speaker": False, "label": True}
dims: int
determines if the signal is 1D, 2D, or 3D.
num_heads: int
the number of attention heads in each attention layer.
num_heads_channels: int
if specified, ignore num_heads and instead use
a fixed channel width per attention head.
num_heads_upsample: int
works with num_heads to set a different number
of heads for upsampling. Deprecated.
resblock_updown: bool
use residual blocks for up/downsampling.
use_fixup_init: bool
whether to use FixUp initialization
Example
-------
>>> model = DecoderUNetModel(
... in_channels=1,
... model_channels=32,
... out_channels=3,
... num_res_blocks=1,
... attention_resolutions=[1]
... )
>>> x = torch.randn(4, 1, 2, 4)
>>> ts = torch.tensor([10, 100, 50, 25])
>>> out = model(x, ts)
>>> out.shape
torch.Size([4, 3, 16, 32])
"""
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
resblock_updown=False,
norm_num_groups=32,
out_kernel_size=3,
use_fixup_init=True,
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.dtype = torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
emb_dim = model_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(model_channels, emb_dim),
nn.SiLU(),
nn.Linear(emb_dim, emb_dim),
)
ch = int(channel_mult[0] * model_channels)
self.input_block = TimestepEmbedSequential(
conv_nd(dims, in_channels, ch, 3, padding=1)
)
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
emb_dim,
dropout,
dims=dims,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
),
ResBlock(
ch,
emb_dim,
dropout,
dims=dims,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
),
)
self.upsample_blocks = nn.ModuleList()
self._feature_size = ch
ds = 1
for level, mult in enumerate(reversed(channel_mult)):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
emb_dim,
dropout,
out_channels=int(mult * model_channels),
dims=dims,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
)
]
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
)
)
self.upsample_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
if level != len(channel_mult) - 1:
out_ch = ch
self.upsample_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
emb_dim,
dropout,
out_channels=out_ch,
dims=dims,
up=True,
norm_num_groups=norm_num_groups,
use_fixup_init=use_fixup_init,
)
if resblock_updown
else Upsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
ds *= 2
self._feature_size += ch
self.out = nn.Sequential(
nn.GroupNorm(num_channels=ch, num_groups=norm_num_groups, eps=1e-6),
nn.SiLU(),
conv_nd(
dims,
ch,
out_channels,
kernel_size=out_kernel_size,
padding="same",
),
)
self._feature_size += ch
[docs]
def forward(self, x, timesteps=None):
"""
Apply the model to an input batch.
Arguments
---------
x: torch.Tensor
an [N x C x ...] Tensor of inputs.
timesteps: torch.Tensor
a 1-D batch of timesteps.
Returns
--------
result: torch.Tensor
an [N x K] Tensor of outputs.
"""
emb = None
if timesteps is not None:
emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels)
)
h = x.type(self.dtype)
h = self.input_block(h, emb)
h = self.middle_block(h, emb)
for module in self.upsample_blocks:
h = module(h, emb)
h = self.out(h)
return h
DEFAULT_PADDING_DIMS = [2, 3]
[docs]
class DownsamplingPadding(nn.Module):
"""A wrapper module that applies the necessary padding for
the downsampling factor
Arguments
---------
factor: int
the downsampling / divisibility factor
len_dim: int
the index of the dimension in which the length will vary
dims: list
the list of dimensions to be included in padding
Example
-------
>>> padding = DownsamplingPadding(factor=4, dims=[1, 2], len_dim=1)
>>> x = torch.randn(4, 7, 14)
>>> length = torch.tensor([1., 0.8, 1., 0.7])
>>> x, length_new = padding(x, length)
>>> x.shape
torch.Size([4, 8, 16])
>>> length_new
tensor([0.8750, 0.7000, 0.8750, 0.6125])
"""
def __init__(self, factor, len_dim=2, dims=None):
super().__init__()
self.factor = factor
self.len_dim = len_dim
if dims is None:
dims = DEFAULT_PADDING_DIMS
self.dims = dims
[docs]
def forward(self, x, length=None):
"""Applies the padding
Arguments
---------
x: torch.Tensor
the sample
length: torch.Tensor
the length tensor
Returns
-------
x_pad: torch.Tensor
the padded tensor
lens: torch.Tensor
the new, adjusted lengths, if applicable
"""
updated_length = length
for dim in self.dims:
# TODO: Consider expanding pad_divisible to support multiple dimensions
x, length_pad = pad_divisible(x, length, self.factor, len_dim=dim)
if dim == self.len_dim:
updated_length = length_pad
return x, updated_length
[docs]
class UNetNormalizingAutoencoder(NormalizingAutoencoder):
"""A convenience class for a UNet-based Variational Autoencoder (VAE) -
useful in constructing Latent Diffusion models
Arguments
---------
in_channels: int
the number of input channels
model_channels: int
the number of channels in the convolutional layers of the
UNet encoder and decoder
encoder_out_channels: int
the number of channels the encoder will output
latent_channels: int
the number of channels in the latent space
encoder_num_res_blocks: int
the number of residual blocks in the encoder
encoder_attention_resolutions: list
the resolutions at which to apply attention layers in the encoder
decoder_num_res_blocks: int
the number of residual blocks in the decoder
decoder_attention_resolutions: list
the resolutions at which to apply attention layers in the encoder
dropout: float
the dropout probability
channel_mult: tuple
channel multipliers for each layer
dims: int
the convolution dimension to use (1, 2 or 3)
num_heads: int
the number of attention heads
num_head_channels: int
the number of channels in attention heads
num_heads_upsample: int
the number of upsampling heads
resblock_updown: bool
whether to use residual blocks for upsampling and downsampling
out_kernel_size: int
the kernel size for output convolution layers (if applicable)
use_fixup_norm: bool
whether to use FixUp normalization
Example
-------
>>> unet_ae = UNetNormalizingAutoencoder(
... in_channels=1,
... model_channels=4,
... encoder_out_channels=16,
... latent_channels=3,
... encoder_num_res_blocks=1,
... encoder_attention_resolutions=[],
... decoder_num_res_blocks=1,
... decoder_attention_resolutions=[],
... norm_num_groups=2,
... )
>>> x = torch.randn(4, 1, 32, 32)
>>> x_enc = unet_ae.encode(x)
>>> x_enc.shape
torch.Size([4, 3, 4, 4])
>>> x_dec = unet_ae.decode(x_enc)
>>> x_dec.shape
torch.Size([4, 1, 32, 32])
"""
def __init__(
self,
in_channels,
model_channels,
encoder_out_channels,
latent_channels,
encoder_num_res_blocks,
encoder_attention_resolutions,
decoder_num_res_blocks,
decoder_attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
dims=2,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
norm_num_groups=32,
resblock_updown=False,
out_kernel_size=3,
len_dim=2,
out_mask_value=0.0,
latent_mask_value=0.0,
use_fixup_norm=False,
downsampling_padding=None,
):
encoder_unet = EncoderUNetModel(
in_channels=in_channels,
model_channels=model_channels,
out_channels=encoder_out_channels,
num_res_blocks=encoder_num_res_blocks,
attention_resolutions=encoder_attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
dims=dims,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
norm_num_groups=norm_num_groups,
resblock_updown=resblock_updown,
out_kernel_size=out_kernel_size,
use_fixup_init=use_fixup_norm,
)
encoder = nn.Sequential(
encoder_unet,
conv_nd(
dims=dims,
in_channels=encoder_out_channels,
out_channels=latent_channels,
kernel_size=1,
),
)
if downsampling_padding is None:
downsampling_padding = 2 ** len(channel_mult)
encoder_pad = DownsamplingPadding(downsampling_padding)
decoder = DecoderUNetModel(
in_channels=latent_channels,
out_channels=in_channels,
model_channels=model_channels,
num_res_blocks=decoder_num_res_blocks,
attention_resolutions=decoder_attention_resolutions,
dropout=dropout,
channel_mult=list(channel_mult),
dims=dims,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
norm_num_groups=norm_num_groups,
resblock_updown=resblock_updown,
out_kernel_size=out_kernel_size,
use_fixup_init=use_fixup_norm,
)
super().__init__(
encoder=encoder,
latent_padding=encoder_pad,
decoder=decoder,
len_dim=len_dim,
out_mask_value=out_mask_value,
latent_mask_value=latent_mask_value,
)