T4: 3D slice modeling of Drosophila embryo

[10]:
import scanpy as sc
import torch
import urllib.request
import warnings
import seaborn as sns
import matplotlib.pyplot as plt
from diffusers import DDPMScheduler
from torch_geometric.loader import NeighborLoader
from stadiffuser import pipeline
from stadiffuser.vae import SpaAE
from stadiffuser.models import SpaUNet1DModel
from stadiffuser import utils as sutils
from stadiffuser import metrics
from stadiffuser.dataset import get_slice_loader
warnings.filterwarnings("ignore")

Load data

[8]:
# Please manually download file from https://drive.google.com/file/d/1zyZKeZljbsEqo3YqVc_2-quU1Esm55E1/view?usp=drive_link
# It's ~200 MB.
# load the dowloaded proceesed Stereo-seq data
adata = sc.read_h5ad("adata_processed.h5ad")
adata
[8]:
AnnData object with n_obs × n_vars = 14634 × 2000
    obs: 'slice_ID', 'raw_x', 'raw_y', 'new_x', 'new_y', 'new_z', 'annotation'
    var: 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'hvg', 'log1p', 'spatial_net'
    obsm: 'X_umap', 'spatial'
    layers: 'raw_counts'
[9]:
adata.obs["slice_ID"].value_counts()
[9]:
slice_ID
E16-18h_a_S11    1193
E16-18h_a_S04    1189
E16-18h_a_S05    1181
E16-18h_a_S08    1131
E16-18h_a_S09    1113
E16-18h_a_S10    1111
E16-18h_a_S07    1096
E16-18h_a_S06    1076
E16-18h_a_S12    1049
E16-18h_a_S13    1022
E16-18h_a_S03    1021
E16-18h_a_S01     985
E16-18h_a_S02     965
E16-18h_a_S14     502
Name: count, dtype: int64
[13]:
adata = sutils.cal_spatial_net3D(adata, iter_comb=None, batch_id="slice_ID", rad_cutoff=1.4,
                                add_key="spatial_net")
new_spatial = adata.obsm["spatial"].copy()
new_spatial = sutils.quantize_coordination(new_spatial, methods=[("division", 0.8), ("division", 0.8), ("division", 0.35)])
adata.obsm["new_spatial"] = new_spatial
------Calculating spatial network for each batch...
Calculating spatial network for batch E16-18h_a_S01...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 3790 edges, 985 cells, 3.8477 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S02...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 3718 edges, 965 cells, 3.8528 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S03...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 3932 edges, 1021 cells, 3.8511 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S04...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4594 edges, 1189 cells, 3.8638 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S05...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4560 edges, 1181 cells, 3.8611 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S06...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4144 edges, 1076 cells, 3.8513 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S07...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4224 edges, 1096 cells, 3.8540 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S08...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4368 edges, 1131 cells, 3.8621 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S09...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4294 edges, 1113 cells, 3.8580 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S10...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4286 edges, 1111 cells, 3.8578 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S11...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4610 edges, 1193 cells, 3.8642 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S12...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4044 edges, 1049 cells, 3.8551 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S13...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 3936 edges, 1022 cells, 3.8513 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S14...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 1910 edges, 502 cells, 3.8048 neighbors per cell on average.
------Calculating spatial bipartite network...
------Spatial network calculated.
Quantizing spatial coordinates...
Quantize 0th dimension of spatial coordinates to 1.25, mean deviation: 0.24953966933810226, pearson correlation: 0.9998366984519741
Quantize 1th dimension of spatial coordinates to 1.25, mean deviation: 0.24992933410836693, pearson correlation: 0.9994593611796051
Quantize 2th dimension of spatial coordinates to 2.857142857142857, mean deviation: 2.953008132466147e-16, pearson correlation: 1.0
[21]:
import numpy as np
label_name = "annotation"
num_class_embeds = len(np.unique(adata.obs[label_name]))
class_dict = dict(zip(np.unique(adata.obs[label_name]), range(num_class_embeds)))
class_dict
[21]:
{'CNS': 0,
 'carcass': 1,
 'epidermis': 2,
 'fat body': 3,
 'foregut': 4,
 'hemolymph': 5,
 'midgut': 6,
 'muscle': 7,
 'salivary gland': 8,
 'trachea': 9}

Training autoencoder

[17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autoencoder = SpaAE(input_dim=adata.shape[1],
                        block_list=["AttnBlock"],
                        gat_dim=[512, 32],
                        block_out_dims=[32, 32])

Pretraining on each slice

[22]:
batch_list = adata.obs["slice_ID"].unique().tolist()
data = pipeline.prepare_dataset(adata, use_rep=None)
train_loaders = [get_slice_loader(adata, data, batch, use_batch="slice_ID",
                                  batch_size=256) for batch in batch_list]
autoencoder, autoencoder_loss = pipeline.pretrain_autoencoder_multi(train_loaders,
                                                                    autoencoder,
                                                                    pretrain_epochs=200,
                                                                    device=device)

Training with triplet loss to align the spot/cell embeddings

[23]:
autoencoder, autoencoder_loss = pipeline.train_autoencoder_multi(adata, autoencoder, use_batch="slice_ID",
                                                                 batch_list=batch_list,
                                                                 n_epochs=300,
                                                                 margin=1,
                                                                 lr=1e-4,
                                                                 update_interval=50,
                                                                 device=device)

Training Latent diffusion model

[20]:
import numpy as np
cond_name = "annotation"
num_class_embeds = len(np.unique(adata.obs[cond_name]))
class_dict = dict(zip(np.unique(adata.obs[cond_name]), range(num_class_embeds)))
adata.obs["label_"] = adata.obs[cond_name].map(class_dict)
class_dict
[20]:
{'CNS': 0,
 'carcass': 1,
 'epidermis': 2,
 'fat body': 3,
 'foregut': 4,
 'hemolymph': 5,
 'midgut': 6,
 'muscle': 7,
 'salivary gland': 8,
 'trachea': 9}
[24]:
adata = pipeline.get_recon(adata, autoencoder, device=device,
                           apply_normalize=False, show_progress=True, batch_mode=True)
normalizer = sutils.MinMaxNormalize(adata.obsm["latent"], dim=0)
adata.obsm["normalized_latent"] = normalizer.normalize(adata.obsm["latent"])
[25]:
# For 3D slice modeling, in_channels = time_embedding (16) + latent_emebdding (1) + z-axis embedding (concat mode)
denoiser = SpaUNet1DModel(in_channels=18, out_channels=1, spatial_encoding="sinusoidal3d",
                                      spatial3d_concat=True).to(device)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
adata.obs["label_"] = adata.obs[cond_name].map(class_dict)
data_latent = pipeline.prepare_dataset(adata, use_rep="normalized_latent", use_spatial="new_spatial",
                                       use_net="spatial_net", use_label="label_")
train_loader = NeighborLoader(data_latent, num_neighbors=[5, 3], batch_size=256)
denoiser, denoise_loss = pipeline.train_denoiser(train_loader, denoiser, noise_scheduler,
                                                 lr=1e-4, weight_decay=1e-6,
                                                 n_epochs=500,
                                                 num_class_embeds=num_class_embeds,
                                                 device=device)
-------------------Training Finished-------------------

Simulate a slice from the trained model

[26]:
data = pipeline.prepare_dataset(adata, use_net="spatial_net", use_spatial="new_spatial")
stadiff_sim = pipeline.simulate(denoiser, autoencoder, device=device, use_net="spatial_net",
                                ref_data=adata, spatial_coord=adata.obsm["new_spatial"],
                                labels = adata.obs["label_"].to_numpy(), seed=2024, normarlizer=normalizer)
Simulate with labels