"""
Denoising network for STADiffuser
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union
from dataclasses import dataclass
from diffusers.utils import BaseOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
from .modules import SpatialEncoding, SpatialEncoding3D
[docs]
@dataclass
class StaDiffuserOutput(BaseOutput):
sample: torch.Tensor
spatial_coord: torch.Tensor
[docs]
class SpaUNet1DModel(ModelMixin, ConfigMixin):
"""
The denoising network for STADiffuser, extending the `UNet1DModel` from `diffusers` with spatial processing.
This model is designed to process one-dimensional data, such as time series or signals, with an emphasis on spatial information. It includes features for spatial encoding and can be configured to use three-dimensional spatial concatenation for enhanced performance.
Parameters
----------
sample_size : int, optional
The default length of the input sample. Defaults to 32.
spatial_encoding : str, optional
The type of spatial encoding to use. Defaults to "sinusoidal".
sample_rate : int, optional
The sample rate of the input data, if applicable.
in_channels : int, optional
The number of channels in the input sample. Defaults to 2.
out_channels : int, optional
The number of channels in the output sample. Defaults to 2.
extra_in_channels : int, optional
Additional channels to append to the input. Defaults to 0.
time_embedding_type : str, optional
The type of time embedding to use. Defaults to "fourier".
flip_sin_to_cos : bool, optional
Whether to convert sine to cosine for Fourier time embedding. Defaults to True.
use_timestep_embedding : bool, optional
Whether to use timestep embedding. Defaults to False.
freq_shift : float, optional
Frequency shift for Fourier time embedding. Defaults to 0.0.
down_block_types : Tuple[str], optional
Tuple of downsample block types.
up_block_types : Tuple[str], optional
Tuple of upsample block types.
mid_block_type : str, optional
The block type for the middle of the UNet. Defaults to "UNetMidBlock1D".
out_block_type : str, optional
The optional output processing block type for the UNet.
block_out_channels : Tuple[int], optional
Tuple of block output channels.
act_fn : str, optional
The optional activation function to use in UNet blocks.
norm_num_groups : int, optional
The number of groups for group normalization. Defaults to 8.
layers_per_block : int, optional
The number of layers per block in the UNet. Defaults to 1.
downsample_each_block : bool, optional
Whether to downsample in each block. Defaults to False.
spatial3d_concat : bool, optional
Whether to use three-dimensional spatial concatenation.
class_embed_type : str, optional
The type of class embedding to use.
num_class_embeds : int, optional
The number of class embeddings.
Attributes
----------
time_embedder : object
The time embedding module.
spatial_embedder : object
The spatial embedding module.
model : object
The core UNet model.
"""
@register_to_config
def __init__(
self,
sample_size: int = 32,
spatial_encoding: str = "sinusoidal",
sample_rate: Optional[int] = None,
in_channels: int = 2,
out_channels: int = 1,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
mid_block_type: Tuple[str] = "UNetMidBlock1D",
out_block_type: str = None,
block_out_channels: Tuple[int] = (32, 32, 64),
act_fn: str = None,
norm_num_groups: int = 8,
layers_per_block: int = 1,
downsample_each_block: bool = False,
spatial3d_concat: Optional[bool] = False,
class_embed_type: Optional[str] = "embedding",
num_class_embeds: Optional[int] = None,
):
super().__init__()
self.sample_size = sample_size
self.spatial3d_concat = spatial3d_concat
if spatial_encoding == "sinusoidal":
self.spa_encoder = SpatialEncoding(sample_size)
elif spatial_encoding == "sinusoidal3d":
self.spa_encoder = SpatialEncoding3D(concat=spatial3d_concat, channels=sample_size)
else:
raise NotImplementedError(f"Unknown spatial encoding {spatial_encoding}")
# time enconding
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
)
timestep_input_dim = block_out_channels[0]
if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=timestep_input_dim,
time_embed_dim=time_embed_dim,
act_fn=act_fn,
out_dim=block_out_channels[0],
)
# class embedding
if num_class_embeds is not None:
if class_embed_type == "embedding":
self.class_embedding = nn.Embedding(num_class_embeds, sample_size)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, sample_size)
else:
raise NotImplementedError(f"Unknown class embedding type {class_embed_type}")
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
self.out_block = None
# down
output_channel = in_channels
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
if i == 0:
input_channel += extra_in_channels
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_downsample=not is_final_block or downsample_each_block,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = get_mid_block(
mid_block_type,
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
embed_dim=block_out_channels[0],
num_layers=layers_per_block,
add_downsample=downsample_each_block,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
if out_block_type is None:
final_upsample_channels = out_channels
else:
final_upsample_channels = block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = (
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
)
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block,
in_channels=prev_output_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_upsample=not is_final_block,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.out_block = get_out_block(
out_block_type=out_block_type,
num_groups_out=num_groups_out,
embed_dim=block_out_channels[0],
out_channels=out_channels,
act_fn=act_fn,
fc_dim=block_out_channels[-1] // 4,
)
[docs]
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
spatial_coord: torch.FloatTensor = None,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[StaDiffuserOutput, Tuple]:
r"""
Forward pass of the STADiffuser model.
Args:
sample (torch.FloatTensor):
Input tensor of shape `(batch_size, num_channels, sample_size)` containing noisy inputs.
timestep (Union[torch.Tensor, float, int]):
Batched timestep values.
spatial_coord (torch.FloatTensor, optional):
Spatial coordinates tensor of shape `(batch_size, 2)`. Defaults to None.
class_labels (torch.LongTensor, optional):
Class labels tensor of shape `(batch_size)`. Defaults to None.
return_dict (bool, optional):
Whether to return a `StaDiffuserOutput` instead of a plain tuple. Defaults to True.
Returns:
StaDiffuserOutput or tuple: If `return_dict` is True, returns a StaDiffuserOutput instance, otherwise returns a tuple.
"""
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
timestep_embed = self.time_proj(timesteps)
if self.config.use_timestep_embedding:
timestep_embed = self.time_mlp(timestep_embed)
else:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
# 2. spatial
if spatial_coord is None:
spatial_embed = 0
else:
spatial_embed = self.spa_encoder(spatial_coord)
if self.spatial3d_concat:
emb = spatial_embed[:, [0], :] + timestep_embed
z_emb = spatial_embed[:, [1], :]
# concat
emb = torch.cat([emb, z_emb], dim=1)
else:
# spatial_embed (batch_size, 1, sample_size)
# timestep_embed (batch_size, 16, sample_size)
# broadcast spatial_embed to (batch_size, 16, sample_size)
emb = spatial_embed + timestep_embed
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels must be provided if class_embedding is not None")
if self.config.class_embed_type == "timestep":
class_emb = self.time_proj(class_labels)
else:
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
class_emb = torch.unsqueeze(class_emb, dim=1)
emb = emb + class_emb
# 2. down
down_block_res_samples = ()
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 3. mid
if self.mid_block:
sample = self.mid_block(sample, emb)
# 4. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-1:]
down_block_res_samples = down_block_res_samples[:-1]
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=emb)
# 5. post-process
if self.out_block:
sample = self.out_block(sample, emb)
if not return_dict:
return (sample,)
return StaDiffuserOutput(sample=sample, spatial_coord=spatial_coord)