Tutorial 4: Binding Avidity Regression

Fine-tune TCRfoundation for predicting TCR-antigen binding avidity (quantitative binding strength).

1. Setup

import warnings
warnings.filterwarnings('ignore')

import os
import scanpy as sc
import tcrfoundation as tcrf
from tcrfoundation.finetune.avidity import (
    train_binding_counts_regressor,
    build_regression_results_dataframe,
    plot_regression_metrics_charts
)

2. Configuration

checkpoint_path = "../TCR_foundation_model/foundation_model_best.pt"
results_dir = "../results/binding_counts"
num_epochs = 2 # Just for demonstration. When training from scratch, please set the epoachs as 50.
batch_size = 128

os.makedirs(results_dir, exist_ok=True)

3. Load Data

We need to combine data from two files:

  • adata_avidity.h5ad: Contains binding counts

  • speci_adata.h5ad: Contains TCR sequences and gene expression

# Load binding counts data
adata_avidity = sc.read("../data/adata_avidity.h5ad")
print(f"Avidity data: {adata_avidity.n_obs} cells")

# Load TCR and gene expression data
adata = sc.read("../data/speci_adata.h5ad")
print(f"Specificity data: {adata.n_obs} cells, {adata.n_vars} genes")

# Transfer binding counts and splits to main adata
adata.obsm['binding_counts'] = adata_avidity.obsm['binding_counts']
adata.obs['set'] = adata_avidity.obs['set']

# Check data
n_antigens = adata.obsm['binding_counts'].shape[1]
print(f"\nBinding counts for {n_antigens} antigens")
print(f"Data splits: {adata.obs['set'].value_counts().to_dict()}")
Avidity data: 60114 cells
Specificity data: 60114 cells, 3000 genes

Binding counts for 8 antigens
Data splits: {'train': 38472, 'test': 12023, 'val': 9619}

4. Train Avidity Regressor

print("\n=== Training Binding Avidity Regressor ===")

results = train_binding_counts_regressor(
    adata,
    checkpoint_path=checkpoint_path,
    num_epochs=num_epochs,
    batch_size=batch_size,
    return_training_history=False
)

print("\n✓ Training complete")
=== Training Binding Avidity Regressor ===
Output dimension: 8

==================== Training binding_counts regressor: rna_only ====================
Loaded model with max_length: 30
Mode rna_only Epoch 1/2: Train Loss = 2868.7943 | Val Loss = 2214.6270 | Val R² = -0.1805
--> Best model saved with Val R² = -0.1805
Mode rna_only Epoch 2/2: Train Loss = 2833.1833 | Val Loss = 2176.2131 | Val R² = -0.1682
--> Best model saved with Val R² = -0.1682

==================== Training binding_counts regressor: tcr_only ====================
Loaded model with max_length: 30
Mode tcr_only Epoch 1/2: Train Loss = 2810.2117 | Val Loss = 2118.3756 | Val R² = -0.1730
--> Best model saved with Val R² = -0.1730
Mode tcr_only Epoch 2/2: Train Loss = 2724.8559 | Val Loss = 2033.9279 | Val R² = -0.1436
--> Best model saved with Val R² = -0.1436

==================== Training binding_counts regressor: tcra_only ====================
Loaded model with max_length: 30
Mode tcra_only Epoch 1/2: Train Loss = 2841.1611 | Val Loss = 2168.9397 | Val R² = -0.1930
--> Best model saved with Val R² = -0.1930
Mode tcra_only Epoch 2/2: Train Loss = 2788.4668 | Val Loss = 2117.1266 | Val R² = -0.1762
--> Best model saved with Val R² = -0.1762

==================== Training binding_counts regressor: tcrb_only ====================
Loaded model with max_length: 30
Mode tcrb_only Epoch 1/2: Train Loss = 2847.3957 | Val Loss = 2179.0847 | Val R² = -0.1914
--> Best model saved with Val R² = -0.1914
Mode tcrb_only Epoch 2/2: Train Loss = 2801.2477 | Val Loss = 2130.1765 | Val R² = -0.1722
--> Best model saved with Val R² = -0.1722

==================== Training binding_counts regressor: rna_tcr ====================
Loaded model with max_length: 30
Mode rna_tcr Epoch 1/2: Train Loss = 2800.0972 | Val Loss = 2115.3714 | Val R² = -0.1498
--> Best model saved with Val R² = -0.1498
Mode rna_tcr Epoch 2/2: Train Loss = 2710.2179 | Val Loss = 2033.8319 | Val R² = -0.1123
--> Best model saved with Val R² = -0.1123

✓ Training complete

5. Performance Summary

print("\n" + "="*60)
print("Regression Performance by Modality")
print("="*60)
print(df)

print("\n" + "="*60)
print("Test Set Performance")
print("="*60)
for mode in results.keys():
    test_metrics = results[mode]['test']
    r2 = test_metrics['avg_r2']
    mse = test_metrics['avg_mse']
    mae = test_metrics['avg_mae']
    print(f"{mode:15s}: R²={r2:.3f}, MSE={mse:.4f}, MAE={mae:.4f}")
============================================================
Regression Performance by Modality
============================================================
         Mode  Split        R²          MSE        MAE      MSLE
0    rna_only  train -0.164948  2807.111328  12.578028  2.224296
1    rna_only    val -0.168183  2176.213135  12.488777  2.215775
2    rna_only   test -0.165725  2006.111572  12.432177  2.236494
3    tcr_only  train -0.143940  2666.379150  12.329477  2.273770
4    tcr_only    val -0.143627  2033.927979  12.240131  2.261877
5    tcr_only   test -0.141133  1874.746582  12.190807  2.291583
6   tcra_only  train -0.174499  2751.299072  12.522174  2.580707
7   tcra_only    val -0.176214  2117.126465  12.430248  2.561954
8   tcra_only   test -0.174020  1952.712524  12.381844  2.603213
9   tcrb_only  train -0.171604  2762.735596  12.452492  2.475381
10  tcrb_only    val -0.172159  2130.176514  12.357960  2.460135
11  tcrb_only   test -0.169991  1964.404053  12.308035  2.493986
12    rna_tcr  train -0.111450  2665.060791  12.147854  1.886841
13    rna_tcr    val -0.112304  2033.831909  12.066460  1.887167
14    rna_tcr   test -0.110094  1873.524170  12.010181  1.897295

============================================================
Test Set Performance
============================================================
rna_only       : R²=-0.166, MSE=2006.1116, MAE=12.4322
tcr_only       : R²=-0.141, MSE=1874.7466, MAE=12.1908
tcra_only      : R²=-0.174, MSE=1952.7125, MAE=12.3818
tcrb_only      : R²=-0.170, MSE=1964.4041, MAE=12.3080
rna_tcr        : R²=-0.110, MSE=1873.5242, MAE=12.0102