{ "cells": [ { "cell_type": "markdown", "source": [ "# T4: 3D slice modeling of Drosophila embryo\n" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 10, "outputs": [], "source": [ "import scanpy as sc\n", "import torch\n", "import urllib.request\n", "import warnings\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "from diffusers import DDPMScheduler\n", "from torch_geometric.loader import NeighborLoader\n", "from stadiffuser import pipeline\n", "from stadiffuser.vae import SpaAE\n", "from stadiffuser.models import SpaUNet1DModel\n", "from stadiffuser import utils as sutils\n", "from stadiffuser import metrics\n", "from stadiffuser.dataset import get_slice_loader\n", "warnings.filterwarnings(\"ignore\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-06T08:09:39.547603100Z", "start_time": "2024-06-06T08:09:39.529605300Z" } } }, { "cell_type": "markdown", "source": [ "## Load data" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 8, "outputs": [ { "data": { "text/plain": "AnnData object with n_obs × n_vars = 14634 × 2000\n obs: 'slice_ID', 'raw_x', 'raw_y', 'new_x', 'new_y', 'new_z', 'annotation'\n var: 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'\n uns: 'hvg', 'log1p', 'spatial_net'\n obsm: 'X_umap', 'spatial'\n layers: 'raw_counts'" }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Please manually download file from https://drive.google.com/file/d/1zyZKeZljbsEqo3YqVc_2-quU1Esm55E1/view?usp=drive_link\n", "# It's ~200 MB.\n", "# load the dowloaded proceesed Stereo-seq data\n", "adata = sc.read_h5ad(\"adata_processed.h5ad\")\n", "adata" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-06T08:09:01.875594700Z", "start_time": "2024-06-06T08:09:01.698084100Z" } } }, { "cell_type": "code", "execution_count": 9, "outputs": [ { "data": { "text/plain": "slice_ID\nE16-18h_a_S11 1193\nE16-18h_a_S04 1189\nE16-18h_a_S05 1181\nE16-18h_a_S08 1131\nE16-18h_a_S09 1113\nE16-18h_a_S10 1111\nE16-18h_a_S07 1096\nE16-18h_a_S06 1076\nE16-18h_a_S12 1049\nE16-18h_a_S13 1022\nE16-18h_a_S03 1021\nE16-18h_a_S01 985\nE16-18h_a_S02 965\nE16-18h_a_S14 502\nName: count, dtype: int64" }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata.obs[\"slice_ID\"].value_counts()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-06T08:09:04.475569400Z", "start_time": "2024-06-06T08:09:04.428644800Z" } } }, { "cell_type": "code", "execution_count": 13, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "------Calculating spatial network for each batch...\n", "Calculating spatial network for batch E16-18h_a_S01...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 3790 edges, 985 cells, 3.8477 neighbors per cell on average.\n", "Calculating spatial network for batch E16-18h_a_S02...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 3718 edges, 965 cells, 3.8528 neighbors per cell on average.\n", "Calculating spatial network for batch E16-18h_a_S03...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 3932 edges, 1021 cells, 3.8511 neighbors per cell on average.\n", "Calculating spatial network for batch E16-18h_a_S04...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 4594 edges, 1189 cells, 3.8638 neighbors per cell on average.\n", "Calculating spatial network for batch E16-18h_a_S05...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 4560 edges, 1181 cells, 3.8611 neighbors per cell on average.\n", "Calculating spatial network for batch E16-18h_a_S06...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 4144 edges, 1076 cells, 3.8513 neighbors per cell on average.\n", "Calculating spatial network for batch E16-18h_a_S07...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 4224 edges, 1096 cells, 3.8540 neighbors per cell on average.\n", "Calculating spatial network for batch E16-18h_a_S08...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 4368 edges, 1131 cells, 3.8621 neighbors per cell on average.\n", "Calculating spatial network for batch E16-18h_a_S09...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 4294 edges, 1113 cells, 3.8580 neighbors per cell on average.\n", "Calculating spatial network for batch E16-18h_a_S10...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 4286 edges, 1111 cells, 3.8578 neighbors per cell on average.\n", "Calculating spatial network for batch E16-18h_a_S11...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 4610 edges, 1193 cells, 3.8642 neighbors per cell on average.\n", "Calculating spatial network for batch E16-18h_a_S12...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 4044 edges, 1049 cells, 3.8551 neighbors per cell on average.\n", "Calculating spatial network for batch E16-18h_a_S13...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 3936 edges, 1022 cells, 3.8513 neighbors per cell on average.\n", "Calculating spatial network for batch E16-18h_a_S14...\n", "------Calculating spatial graph...\n", "------Spatial graph calculated.\n", "The graph contains 1910 edges, 502 cells, 3.8048 neighbors per cell on average.\n", "------Calculating spatial bipartite network...\n", "------Spatial network calculated.\n", "Quantizing spatial coordinates...\n", "Quantize 0th dimension of spatial coordinates to 1.25, mean deviation: 0.24953966933810226, pearson correlation: 0.9998366984519741\n", "Quantize 1th dimension of spatial coordinates to 1.25, mean deviation: 0.24992933410836693, pearson correlation: 0.9994593611796051\n", "Quantize 2th dimension of spatial coordinates to 2.857142857142857, mean deviation: 2.953008132466147e-16, pearson correlation: 1.0\n" ] } ], "source": [ "adata = sutils.cal_spatial_net3D(adata, iter_comb=None, batch_id=\"slice_ID\", rad_cutoff=1.4,\n", " add_key=\"spatial_net\")\n", "new_spatial = adata.obsm[\"spatial\"].copy()\n", "new_spatial = sutils.quantize_coordination(new_spatial, methods=[(\"division\", 0.8), (\"division\", 0.8), (\"division\", 0.35)])\n", "adata.obsm[\"new_spatial\"] = new_spatial" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-06T08:17:54.754339900Z", "start_time": "2024-06-06T08:17:54.555238100Z" } } }, { "cell_type": "code", "execution_count": 21, "outputs": [ { "data": { "text/plain": "{'CNS': 0,\n 'carcass': 1,\n 'epidermis': 2,\n 'fat body': 3,\n 'foregut': 4,\n 'hemolymph': 5,\n 'midgut': 6,\n 'muscle': 7,\n 'salivary gland': 8,\n 'trachea': 9}" }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "label_name = \"annotation\"\n", "num_class_embeds = len(np.unique(adata.obs[label_name]))\n", "class_dict = dict(zip(np.unique(adata.obs[label_name]), range(num_class_embeds)))\n", "class_dict" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-09T13:57:48.992062300Z", "start_time": "2024-06-09T13:57:48.978060200Z" } } }, { "cell_type": "markdown", "source": [ "## Training autoencoder" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 17, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "autoencoder = SpaAE(input_dim=adata.shape[1],\n", " block_list=[\"AttnBlock\"],\n", " gat_dim=[512, 32],\n", " block_out_dims=[32, 32])" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-06T08:19:07.736106600Z", "start_time": "2024-06-06T08:19:07.696029200Z" } } }, { "cell_type": "markdown", "source": [ "### Pretraining on each slice" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 22, "outputs": [ { "data": { "text/plain": " 0%| | 0/200 [00:00