Tutorial 4: Binding Avidity Regression¶
Fine-tune TCRfoundation for predicting TCR-antigen binding avidity (quantitative binding strength).
1. Setup¶
import warnings
warnings.filterwarnings('ignore')
import os
import scanpy as sc
import tcrfoundation as tcrf
from tcrfoundation.finetune.avidity import (
train_binding_counts_regressor,
build_regression_results_dataframe,
plot_regression_metrics_charts
)
2. Configuration¶
checkpoint_path = "../TCR_foundation_model/foundation_model_best.pt"
results_dir = "../results/binding_counts"
num_epochs = 2 # Just for demonstration. When training from scratch, please set the epoachs as 50.
batch_size = 128
os.makedirs(results_dir, exist_ok=True)
3. Load Data¶
We need to combine data from two files:
adata_avidity.h5ad: Contains binding countsspeci_adata.h5ad: Contains TCR sequences and gene expression
# Load binding counts data
adata_avidity = sc.read("../data/adata_avidity.h5ad")
print(f"Avidity data: {adata_avidity.n_obs} cells")
# Load TCR and gene expression data
adata = sc.read("../data/speci_adata.h5ad")
print(f"Specificity data: {adata.n_obs} cells, {adata.n_vars} genes")
# Transfer binding counts and splits to main adata
adata.obsm['binding_counts'] = adata_avidity.obsm['binding_counts']
adata.obs['set'] = adata_avidity.obs['set']
# Check data
n_antigens = adata.obsm['binding_counts'].shape[1]
print(f"\nBinding counts for {n_antigens} antigens")
print(f"Data splits: {adata.obs['set'].value_counts().to_dict()}")
Avidity data: 60114 cells
Specificity data: 60114 cells, 3000 genes
Binding counts for 8 antigens
Data splits: {'train': 38472, 'test': 12023, 'val': 9619}
4. Train Avidity Regressor¶
print("\n=== Training Binding Avidity Regressor ===")
results = train_binding_counts_regressor(
adata,
checkpoint_path=checkpoint_path,
num_epochs=num_epochs,
batch_size=batch_size,
return_training_history=False
)
print("\n✓ Training complete")
=== Training Binding Avidity Regressor ===
Output dimension: 8
==================== Training binding_counts regressor: rna_only ====================
Loaded model with max_length: 30
Mode rna_only Epoch 1/2: Train Loss = 2868.7943 | Val Loss = 2214.6270 | Val R² = -0.1805
--> Best model saved with Val R² = -0.1805
Mode rna_only Epoch 2/2: Train Loss = 2833.1833 | Val Loss = 2176.2131 | Val R² = -0.1682
--> Best model saved with Val R² = -0.1682
==================== Training binding_counts regressor: tcr_only ====================
Loaded model with max_length: 30
Mode tcr_only Epoch 1/2: Train Loss = 2810.2117 | Val Loss = 2118.3756 | Val R² = -0.1730
--> Best model saved with Val R² = -0.1730
Mode tcr_only Epoch 2/2: Train Loss = 2724.8559 | Val Loss = 2033.9279 | Val R² = -0.1436
--> Best model saved with Val R² = -0.1436
==================== Training binding_counts regressor: tcra_only ====================
Loaded model with max_length: 30
Mode tcra_only Epoch 1/2: Train Loss = 2841.1611 | Val Loss = 2168.9397 | Val R² = -0.1930
--> Best model saved with Val R² = -0.1930
Mode tcra_only Epoch 2/2: Train Loss = 2788.4668 | Val Loss = 2117.1266 | Val R² = -0.1762
--> Best model saved with Val R² = -0.1762
==================== Training binding_counts regressor: tcrb_only ====================
Loaded model with max_length: 30
Mode tcrb_only Epoch 1/2: Train Loss = 2847.3957 | Val Loss = 2179.0847 | Val R² = -0.1914
--> Best model saved with Val R² = -0.1914
Mode tcrb_only Epoch 2/2: Train Loss = 2801.2477 | Val Loss = 2130.1765 | Val R² = -0.1722
--> Best model saved with Val R² = -0.1722
==================== Training binding_counts regressor: rna_tcr ====================
Loaded model with max_length: 30
Mode rna_tcr Epoch 1/2: Train Loss = 2800.0972 | Val Loss = 2115.3714 | Val R² = -0.1498
--> Best model saved with Val R² = -0.1498
Mode rna_tcr Epoch 2/2: Train Loss = 2710.2179 | Val Loss = 2033.8319 | Val R² = -0.1123
--> Best model saved with Val R² = -0.1123
✓ Training complete
5. Performance Summary¶
print("\n" + "="*60)
print("Regression Performance by Modality")
print("="*60)
print(df)
print("\n" + "="*60)
print("Test Set Performance")
print("="*60)
for mode in results.keys():
test_metrics = results[mode]['test']
r2 = test_metrics['avg_r2']
mse = test_metrics['avg_mse']
mae = test_metrics['avg_mae']
print(f"{mode:15s}: R²={r2:.3f}, MSE={mse:.4f}, MAE={mae:.4f}")
============================================================
Regression Performance by Modality
============================================================
Mode Split R² MSE MAE MSLE
0 rna_only train -0.164948 2807.111328 12.578028 2.224296
1 rna_only val -0.168183 2176.213135 12.488777 2.215775
2 rna_only test -0.165725 2006.111572 12.432177 2.236494
3 tcr_only train -0.143940 2666.379150 12.329477 2.273770
4 tcr_only val -0.143627 2033.927979 12.240131 2.261877
5 tcr_only test -0.141133 1874.746582 12.190807 2.291583
6 tcra_only train -0.174499 2751.299072 12.522174 2.580707
7 tcra_only val -0.176214 2117.126465 12.430248 2.561954
8 tcra_only test -0.174020 1952.712524 12.381844 2.603213
9 tcrb_only train -0.171604 2762.735596 12.452492 2.475381
10 tcrb_only val -0.172159 2130.176514 12.357960 2.460135
11 tcrb_only test -0.169991 1964.404053 12.308035 2.493986
12 rna_tcr train -0.111450 2665.060791 12.147854 1.886841
13 rna_tcr val -0.112304 2033.831909 12.066460 1.887167
14 rna_tcr test -0.110094 1873.524170 12.010181 1.897295
============================================================
Test Set Performance
============================================================
rna_only : R²=-0.166, MSE=2006.1116, MAE=12.4322
tcr_only : R²=-0.141, MSE=1874.7466, MAE=12.1908
tcra_only : R²=-0.174, MSE=1952.7125, MAE=12.3818
tcrb_only : R²=-0.170, MSE=1964.4041, MAE=12.3080
rna_tcr : R²=-0.110, MSE=1873.5242, MAE=12.0102