API Reference¶
Pipelines¶
Pipeline for STADiffuser model
- stadiffuser.pipeline.get_recon(adata, autoencoder, use_net='spatial_net', apply_normalize=True, use_rep='latent', use_spatial='spatial', batch_mode=False, batch_size=256, num_neighbors=[5, 3], inplace=True, device='cuda:0', show_progress=True)[source]¶
Get the reconstructed expression matrix and latent representation with the trained autoencoder.
- Parameters:
adata (AnnData) – The AnnData object.
autoencoder (torch.nn.Module) – The trained autoencoder model.
use_net (str) – The key of the network in adata.uns.
apply_normalize (bool) – Whether to normalize the data. If True, normalize the data. For count data, it is recommended to normalize the data.
use_rep (str) – The key of the representation in adata.obsm.
use_spatial (str) – The key of the spatial coordinates in adata.obsm.
batch_mode (bool) – Whether to use batch mode. If True, use batch mode. For large dataset_hub, it is recommended to use batch mode.
batch_size (int) – The batch size in batch mode to compute the latent representation.
num_neighbors (list) – The number of neighbors to compute the latent representation, list of integers.
inplace (bool) – Whether to change the adata object in place. If True, change the adata object in place.
device (str) – The device to compute the latent representation. default is “cuda:0”.
show_progress (bool) – Whether to show the progress bar.
- Returns:
adata_recon – The AnnData object with the reconstructed expression matrix and latent representation.
- Return type:
AnnData
- stadiffuser.pipeline.prepare_dataset(adata: anndata.AnnData, use_rep=None, use_spatial='spatial', use_net='spatial_net', use_label=None, device='cpu')[source]¶
Transfer adata to pytorch_geometric dataset. This function is used for training the autoencoder model. Side effect: add edge_list to adata.uns.
- Parameters:
adata (AnnData) – a.obsm[“spatial”] must exist in adata.
factor (float) – The factor to divide the spatial coordinates.
used_mask (str) – The key of the mask in adata.obs.
use_rep (str) – The key of the expression matrix in adata.obsm. If None, use adata.X.
use_spatial (str) – The key of the spatial coordinates in adata.obsm.
use_label (str) – The key of the label in adata.obs.
device (str) – “cpu” or “cuda”
Notes (side effect: add edge_list to adata.uns)
- stadiffuser.pipeline.pretrain_autoencoder_multi(train_loaders, model, pretrain_epochs=100, lr=0.0001, weight_decay=1e-06, save_dir=None, model_name='autoencoder_pre', check_points=None, device='cpu')[source]¶
Prerain the autoencoder model on each slice.
- stadiffuser.pipeline.remove_edge(G, is_masked)[source]¶
Remove edges in the graph. :param G: The graph in scipy.sparse.coo_matrix format. :param is_masked: The mask of edges to be removed.
- Returns:
The new graph in scipy.sparse.coo_matrix format.
- Return type:
G
- stadiffuser.pipeline.simulate(denoiser: torch.nn.Module | None = None, autoencoder: torch.nn.Module | None = None, ref_data=None, spatial_coord: ndarray | None = None, labels: ndarray | None = None, noise_scheduler: diffusers.SchedulerMixin | None = None, init_x: torch.Tensor | None = None, seed=None, progress=True, n_samples: int = 1, return_tuple=True, normarlizer=None, add_rep='latent', use_spatial='spatial', device='cpu', **kwargs)[source]¶
Simulate the diffusion process of the model. :param denoiser: the denoising model :param spatial_coord: the spatial coordinates of the spots (n_spots, 2) or (n_spots, 3) :param labels: optional, the labels of the spots when the model is conditional :param noise_scheduler: the noise scheduler :param seed: random seed to control torch and numpy :param progress: whether to show the progress bar :param n_samples: int, number of samples to simulate :param return_tuple: bool, whether to return a tuple of tensors :param device: str, “cpu” or “cuda” :param kwargs: other arguments to pass to the model
- Returns:
denoised expression matrix, shape (n_spots, n_samples * p) if return_tuple is False,
- Return type:
denoised
- stadiffuser.pipeline.train_autoencoder(train_loader, model, n_epochs=1000, gradient_clip=5, lr=0.0001, weight_decay=1e-06, save_dir=None, model_name='autoencoder', device='cpu', check_points=None)[source]¶
Train the autoencoder model. :param train_loader: The data loader for training. :type train_loader: torch_geometric.loader or torch_geometric.data.DataLoader :param model: The autoencoder model. :type model: torch.nn.Module :param n_epochs: The number of epochs to train. :type n_epochs: int :param gradient_clip: The gradient clip. :type gradient_clip: float :param lr: The learning rate. :type lr: float :param weight_decay: The weight decay. :type weight_decay: float :param save_dir: The directory to save the model. If None, do not save the model. :type save_dir: str :param model_name: The name of the model. :type model_name: str :param device: The device to train the model. :type device: str :param check_points: The epochs to save the model–list of integers. :type check_points: list
- Returns:
model (torch.nn.Module) – The trained model.
loss_list (list) – The list of loss during training.
- stadiffuser.pipeline.train_autoencoder_multi(adata, model, use_batch=None, batch_list=None, n_epochs=400, lr=0.0001, weight_decay=1e-06, margin=1.0, update_interval=50, mnn_neighbors=15, save_dir=None, model_name='autoencoder_tri', device='cpu', check_points=None)[source]¶
Train the autoencoder model with triplet loss using multiple slices.
- Parameters:
adata (AnnData) – The AnnData object.
model (torch.nn.Module) – The autoencoder model.
use_batch (str) – The key of the batch in adata.obs.
batch_list (list) – The list of batch names.
n_epochs (int) – The number of epochs to train.
lr (float) – The learning rate.
weight_decay (float) – The weight decay.
margin (float) – The margin of the triplet loss. Default is 1.0.
- stadiffuser.pipeline.train_denoiser(train_loader, model, noise_scheduler: diffusers.SchedulerMixin, n_epochs: int = 1000, lr: float = 0.0001, lr_scheduler: str = 'cosine_annealing', weight_decay: float = 1e-06, gradient_clip: float = 5, num_class_embeds: int | None = None, save_dir: str | None = None, evaluate_interval: int | None = None, model_name: str = 'denoiser', device: str = 'cpu', check_points=None, eval_fn=None, eval_kwargs=None)[source]¶
Train the denoising model with the noise scheduler.
- Parameters:
train_loader (torch_geometric.loader or torch_geometric.data.DataLoader) – The data loader for training.
model (torch.nn.Module) – The denoising model.
noise_scheduler (SchedulerMixin) – The noise scheduler.
n_epochs (int) – The number of epochs to train.
lr (float) – The learning rate.
lr_scheduler (str) – The learning rate scheduler. Only “cosine_annealing” is supported.
weight_decay (float) – The weight decay.
gradient_clip (float) – The gradient clip.
num_class_embeds (int) – The number of class embeddings if None the model is unconditional.
- Returns:
model (torch.nn.Module) – The trained model.
loss_list (list) – The list of loss during training.
Models¶
Denoising network for STADiffuser
- class stadiffuser.models.SpaUNet1DModel(*args: Any, **kwargs: Any)[source]¶
Bases:
ModelMixin,ConfigMixinThe denoising network for STADiffuser, extending the UNet1DModel from diffusers with spatial processing.
This model is designed to process one-dimensional data, such as time series or signals, with an emphasis on spatial information. It includes features for spatial encoding and can be configured to use three-dimensional spatial concatenation for enhanced performance.
- Parameters:
sample_size (int, optional) – The default length of the input sample. Defaults to 32.
spatial_encoding (str, optional) – The type of spatial encoding to use. Defaults to “sinusoidal”.
sample_rate (int, optional) – The sample rate of the input data, if applicable.
in_channels (int, optional) – The number of channels in the input sample. Defaults to 2.
out_channels (int, optional) – The number of channels in the output sample. Defaults to 2.
extra_in_channels (int, optional) – Additional channels to append to the input. Defaults to 0.
time_embedding_type (str, optional) – The type of time embedding to use. Defaults to “fourier”.
flip_sin_to_cos (bool, optional) – Whether to convert sine to cosine for Fourier time embedding. Defaults to True.
use_timestep_embedding (bool, optional) – Whether to use timestep embedding. Defaults to False.
freq_shift (float, optional) – Frequency shift for Fourier time embedding. Defaults to 0.0.
down_block_types (Tuple[str], optional) – Tuple of downsample block types.
up_block_types (Tuple[str], optional) – Tuple of upsample block types.
mid_block_type (str, optional) – The block type for the middle of the UNet. Defaults to “UNetMidBlock1D”.
out_block_type (str, optional) – The optional output processing block type for the UNet.
block_out_channels (Tuple[int], optional) – Tuple of block output channels.
act_fn (str, optional) – The optional activation function to use in UNet blocks.
norm_num_groups (int, optional) – The number of groups for group normalization. Defaults to 8.
layers_per_block (int, optional) – The number of layers per block in the UNet. Defaults to 1.
downsample_each_block (bool, optional) – Whether to downsample in each block. Defaults to False.
spatial3d_concat (bool, optional) – Whether to use three-dimensional spatial concatenation.
class_embed_type (str, optional) – The type of class embedding to use.
num_class_embeds (int, optional) – The number of class embeddings.
- forward(sample: torch.FloatTensor, timestep: torch.Tensor | float | int, spatial_coord: torch.FloatTensor | None = None, class_labels: torch.Tensor | None = None, return_dict: bool = True) StaDiffuserOutput | Tuple[source]¶
Forward pass of the STADiffuser model.
- Parameters:
sample (torch.FloatTensor) – Input tensor of shape (batch_size, num_channels, sample_size) containing noisy inputs.
timestep (Union[torch.Tensor, float, int]) – Batched timestep values.
spatial_coord (torch.FloatTensor, optional) – Spatial coordinates tensor of shape (batch_size, 2). Defaults to None.
class_labels (torch.LongTensor, optional) – Class labels tensor of shape (batch_size). Defaults to None.
return_dict (bool, optional) – Whether to return a StaDiffuserOutput instead of a plain tuple. Defaults to True.
- Returns:
If return_dict is True, returns a StaDiffuserOutput instance, otherwise returns a tuple.
- Return type:
Dataset¶
This file contains the functionals and classes for dataset_hub construction and processing.
- class stadiffuser.dataset.MaskNode(data: torch_geometric.data.Data, mask: ndarray)[source]¶
Bases:
objectA class to mask node features in a graph and remove edges connected to masked nodes.
This class provides a way to mask specific nodes in a graph and remove the edges that are connected to these nodes. It keeps track of the masked nodes and the edges that have been removed, allowing for the creation of a modified graph data object with the masked nodes and remaining edges.
- Parameters:
data (torch_geometric.data.Data) – The input data containing the node features data.x and the edges data.edge_index.
mask (numpy.ndarray) – A boolean array indicating which nodes to mask. True for masked nodes and False for unmasked nodes.
- mask¶
The boolean mask indicating which nodes are masked.
- Type:
numpy.ndarray
- masked_nodes¶
The indices of the masked nodes.
- Type:
numpy.ndarray
- remained_nodes¶
The indices of the nodes that remain after masking.
- Type:
numpy.ndarray
- node_mapping¶
A dictionary mapping the indices of the remained nodes to a new consecutive index.
- Type:
- get_data(data)[source]¶
Create a new data object with masked node features and removed edges.
This method applies the mask to the node features and removes edges connected to masked nodes. It returns a new Data object with the modified node features and edge indices.
- Parameters:
data (torch_geometric.data.Data) – The input data containing the node features data.x and the edges data.edge_index.
- Returns:
The modified data with masked node features and removed edges.
- Return type:
torch_geometric.data.Data
- class stadiffuser.dataset.TripletSampler(adata: scanpy.AnnData, target: str | None = None, reference: str | None = None, use_batch: str | None = None, use_rep: str | None = None, num_neighbors: int = 30, random_seed: int = 0)[source]¶
Bases:
objectConstruct triplet data for training from batch-disjoint AnnData and PyG Data objects.
This class creates triplets consisting of an anchor node, a positive node (nearest neighbor from a specified reference batch), and a negative node (randomly chosen from the target batch). It uses a specified representation from the AnnData object to find nearest neighbors and construct the triplets.
- Parameters:
adata (anndata.AnnData) – The AnnData object containing the original data of multiple batches.
target (str, optional) – The target batch id for anchor and negative nodes. Defaults to None.
reference (str, optional) – The reference batch id for positive nodes. Defaults to None.
use_batch (str, optional) – The batch id to be used for training. Defaults to None.
use_rep (str, optional) – The key for the representation in adata.obsm to be used for nearest neighbor search. Defaults to None, in which case adata.X is used.
num_neighbors (int, optional) – The number of nearest neighbors to be used for triplet construction. Defaults to 30.
random_seed (int, optional) – The random seed for numpy’s random number generator. Defaults to 0.
- target_indices¶
The indices of nodes in the target batch.
- Type:
numpy.ndarray
- reference_indices¶
The indices of nodes in the reference batch.
- Type:
numpy.ndarray
- rng¶
The random number generator for sampling negative nodes.
- Type:
numpy.random.RandomState
Examples
>>> sampler = TripletSampler(adata, target='batch1', reference='batch2', use_batch='batch', use_rep='X_pca') >>> anchor_indices, positive_indices, negative_indices = sampler.query(anchor_indices)
- query(anchor_indices)[source]¶
Query the positive and negative indices for the given anchor indices.
For each anchor index, this method finds the corresponding positive indices (nearest neighbors from the reference batch) and samples negative indices (randomly chosen from the target batch).
- Parameters:
anchor_indices (numpy) – The indices of the anchor nodes.
- stadiffuser.dataset.get_slice_loader(adata, data, batch_name, use_batch='slice_ID', batch_size=32, **kwargs)[source]¶
Get the NeighborLoader for the slice batch.
- Parameters:
adata (anndata.AnnData) – The AnnData object containing the original data of multiple batches.
data (torch_geometric.data.Data) – The pytorch_geometric.data.Data object, containing the node features data.x and the edges data.edge_index. Constructed by the prepare_data function in the stadiffuser.pipeline module.
batch_name (str) – The name of the slice batch in the use_batch column of the adata.obs.
use_batch (str) – The column name in the adata.obs to indicate the batch information.
batch_size (int) – The batch size for the NeighborLoader.
**kwargs – The additional arguments for the NeighborLoader.
- Returns:
train_loader – The NeighborLoader for the slice batch.
- Return type:
torch_geometric.loader.NeighborLoader