{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 5: TCR-to-Gene Expression Prediction\n", "\n", "Fine-tune TCRfoundation for cross-modal prediction: predicting gene expression from TCR sequences." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "import os\n", "import json\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "import scanpy as sc\n", "import tcrfoundation as tcrf\n", "\n", "# Set random seeds for reproducibility\n", "np.random.seed(0)\n", "torch.manual_seed(0)\n", "if torch.cuda.is_available():\n", " torch.cuda.manual_seed_all(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Configuration" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Task: TCR2gene\n", "Device: cuda\n", "Results directory: ../results/TCR2gene\n" ] } ], "source": [ "# Configuration\n", "config = {\n", " \"checkpoint_path\": \"../TCR_foundation_model/foundation_model_best.pt\",\n", " \"num_epochs\": 2,\n", " \"batch_size\": 512,\n", " \"modalities\": None, # Train all modalities: tcr_only, tcra_only, tcrb_only\n", " \"val_split\": 0.2,\n", " \"test_split\": 0.2,\n", " \"save_splits\": False,\n", " \"save_predictions\": False,\n", " \"task\": \"TCR2gene\"\n", "}\n", "\n", "results_dir = f\"../results/{config['task']}\"\n", "os.makedirs(results_dir, exist_ok=True)\n", "\n", "# Save configuration\n", "with open(f\"{results_dir}/config.json\", 'w') as f:\n", " json.dump(config, f, indent=2)\n", "\n", "print(f\"Task: {config['task']}\")\n", "print(f\"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}\")\n", "print(f\"Results directory: {results_dir}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Load Data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset: 444979 cells, 3000 genes\n", "\n", "TCR sequences:\n", " CDR3a: 444979 available\n", " CDR3b: 444979 available\n" ] } ], "source": [ "adata = sc.read(\"../data/adata_finetune.h5ad\")\n", "print(f\"Dataset: {adata.n_obs} cells, {adata.n_vars} genes\")\n", "print(f\"\\nTCR sequences:\")\n", "print(f\" CDR3a: {adata.obs['CDR3a'].notna().sum()} available\")\n", "print(f\" CDR3b: {adata.obs['CDR3b'].notna().sum()} available\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Train Cross-Modal Regressor\n", "\n", "This trains models to predict gene expression from TCR sequences:\n", "- TCR α+β → Gene expression\n", "- TCR α only → Gene expression\n", "- TCR β only → Gene expression" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "=== Training TCR2gene Regressor ===\n", "Loaded model with max_length: 30\n", "\n", "=== Training tcr_only regressor ===\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1/2 (Train): 100%|██████████████████████████████████| 522/522 [00:21<00:00, 23.75it/s]\n", "Epoch 1/2 (Val): 100%|████████████████████████████████████| 174/174 [00:03<00:00, 45.97it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2 - Train Loss: 0.040502, Val Loss: 0.039059, Train R²: -103.7333, Val R²: -0.2391\n", "Saved best model checkpoint (Val Loss: 0.039059)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 2/2 (Train): 100%|██████████████████████████████████| 522/522 [00:22<00:00, 23.43it/s]\n", "Epoch 2/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 72.20it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/2 - Train Loss: 0.039063, Val Loss: 0.038805, Train R²: -10.9746, Val R²: -0.6951\n", "Saved best model checkpoint (Val Loss: 0.038805)\n", "\n", "Evaluating on test set with best model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Testing: 100%|████████████████████████████████████████████| 174/174 [00:03<00:00, 44.76it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Test Loss: 0.038711, Test MSE: 0.038711, Test R²: -0.7039\n", "\n", "=== Training tcra_only regressor ===\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1/2 (Train): 100%|██████████████████████████████████| 522/522 [00:16<00:00, 31.67it/s]\n", "Epoch 1/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 73.45it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2 - Train Loss: 0.038970, Val Loss: 0.038550, Train R²: -17.3715, Val R²: -1.0270\n", "Saved best model checkpoint (Val Loss: 0.038550)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 2/2 (Train): 100%|██████████████████████████████████| 522/522 [00:14<00:00, 37.17it/s]\n", "Epoch 2/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 74.99it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/2 - Train Loss: 0.038423, Val Loss: 0.038114, Train R²: -3.4055, Val R²: -1.1113\n", "Saved best model checkpoint (Val Loss: 0.038114)\n", "\n", "Evaluating on test set with best model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Testing: 100%|████████████████████████████████████████████| 174/174 [00:02<00:00, 74.76it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Test Loss: 0.038035, Test MSE: 0.038034, Test R²: -1.0393\n", "\n", "=== Training tcrb_only regressor ===\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1/2 (Train): 100%|██████████████████████████████████| 522/522 [00:14<00:00, 35.80it/s]\n", "Epoch 1/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 74.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2 - Train Loss: 0.038810, Val Loss: 0.038457, Train R²: -11.7012, Val R²: -0.6569\n", "Saved best model checkpoint (Val Loss: 0.038457)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 2/2 (Train): 100%|██████████████████████████████████| 522/522 [00:14<00:00, 37.16it/s]\n", "Epoch 2/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 71.41it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/2 - Train Loss: 0.038336, Val Loss: 0.038090, Train R²: -2.8365, Val R²: -1.1950\n", "Saved best model checkpoint (Val Loss: 0.038090)\n", "\n", "Evaluating on test set with best model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Testing: 100%|████████████████████████████████████████████| 174/174 [00:02<00:00, 71.89it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Test Loss: 0.038028, Test MSE: 0.038028, Test R²: -1.1855\n", "\n", "✓ Training complete\n" ] } ], "source": [ "print(f\"\\n=== Training {config['task']} Regressor ===\")\n", "\n", "results, adata_with_predictions = tcrf.finetune.cross_modal.train_regressor(\n", " adata,\n", " checkpoint_path=config[\"checkpoint_path\"],\n", " num_epochs=config[\"num_epochs\"],\n", " batch_size=config[\"batch_size\"],\n", " modalities=config[\"modalities\"],\n", " val_split=config[\"val_split\"],\n", " test_split=config[\"test_split\"],\n", " save_splits=config[\"save_splits\"],\n", " save_predictions=config[\"save_predictions\"]\n", ")\n", "\n", "print(\"\\n✓ Training complete\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Performance Summary" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "============================================================\n", "TCR2gene Results Summary\n", "============================================================\n", " modality split loss mse r2\n", "0 tcr_only train 0.039063 0.039061 -10.974631\n", "1 tcr_only val 0.038805 0.038805 -0.695051\n", "2 tcr_only test 0.038711 0.038711 -0.703923\n", "3 tcra_only train 0.038423 0.038421 -3.405451\n", "4 tcra_only val 0.038114 0.038114 -1.111292\n", "5 tcra_only test 0.038035 0.038034 -1.039305\n", "6 tcrb_only train 0.038336 0.038335 -2.836521\n", "7 tcrb_only val 0.038090 0.038090 -1.195003\n", "8 tcrb_only test 0.038028 0.038028 -1.185543\n", "\n", "✓ Summary saved to: ../results/TCR2gene/TCR2gene_summary.csv\n" ] } ], "source": [ "# Create summary dataframe\n", "summary_rows = []\n", "for mode in results:\n", " for split in [\"train\", \"val\", \"test\"]:\n", " metrics = results[mode][split]\n", " row = {\n", " \"modality\": mode,\n", " \"split\": split,\n", " \"loss\": metrics[\"loss\"],\n", " \"mse\": metrics[\"mse\"],\n", " \"r2\": metrics[\"r2\"]\n", " }\n", " summary_rows.append(row)\n", "\n", "summary_df = pd.DataFrame(summary_rows)\n", "summary_df.to_csv(f\"{results_dir}/{config['task']}_summary.csv\", index=False)\n", "\n", "print(\"\\n\" + \"=\"*60)\n", "print(f\"{config['task']} Results Summary\")\n", "print(\"=\"*60)\n", "print(summary_df.to_string())\n", "print(f\"\\n✓ Summary saved to: {results_dir}/{config['task']}_summary.csv\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 4 }