"""
Pipeline for STADiffuser model
"""
import os
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
import scipy.sparse as sp
import scanpy as sc
from anndata import AnnData
from diffusers import SchedulerMixin, DDPMScheduler
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from tqdm.auto import tqdm
from .dataset import get_slice_loader, TripletSampler
[docs]
def remove_edge(G, is_masked):
"""
Remove edges in the graph.
Parameters
----------
G:
The graph in scipy.sparse.coo_matrix format.
is_masked:
The mask of edges to be removed.
Returns
-------
G:
The new graph in scipy.sparse.coo_matrix format.
"""
is_masked = is_masked.astype(int)
mask = np.tile(is_masked, (G.shape[0], 1))
mask = 1 - np.minimum(mask + mask.T, 1)
new_G = G.multiply(mask)
return new_G
[docs]
def prepare_dataset(adata: AnnData,
use_rep=None,
use_spatial='spatial',
use_net='spatial_net',
use_label=None,
device='cpu'):
"""
Transfer adata to pytorch_geometric dataset. This function is used for training the autoencoder model.
Side effect: add edge_list to adata.uns.
Parameters
----------
adata: AnnData
a.obsm["spatial"] must exist in adata.
factor: float
The factor to divide the spatial coordinates.
used_mask: str
The key of the mask in adata.obs.
use_rep: str
The key of the expression matrix in adata.obsm. If None, use adata.X.
use_spatial: str
The key of the spatial coordinates in adata.obsm.
use_label: str
The key of the label in adata.obs.
device: str
"cpu" or "cuda"
Notes: side effect: add edge_list to adata.uns
"""
G_df = adata.uns[use_net].copy()
cells = np.array(adata.obs_names)
cells_id_tran = dict(zip(cells, range(cells.shape[0])))
G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
G = G + sp.eye(G.shape[0])
if use_rep is not None:
x = adata.obsm[use_rep]
else:
x = adata.X
if sp.issparse(x):
x = x.todense()
spatial = adata.obsm[use_spatial]
edge_list = np.nonzero(G)
normalized_spatial = torch.LongTensor(spatial)
if use_label is not None:
label = adata.obs[use_label]
label = torch.LongTensor(label)
else:
label = None
edge_list = np.array([edge_list[0], edge_list[1]])
adata.uns['edge_list'] = edge_list
data = Data(edge_index=torch.LongTensor(edge_list),
x=torch.FloatTensor(x),
spatial=normalized_spatial,
label=label)
return data.to(device)
[docs]
def get_recon(adata, autoencoder, use_net="spatial_net", apply_normalize=True, use_rep="latent",
use_spatial="spatial", batch_mode=False, batch_size=256, num_neighbors=[5, 3],
inplace=True, device="cuda:0", show_progress=True):
"""
Get the reconstructed expression matrix and latent representation with the trained autoencoder.
Parameters
----------
adata: AnnData
The AnnData object.
autoencoder: torch.nn.Module
The trained autoencoder model.
use_net: str
The key of the network in adata.uns.
apply_normalize: bool
Whether to normalize the data. If True, normalize the data. For count data, it is recommended to normalize the data.
use_rep: str
The key of the representation in adata.obsm.
use_spatial: str
The key of the spatial coordinates in adata.obsm.
batch_mode: bool
Whether to use batch mode. If True, use batch mode. For large dataset_hub, it is recommended to use batch mode.
batch_size: int
The batch size in batch mode to compute the latent representation.
num_neighbors: list
The number of neighbors to compute the latent representation, list of integers.
inplace: bool
Whether to change the adata object in place. If True, change the adata object in place.
device: str
The device to compute the latent representation. default is "cuda:0".
show_progress: bool
Whether to show the progress bar.
Returns
-------
adata_recon: AnnData
The AnnData object with the reconstructed expression matrix and latent representation.
"""
adata_recon = adata.copy()
if apply_normalize:
sc.pp.normalize_total(adata_recon, target_sum=1e4)
sc.pp.log1p(adata_recon)
data = prepare_dataset(adata_recon, use_net=use_net, use_spatial=use_spatial)
autoencoder = autoencoder.to(device)
autoencoder.eval()
if not batch_mode:
data = data.to(device)
with torch.no_grad():
latent, recon = autoencoder(data.x, data.edge_index)
# realase memory
del data
else:
train_loader = NeighborLoader(data, num_neighbors=num_neighbors, batch_size=batch_size, shuffle=False)
latent = []
recon = []
n_batches = len(train_loader)
if show_progress:
pbar = tqdm(total=n_batches)
for batch in train_loader:
if show_progress:
pbar.update(1)
pbar.set_description("Batch: {} / {}".format(pbar.n, n_batches))
pbar.refresh()
batch = batch.to(device)
with torch.no_grad():
latent_batch, recon_batch = autoencoder(batch.x, batch.edge_index)
input_num = batch.input_id.shape[0]
latent_batch = latent_batch[:input_num, :]
recon_batch = recon_batch[:input_num, :]
latent.append(latent_batch.cpu())
recon.append(recon_batch.cpu())
latent = torch.cat(latent, dim=0)
recon = torch.cat(recon, dim=0)
if inplace:
adata_recon.X = recon.cpu().numpy()
adata_recon.obsm[use_rep] = latent.cpu().numpy()
return adata_recon
else:
return latent.cpu().numpy(), recon.cpu().numpy()
[docs]
def train_autoencoder(train_loader, model,
n_epochs=1000, gradient_clip=5, lr=1e-4, weight_decay=1e-6,
save_dir=None, model_name="autoencoder",
device="cpu",
check_points=None):
"""
Train the autoencoder model.
Parameters
----------
train_loader: torch_geometric.loader or torch_geometric.data.DataLoader
The data loader for training.
model: torch.nn.Module
The autoencoder model.
n_epochs: int
The number of epochs to train.
gradient_clip: float
The gradient clip.
lr: float
The learning rate.
weight_decay: float
The weight decay.
save_dir: str
The directory to save the model. If None, do not save the model.
model_name: str
The name of the model.
device: str
The device to train the model.
check_points: list
The epochs to save the model--list of integers.
Returns
-------
model: torch.nn.Module
The trained model.
loss_list: list
The list of loss during training.
"""
# check if save_dir is not None and exists
if save_dir is not None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
loss_list = []
pbar = tqdm(range(n_epochs))
model.train()
for epoch in range(1, n_epochs + 1):
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
z, out = model(batch.x, batch.edge_index)
loss = F.mse_loss(out, batch.x)
loss_list.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
optimizer.step()
pbar.set_description(f"Epoch: {epoch}, Loss: {loss.item():.4f}")
loss_list.append(loss.item())
scheduler.step()
pbar.update(1)
if check_points is not None and epoch in check_points:
torch.save(model, os.path.join(save_dir, "{}_{}.pth".format(model_name, epoch)))
if save_dir is not None:
torch.save(model, os.path.join(save_dir, "{}.pth".format(model_name)))
return model, loss_list
[docs]
def pretrain_autoencoder_multi(train_loaders, model,
pretrain_epochs=100, lr=1e-4, weight_decay=1e-6,
save_dir=None, model_name="autoencoder_pre", check_points=None,
device="cpu"):
"""
Prerain the autoencoder model on each slice.
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=pretrain_epochs)
loss_list = []
pbar = tqdm(range(pretrain_epochs))
model = model.to(device)
model.train()
for epoch in pbar:
for i, loader in enumerate(train_loaders):
for batch_id, batch in enumerate(loader):
batch = batch.to(device)
optimizer.zero_grad()
z, out = model(batch.x, batch.edge_index)
loss = F.mse_loss(out, batch.x)
loss.backward()
optimizer.step()
pbar.set_description("Pretrain|Epoch: {}, Batch: {}-{}, Loss: {:.4f}".format(epoch, i + 1, batch_id, loss.item()))
scheduler.step()
loss_list.append(loss.item())
scheduler.step()
pbar.update(1)
if check_points is not None and epoch in check_points:
torch.save(model, os.path.join(save_dir, "{}_{}.pth".format(model_name, epoch)))
if save_dir is not None:
torch.save(model, os.path.join(save_dir, "{}.pth".format(model_name)))
return model, loss_list
[docs]
def train_autoencoder_multi(adata, model, use_batch=None, batch_list=None,
n_epochs=400, lr=1e-4, weight_decay=1e-6, margin=1.0, update_interval=50, mnn_neighbors=15,
save_dir=None, model_name="autoencoder_tri", device="cpu", check_points=None):
"""
Train the autoencoder model with triplet loss using multiple slices.
Parameters
----------
adata: AnnData
The AnnData object.
model: torch.nn.Module
The autoencoder model.
use_batch: str
The key of the batch in adata.obs.
batch_list: list
The list of batch names.
n_epochs: int
The number of epochs to train.
lr: float
The learning rate.
weight_decay: float
The weight decay.
margin: float
The margin of the triplet loss. Default is 1.0.
"""
# construct the train loaders
if batch_list is None:
batch_list = adata.obs[use_batch].unique()
train_loaders = []
index_mappings = []
iter_combs = [(i, i + 1) for i in range(len(batch_list) - 1)]
data = prepare_dataset(adata, use_net="spatial_net", use_spatial="spatial")
for batch_name in batch_list:
num_spots = int((adata.obs[use_batch] == batch_name).values.sum())
loader = get_slice_loader(adata, data, batch_name, use_batch=use_batch, batch_size=num_spots)
train_loaders.append(loader)
batch = next(iter(loader))
index_mappings.append({val.item(): idx for idx, val in enumerate(batch.n_id)})
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs * len(train_loaders))
triplet_loss = torch.nn.TripletMarginLoss(margin=margin)
model = model.to(device)
model.train()
loss_list = []
pbar = tqdm(range(n_epochs))
for epoch in range(1, n_epochs + 1):
if (epoch - 1) % update_interval == 0:
pbar.set_description(f"Aling|update MNN, Epoch: {epoch}")
model.eval()
adata_temp = get_recon(adata, model, device=device, apply_normalize=False,
show_progress=False, batch_mode=True)
tri_samplers = []
for (target_id, ref_id) in iter_combs:
tri_samplers.append(TripletSampler(adata_temp, target=batch_list[target_id],
use_rep="latent",
reference=batch_list[ref_id],
use_batch=use_batch,
num_neighbors=mnn_neighbors))
model.train()
for ind, (target_id, ref_id) in enumerate(iter_combs):
optimizer.zero_grad()
target_batch = next(iter(train_loaders[target_id])).to(device)
reference_batch = next(iter(train_loaders[ref_id])).to(device)
z_target, out_target = model(target_batch.x, target_batch.edge_index)
z_reference, out_reference = model(reference_batch.x, reference_batch.edge_index)
anchor_indices, positive_indices, negative_indices = tri_samplers[ind].query(
target_batch.n_id.detach().cpu().numpy())
anchor_indices = [index_mappings[target_id][i] for i in anchor_indices]
positive_indices = [index_mappings[ref_id][i] for i in positive_indices]
negative_indices = [index_mappings[target_id][i] for i in negative_indices]
loss_rmse = F.mse_loss(out_target, target_batch.x) * .5 + F.mse_loss(out_reference, reference_batch.x) * .5
z_a = z_target[anchor_indices]
z_p = z_reference[positive_indices]
z_n = z_target[negative_indices]
loss_tri = triplet_loss(z_a, z_p, z_n)
loss = loss_tri + loss_rmse
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
pbar.set_description(f"Align|Epoch: {epoch}, Loss: {loss.item():.4f}")
scheduler.step()
pbar.update(1)
loss_list.append(loss.item())
if check_points is not None and epoch in check_points:
torch.save(model, os.path.join(save_dir, "{}_{}.pth".format(model_name, epoch)))
if save_dir is not None:
torch.save(model, os.path.join(save_dir, "{}.pth".format(model_name)))
return model, loss_list
[docs]
def train_denoiser(train_loader,
model,
noise_scheduler: SchedulerMixin,
n_epochs: int = 1000,
lr: float = 1e-4,
lr_scheduler: str = "cosine_annealing",
weight_decay: float = 1e-6,
gradient_clip: float = 5,
num_class_embeds: Optional[int] = None,
save_dir: Optional[str] = None,
evaluate_interval: Optional[int] = None,
model_name: str = "denoiser",
device: str = "cpu",
check_points=None,
eval_fn=None,
eval_kwargs=None,
):
r"""
Train the denoising model with the noise scheduler.
Parameters
----------
train_loader: torch_geometric.loader or torch_geometric.data.DataLoader
The data loader for training.
model: torch.nn.Module
The denoising model.
noise_scheduler: SchedulerMixin
The noise scheduler.
n_epochs: int
The number of epochs to train.
lr: float
The learning rate.
lr_scheduler: str
The learning rate scheduler. Only "cosine_annealing" is supported.
weight_decay: float
The weight decay.
gradient_clip: float
The gradient clip.
num_class_embeds: int
The number of class embeddings if None the model is unconditional.
Returns
-------
model: torch.nn.Module
The trained model.
loss_list: list
The list of loss during training.
"""
pbar = tqdm(range(n_epochs))
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
if lr_scheduler == "cosine_annealing":
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
else:
raise NotImplementedError
loss_list = []
for epoch in range(1, n_epochs + 1):
for batch_idx, batch in enumerate(train_loader):
batch = batch.to(device)
clean_data = batch.x
clean_data = clean_data.unsqueeze(1) # (batch_size, 1, num_channels)
optimizer.zero_grad()
noise = torch.randn_like(clean_data)
timesteps = torch.randint(0, 1000, (clean_data.shape[0],), device=batch.x.device, dtype=torch.long, )
noisy_data = noise_scheduler.add_noise(clean_data, noise, timesteps)
if num_class_embeds is None:
noise_pred = model(noisy_data, timesteps, batch.spatial).sample
else:
noise_pred = model(noisy_data, timesteps, batch.spatial, batch.label).sample
loss = F.mse_loss(noise_pred, noise)
loss.backward()
loss_list.append(loss.item())
# clip gradient
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
optimizer.step()
pbar.set_description(f"Epoch: {epoch}, Loss: {loss.item():.4f}, batch_id: {batch_idx}")
lr_scheduler.step()
pbar.update(1)
if eval_fn is not None:
if epoch % evaluate_interval == 1:
eval_fn(model, noise_scheduler, epoch, **eval_kwargs)
if check_points is not None and epoch in check_points:
torch.save(model, os.path.join(save_dir, "{}_{}.pth".format(model_name, epoch)))
if save_dir is not None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
torch.save(model, os.path.join(save_dir, "{}.pth".format(model_name)))
print(" Save model to {}.pth".format(os.path.join(save_dir, "{}.pth".format(model_name))))
print("-------------------Training Finished-------------------")
return model, loss_list
[docs]
def simulate(denoiser: torch.nn.Module = None,
autoencoder: torch.nn.Module = None,
ref_data=None,
spatial_coord: np.ndarray = None,
labels: Optional[np.ndarray] = None,
noise_scheduler: SchedulerMixin = None,
init_x: torch.Tensor = None,
seed=None,
progress=True,
n_samples: int = 1,
return_tuple=True,
normarlizer=None,
add_rep="latent",
use_spatial="spatial",
device="cpu",
**kwargs):
r"""
Simulate the diffusion process of the model.
Parameters
-----------
denoiser:
the denoising model
spatial_coord:
the spatial coordinates of the spots (n_spots, 2) or (n_spots, 3)
labels:
optional, the labels of the spots when the model is conditional
noise_scheduler:
the noise scheduler
seed:
random seed to control torch and numpy
progress:
whether to show the progress bar
n_samples:
`int`, number of samples to simulate
return_tuple:
`bool`, whether to return a tuple of tensors
device:
`str`, "cpu" or "cuda"
kwargs:
other arguments to pass to the model
Returns
--------
denoised:
denoised expression matrix, shape (n_spots, n_samples * p) if `return_tuple` is False,
"""
if ref_data is not None:
assert autoencoder is not None, "autoencoder must be provided when ref_data is not None"
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)
denoiser = denoiser.to(device)
spatial_coord = torch.from_numpy(spatial_coord).to(device)
n_spots = spatial_coord.shape[0]
n_features = denoiser.sample_size
if init_x is None:
sim_embed = torch.randn((n_spots * n_samples, n_features), device=denoiser.device)
sim_embed = sim_embed.unsqueeze(1)
else:
sim_embed = init_x.unsqueeze(1)
# noise scheduler
if noise_scheduler is None:
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
# check if spatial_coord in kwargs
if labels is not None:
print("Simulate with labels")
labels = torch.from_numpy(labels).to(device).long()
with torch.no_grad():
if progress:
iterator = tqdm(noise_scheduler.timesteps)
else:
iterator = noise_scheduler.timesteps
for t in iterator:
model_output = denoiser(sim_embed, t, spatial_coord, class_labels=labels).sample
sim_embed = noise_scheduler.step(model_output, t, sim_embed).prev_sample
sim_embed = torch.squeeze(sim_embed, 1)
if n_samples > 1:
sim_embed = torch.chunk(sim_embed, n_samples, dim=0)
if not return_tuple:
sim_embed = torch.cat(sim_embed, dim=1)
if normarlizer is not None:
sim_embed = sim_embed.cpu().numpy()
sim_embed = normarlizer.denormalize(sim_embed)
sim_embed = torch.from_numpy(sim_embed).float().to(device)
if n_samples > 1 and ref_data is not None:
# return simulated x directly if n_samples > 1
return sim_embed
if ref_data is not None:
sim_data = ref_data.copy()
sim_data.obsm[use_spatial] = spatial_coord.cpu().numpy()
autoencoder = autoencoder.to(device)
# check ref_data.uns["edge_list"] exist
if "edge_list" not in ref_data.uns:
raise ValueError("ref_data.uns['edge_list'] must exist. Call prepare_dataset first.")
else:
edge_list = ref_data.uns["edge_list"]
edge_list = torch.LongTensor(edge_list).to(device)
with torch.no_grad():
sim_count = autoencoder.decode(sim_embed, edge_list)
sim_data.obsm[add_rep] = sim_embed.cpu().numpy()
sim_data.X = sim_count.cpu().numpy()
return sim_data
return sim_embed