{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 4: Binding Avidity Regression\n", "\n", "Fine-tune TCRfoundation for predicting TCR-antigen binding avidity (quantitative binding strength)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "import os\n", "import scanpy as sc\n", "import tcrfoundation as tcrf\n", "from tcrfoundation.finetune.avidity import (\n", " train_binding_counts_regressor,\n", " build_regression_results_dataframe,\n", " plot_regression_metrics_charts\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Configuration" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "checkpoint_path = \"../TCR_foundation_model/foundation_model_best.pt\"\n", "results_dir = \"../results/binding_counts\"\n", "num_epochs = 2 # Just for demonstration. When training from scratch, please set the epoachs as 50.\n", "batch_size = 128\n", "\n", "os.makedirs(results_dir, exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Load Data\n", "\n", "We need to combine data from two files:\n", "- `adata_avidity.h5ad`: Contains binding counts\n", "- `speci_adata.h5ad`: Contains TCR sequences and gene expression" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Avidity data: 60114 cells\n", "Specificity data: 60114 cells, 3000 genes\n", "\n", "Binding counts for 8 antigens\n", "Data splits: {'train': 38472, 'test': 12023, 'val': 9619}\n" ] } ], "source": [ "# Load binding counts data\n", "adata_avidity = sc.read(\"../data/adata_avidity.h5ad\")\n", "print(f\"Avidity data: {adata_avidity.n_obs} cells\")\n", "\n", "# Load TCR and gene expression data\n", "adata = sc.read(\"../data/speci_adata.h5ad\")\n", "print(f\"Specificity data: {adata.n_obs} cells, {adata.n_vars} genes\")\n", "\n", "# Transfer binding counts and splits to main adata\n", "adata.obsm['binding_counts'] = adata_avidity.obsm['binding_counts']\n", "adata.obs['set'] = adata_avidity.obs['set']\n", "\n", "# Check data\n", "n_antigens = adata.obsm['binding_counts'].shape[1]\n", "print(f\"\\nBinding counts for {n_antigens} antigens\")\n", "print(f\"Data splits: {adata.obs['set'].value_counts().to_dict()}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Train Avidity Regressor" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "=== Training Binding Avidity Regressor ===\n", "Output dimension: 8\n", "\n", "==================== Training binding_counts regressor: rna_only ====================\n", "Loaded model with max_length: 30\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Mode rna_only Epoch 1/2: Train Loss = 2868.7943 | Val Loss = 2214.6270 | Val R² = -0.1805\n", "--> Best model saved with Val R² = -0.1805\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Mode rna_only Epoch 2/2: Train Loss = 2833.1833 | Val Loss = 2176.2131 | Val R² = -0.1682\n", "--> Best model saved with Val R² = -0.1682\n", "\n", "==================== Training binding_counts regressor: tcr_only ====================\n", "Loaded model with max_length: 30\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Mode tcr_only Epoch 1/2: Train Loss = 2810.2117 | Val Loss = 2118.3756 | Val R² = -0.1730\n", "--> Best model saved with Val R² = -0.1730\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Mode tcr_only Epoch 2/2: Train Loss = 2724.8559 | Val Loss = 2033.9279 | Val R² = -0.1436\n", "--> Best model saved with Val R² = -0.1436\n", "\n", "==================== Training binding_counts regressor: tcra_only ====================\n", "Loaded model with max_length: 30\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Mode tcra_only Epoch 1/2: Train Loss = 2841.1611 | Val Loss = 2168.9397 | Val R² = -0.1930\n", "--> Best model saved with Val R² = -0.1930\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Mode tcra_only Epoch 2/2: Train Loss = 2788.4668 | Val Loss = 2117.1266 | Val R² = -0.1762\n", "--> Best model saved with Val R² = -0.1762\n", "\n", "==================== Training binding_counts regressor: tcrb_only ====================\n", "Loaded model with max_length: 30\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Mode tcrb_only Epoch 1/2: Train Loss = 2847.3957 | Val Loss = 2179.0847 | Val R² = -0.1914\n", "--> Best model saved with Val R² = -0.1914\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Mode tcrb_only Epoch 2/2: Train Loss = 2801.2477 | Val Loss = 2130.1765 | Val R² = -0.1722\n", "--> Best model saved with Val R² = -0.1722\n", "\n", "==================== Training binding_counts regressor: rna_tcr ====================\n", "Loaded model with max_length: 30\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Mode rna_tcr Epoch 1/2: Train Loss = 2800.0972 | Val Loss = 2115.3714 | Val R² = -0.1498\n", "--> Best model saved with Val R² = -0.1498\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Mode rna_tcr Epoch 2/2: Train Loss = 2710.2179 | Val Loss = 2033.8319 | Val R² = -0.1123\n", "--> Best model saved with Val R² = -0.1123\n", "\n", "✓ Training complete\n" ] } ], "source": [ "print(\"\\n=== Training Binding Avidity Regressor ===\")\n", "\n", "results = train_binding_counts_regressor(\n", " adata,\n", " checkpoint_path=checkpoint_path,\n", " num_epochs=num_epochs,\n", " batch_size=batch_size,\n", " return_training_history=False\n", ")\n", "\n", "print(\"\\n✓ Training complete\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Performance Summary" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "============================================================\n", "Regression Performance by Modality\n", "============================================================\n", " Mode Split R² MSE MAE MSLE\n", "0 rna_only train -0.164948 2807.111328 12.578028 2.224296\n", "1 rna_only val -0.168183 2176.213135 12.488777 2.215775\n", "2 rna_only test -0.165725 2006.111572 12.432177 2.236494\n", "3 tcr_only train -0.143940 2666.379150 12.329477 2.273770\n", "4 tcr_only val -0.143627 2033.927979 12.240131 2.261877\n", "5 tcr_only test -0.141133 1874.746582 12.190807 2.291583\n", "6 tcra_only train -0.174499 2751.299072 12.522174 2.580707\n", "7 tcra_only val -0.176214 2117.126465 12.430248 2.561954\n", "8 tcra_only test -0.174020 1952.712524 12.381844 2.603213\n", "9 tcrb_only train -0.171604 2762.735596 12.452492 2.475381\n", "10 tcrb_only val -0.172159 2130.176514 12.357960 2.460135\n", "11 tcrb_only test -0.169991 1964.404053 12.308035 2.493986\n", "12 rna_tcr train -0.111450 2665.060791 12.147854 1.886841\n", "13 rna_tcr val -0.112304 2033.831909 12.066460 1.887167\n", "14 rna_tcr test -0.110094 1873.524170 12.010181 1.897295\n", "\n", "============================================================\n", "Test Set Performance\n", "============================================================\n", "rna_only : R²=-0.166, MSE=2006.1116, MAE=12.4322\n", "tcr_only : R²=-0.141, MSE=1874.7466, MAE=12.1908\n", "tcra_only : R²=-0.174, MSE=1952.7125, MAE=12.3818\n", "tcrb_only : R²=-0.170, MSE=1964.4041, MAE=12.3080\n", "rna_tcr : R²=-0.110, MSE=1873.5242, MAE=12.0102\n" ] } ], "source": [ "print(\"\\n\" + \"=\"*60)\n", "print(\"Regression Performance by Modality\")\n", "print(\"=\"*60)\n", "print(df)\n", "\n", "print(\"\\n\" + \"=\"*60)\n", "print(\"Test Set Performance\")\n", "print(\"=\"*60)\n", "for mode in results.keys():\n", " test_metrics = results[mode]['test']\n", " r2 = test_metrics['avg_r2']\n", " mse = test_metrics['avg_mse']\n", " mae = test_metrics['avg_mae']\n", " print(f\"{mode:15s}: R²={r2:.3f}, MSE={mse:.4f}, MAE={mae:.4f}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 4 }