Tutorial 3: Antigen Specificity Prediction

Fine-tune TCRfoundation for predicting antigen specificity from TCR sequences and gene expression.

1. Setup

import warnings
warnings.filterwarnings('ignore')

import os
import scanpy as sc
import tcrfoundation as tcrf
from tcrfoundation.finetune.utils import build_results_dataframe, plot_all_metrics_bubble_chart

2. Configuration

checkpoint_path = "../TCR_foundation_model/foundation_model_best.pt"
results_dir = "../results/specificity"
num_epochs = 2 # Just for demonstration. When training from scratch, please set the epoachs as 50.
batch_size = 512

os.makedirs(results_dir, exist_ok=True)

3. Load Data

adata = sc.read("../data/speci_adata.h5ad")
print(f"Dataset: {adata.n_obs} cells, {adata.n_vars} genes")
print(f"Antigens: {adata.obs['binding_name'].nunique()} unique")
Dataset: 60114 cells, 3000 genes
Antigens: 8 unique

4. Train Specificity Classifier

results, adata_updated = tcrf.finetune.specificity.train_specificity_classifier(
    adata, 
    label_column="binding_name",
    checkpoint_path=checkpoint_path,
    num_epochs=num_epochs, 
    batch_size=batch_size
)
Number of classes: 8, Classes: ['A0201_ELAGIGILTV_MART-1_Cancer_binder', 'A0201_GILGFVFTL_Flu-MP_Influenza_binder', 'A0201_GLCTLVAML_BMLF1_EBV_binder', 'A0301_KLGGALQAK_IE-1_CMV_binder', 'A0301_RLRAEAQVK_EMNA-3A_EBV_binder', 'A1101_AVFDRKSDAK_EBNA-3B_EBV_binder', 'A1101_IVTDFSVIK_EBNA-3B_EBV_binder', 'B0801_RAKFKQLL_BZLF1_EBV_binder']
Loaded model with max_length: 30

==================== Training binding_name classifier: rna_only ====================
Loaded model with max_length: 30
Mode rna_only Epoch 1/2: Train Loss = 0.8529 | Val Loss = 0.6481 | Val Acc = 0.8000
--> Best model saved with Val Acc = 0.8000
Mode rna_only Epoch 2/2: Train Loss = 0.6360 | Val Loss = 0.6245 | Val Acc = 0.8049
--> Best model saved with Val Acc = 0.8049

==================== Training binding_name classifier: tcr_only ====================
Loaded model with max_length: 30
Mode tcr_only Epoch 1/2: Train Loss = 0.6180 | Val Loss = 0.3954 | Val Acc = 0.8919
--> Best model saved with Val Acc = 0.8919
Mode tcr_only Epoch 2/2: Train Loss = 0.3810 | Val Loss = 0.3347 | Val Acc = 0.9160
--> Best model saved with Val Acc = 0.9160

==================== Training binding_name classifier: tcra_only ====================
Loaded model with max_length: 30
Mode tcra_only Epoch 1/2: Train Loss = 0.7298 | Val Loss = 0.5022 | Val Acc = 0.8660
--> Best model saved with Val Acc = 0.8660
Mode tcra_only Epoch 2/2: Train Loss = 0.4746 | Val Loss = 0.4163 | Val Acc = 0.8968
--> Best model saved with Val Acc = 0.8968

==================== Training binding_name classifier: tcrb_only ====================
Loaded model with max_length: 30
Mode tcrb_only Epoch 1/2: Train Loss = 0.8064 | Val Loss = 0.5314 | Val Acc = 0.8597
--> Best model saved with Val Acc = 0.8597
Mode tcrb_only Epoch 2/2: Train Loss = 0.5356 | Val Loss = 0.4554 | Val Acc = 0.8839
--> Best model saved with Val Acc = 0.8839

==================== Training binding_name classifier: rna_tcr ====================
Loaded model with max_length: 30
Mode rna_tcr Epoch 1/2: Train Loss = 0.4879 | Val Loss = 0.2819 | Val Acc = 0.9296
--> Best model saved with Val Acc = 0.9296
Mode rna_tcr Epoch 2/2: Train Loss = 0.2722 | Val Loss = 0.2525 | Val Acc = 0.9357
--> Best model saved with Val Acc = 0.9357

=== Extracting embeddings and predictions for binding_name ===
Loaded model with max_length: 30

Extracting embeddings and predictions for rna_only modality...
Stored embedding in adata.obsm['X_binding_name_rna_only_emb'] with shape (60114, 128)
Stored predictions in adata.obs['binding_name_rna_only_pred']

Extracting embeddings and predictions for tcr_only modality...
Stored embedding in adata.obsm['X_binding_name_tcr_only_emb'] with shape (60114, 256)
Stored predictions in adata.obs['binding_name_tcr_only_pred']

Extracting embeddings and predictions for tcra_only modality...
Stored embedding in adata.obsm['X_binding_name_tcra_only_emb'] with shape (60114, 128)
Stored predictions in adata.obs['binding_name_tcra_only_pred']

Extracting embeddings and predictions for tcrb_only modality...
Stored embedding in adata.obsm['X_binding_name_tcrb_only_emb'] with shape (60114, 128)
Stored predictions in adata.obs['binding_name_tcrb_only_pred']

Extracting embeddings and predictions for rna_tcr modality...
Stored embedding in adata.obsm['X_binding_name_rna_tcr_emb'] with shape (60114, 256)
Stored predictions in adata.obs['binding_name_rna_tcr_pred']
Processing rna_only: 100%|██████████████████████████████████| 59/59 [00:02<00:00, 20.75it/s]
Processing tcr_only: 100%|██████████████████████████████████| 59/59 [00:02<00:00, 20.25it/s]
Processing tcra_only: 100%|█████████████████████████████████| 59/59 [00:02<00:00, 20.24it/s]
Processing tcrb_only: 100%|█████████████████████████████████| 59/59 [00:02<00:00, 20.24it/s]
Processing rna_tcr: 100%|███████████████████████████████████| 59/59 [00:02<00:00, 19.81it/s]

5. Save Results

# Save adata
adata_updated.write(f"{results_dir}/adata_binding_name.h5ad")

# Save metrics
df = build_results_dataframe(results)
df.to_csv(f"{results_dir}/binding_name_results.csv", index=False)

# Create visualization
plot_all_metrics_bubble_chart(df, results_dir, "binding_name")

print("\nResults saved to:", results_dir)
../_images/890dd482307224eb554caece8758b52a0503e010a46d68337757da7863042f5c.png
Results saved to: ../results/specificity

6. Performance Summary

print("\n" + "="*60)
print("Performance by Modality")
print("="*60)
print(df)

print("\n" + "="*60)
for mode, metrics in results.items():
    test_acc, test_f1_macro, test_f1_weighted = metrics['test']
    print(f"{mode:15s}: Acc={test_acc:.3f}, F1-macro={test_f1_macro:.3f}")
============================================================
Performance by Modality
============================================================
         Mode  Split  Accuracy  F1 Weighted  F1 Macro
0    rna_only  train  0.810856     0.782192  0.440004
1    rna_only    val  0.804924     0.776947  0.436638
2    rna_only   test  0.807053     0.777181  0.430111
3    tcr_only  train  0.918392     0.907944  0.659940
4    tcr_only    val  0.916047     0.906349  0.659977
5    tcr_only   test  0.914837     0.904417  0.653300
6   tcra_only  train  0.897431     0.887603  0.638441
7   tcra_only    val  0.896751     0.887648  0.639294
8   tcra_only   test  0.896984     0.887002  0.637773
9   tcrb_only  train  0.885477     0.876032  0.632752
10  tcrb_only    val  0.883886     0.875196  0.633631
11  tcrb_only   test  0.885451     0.875628  0.630690
12    rna_tcr  train  0.938734     0.929175  0.723491
13    rna_tcr    val  0.935677     0.926523  0.692007
14    rna_tcr   test  0.935462     0.925458  0.694275

============================================================
rna_only       : Acc=0.807, F1-macro=0.430
tcr_only       : Acc=0.915, F1-macro=0.653
tcra_only      : Acc=0.897, F1-macro=0.638
tcrb_only      : Acc=0.885, F1-macro=0.631
rna_tcr        : Acc=0.935, F1-macro=0.694