T4: 3D slice modeling of Drosophila embryo¶
[10]:
import scanpy as sc
import torch
import urllib.request
import warnings
import seaborn as sns
import matplotlib.pyplot as plt
from diffusers import DDPMScheduler
from torch_geometric.loader import NeighborLoader
from stadiffuser import pipeline
from stadiffuser.vae import SpaAE
from stadiffuser.models import SpaUNet1DModel
from stadiffuser import utils as sutils
from stadiffuser import metrics
from stadiffuser.dataset import get_slice_loader
warnings.filterwarnings("ignore")
Load data¶
[8]:
# Please manually download file from https://drive.google.com/file/d/1zyZKeZljbsEqo3YqVc_2-quU1Esm55E1/view?usp=drive_link
# It's ~200 MB.
# load the dowloaded proceesed Stereo-seq data
adata = sc.read_h5ad("adata_processed.h5ad")
adata
[8]:
AnnData object with n_obs × n_vars = 14634 × 2000
obs: 'slice_ID', 'raw_x', 'raw_y', 'new_x', 'new_y', 'new_z', 'annotation'
var: 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
uns: 'hvg', 'log1p', 'spatial_net'
obsm: 'X_umap', 'spatial'
layers: 'raw_counts'
[9]:
adata.obs["slice_ID"].value_counts()
[9]:
slice_ID
E16-18h_a_S11 1193
E16-18h_a_S04 1189
E16-18h_a_S05 1181
E16-18h_a_S08 1131
E16-18h_a_S09 1113
E16-18h_a_S10 1111
E16-18h_a_S07 1096
E16-18h_a_S06 1076
E16-18h_a_S12 1049
E16-18h_a_S13 1022
E16-18h_a_S03 1021
E16-18h_a_S01 985
E16-18h_a_S02 965
E16-18h_a_S14 502
Name: count, dtype: int64
[13]:
adata = sutils.cal_spatial_net3D(adata, iter_comb=None, batch_id="slice_ID", rad_cutoff=1.4,
add_key="spatial_net")
new_spatial = adata.obsm["spatial"].copy()
new_spatial = sutils.quantize_coordination(new_spatial, methods=[("division", 0.8), ("division", 0.8), ("division", 0.35)])
adata.obsm["new_spatial"] = new_spatial
------Calculating spatial network for each batch...
Calculating spatial network for batch E16-18h_a_S01...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 3790 edges, 985 cells, 3.8477 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S02...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 3718 edges, 965 cells, 3.8528 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S03...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 3932 edges, 1021 cells, 3.8511 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S04...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4594 edges, 1189 cells, 3.8638 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S05...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4560 edges, 1181 cells, 3.8611 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S06...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4144 edges, 1076 cells, 3.8513 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S07...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4224 edges, 1096 cells, 3.8540 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S08...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4368 edges, 1131 cells, 3.8621 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S09...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4294 edges, 1113 cells, 3.8580 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S10...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4286 edges, 1111 cells, 3.8578 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S11...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4610 edges, 1193 cells, 3.8642 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S12...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4044 edges, 1049 cells, 3.8551 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S13...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 3936 edges, 1022 cells, 3.8513 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S14...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 1910 edges, 502 cells, 3.8048 neighbors per cell on average.
------Calculating spatial bipartite network...
------Spatial network calculated.
Quantizing spatial coordinates...
Quantize 0th dimension of spatial coordinates to 1.25, mean deviation: 0.24953966933810226, pearson correlation: 0.9998366984519741
Quantize 1th dimension of spatial coordinates to 1.25, mean deviation: 0.24992933410836693, pearson correlation: 0.9994593611796051
Quantize 2th dimension of spatial coordinates to 2.857142857142857, mean deviation: 2.953008132466147e-16, pearson correlation: 1.0
[21]:
import numpy as np
label_name = "annotation"
num_class_embeds = len(np.unique(adata.obs[label_name]))
class_dict = dict(zip(np.unique(adata.obs[label_name]), range(num_class_embeds)))
class_dict
[21]:
{'CNS': 0,
'carcass': 1,
'epidermis': 2,
'fat body': 3,
'foregut': 4,
'hemolymph': 5,
'midgut': 6,
'muscle': 7,
'salivary gland': 8,
'trachea': 9}
Training autoencoder¶
[17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autoencoder = SpaAE(input_dim=adata.shape[1],
block_list=["AttnBlock"],
gat_dim=[512, 32],
block_out_dims=[32, 32])
Pretraining on each slice¶
[22]:
batch_list = adata.obs["slice_ID"].unique().tolist()
data = pipeline.prepare_dataset(adata, use_rep=None)
train_loaders = [get_slice_loader(adata, data, batch, use_batch="slice_ID",
batch_size=256) for batch in batch_list]
autoencoder, autoencoder_loss = pipeline.pretrain_autoencoder_multi(train_loaders,
autoencoder,
pretrain_epochs=200,
device=device)
Training with triplet loss to align the spot/cell embeddings¶
[23]:
autoencoder, autoencoder_loss = pipeline.train_autoencoder_multi(adata, autoencoder, use_batch="slice_ID",
batch_list=batch_list,
n_epochs=300,
margin=1,
lr=1e-4,
update_interval=50,
device=device)
Training Latent diffusion model¶
[20]:
import numpy as np
cond_name = "annotation"
num_class_embeds = len(np.unique(adata.obs[cond_name]))
class_dict = dict(zip(np.unique(adata.obs[cond_name]), range(num_class_embeds)))
adata.obs["label_"] = adata.obs[cond_name].map(class_dict)
class_dict
[20]:
{'CNS': 0,
'carcass': 1,
'epidermis': 2,
'fat body': 3,
'foregut': 4,
'hemolymph': 5,
'midgut': 6,
'muscle': 7,
'salivary gland': 8,
'trachea': 9}
[24]:
adata = pipeline.get_recon(adata, autoencoder, device=device,
apply_normalize=False, show_progress=True, batch_mode=True)
normalizer = sutils.MinMaxNormalize(adata.obsm["latent"], dim=0)
adata.obsm["normalized_latent"] = normalizer.normalize(adata.obsm["latent"])
[25]:
# For 3D slice modeling, in_channels = time_embedding (16) + latent_emebdding (1) + z-axis embedding (concat mode)
denoiser = SpaUNet1DModel(in_channels=18, out_channels=1, spatial_encoding="sinusoidal3d",
spatial3d_concat=True).to(device)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
adata.obs["label_"] = adata.obs[cond_name].map(class_dict)
data_latent = pipeline.prepare_dataset(adata, use_rep="normalized_latent", use_spatial="new_spatial",
use_net="spatial_net", use_label="label_")
train_loader = NeighborLoader(data_latent, num_neighbors=[5, 3], batch_size=256)
denoiser, denoise_loss = pipeline.train_denoiser(train_loader, denoiser, noise_scheduler,
lr=1e-4, weight_decay=1e-6,
n_epochs=500,
num_class_embeds=num_class_embeds,
device=device)
-------------------Training Finished-------------------
Simulate a slice from the trained model¶
[26]:
data = pipeline.prepare_dataset(adata, use_net="spatial_net", use_spatial="new_spatial")
stadiff_sim = pipeline.simulate(denoiser, autoencoder, device=device, use_net="spatial_net",
ref_data=adata, spatial_coord=adata.obsm["new_spatial"],
labels = adata.obs["label_"].to_numpy(), seed=2024, normarlizer=normalizer)
Simulate with labels