Tutorial 5: TCR-to-Gene Expression Prediction¶
Fine-tune TCRfoundation for cross-modal prediction: predicting gene expression from TCR sequences.
1. Setup¶
import warnings
warnings.filterwarnings('ignore')
import os
import json
import torch
import numpy as np
import pandas as pd
import scanpy as sc
import tcrfoundation as tcrf
# Set random seeds for reproducibility
np.random.seed(0)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
2. Configuration¶
# Configuration
config = {
"checkpoint_path": "../TCR_foundation_model/foundation_model_best.pt",
"num_epochs": 2,
"batch_size": 512,
"modalities": None, # Train all modalities: tcr_only, tcra_only, tcrb_only
"val_split": 0.2,
"test_split": 0.2,
"save_splits": False,
"save_predictions": False,
"task": "TCR2gene"
}
results_dir = f"../results/{config['task']}"
os.makedirs(results_dir, exist_ok=True)
# Save configuration
with open(f"{results_dir}/config.json", 'w') as f:
json.dump(config, f, indent=2)
print(f"Task: {config['task']}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
print(f"Results directory: {results_dir}")
Task: TCR2gene
Device: cuda
Results directory: ../results/TCR2gene
3. Load Data¶
adata = sc.read("../data/adata_finetune.h5ad")
print(f"Dataset: {adata.n_obs} cells, {adata.n_vars} genes")
print(f"\nTCR sequences:")
print(f" CDR3a: {adata.obs['CDR3a'].notna().sum()} available")
print(f" CDR3b: {adata.obs['CDR3b'].notna().sum()} available")
Dataset: 444979 cells, 3000 genes
TCR sequences:
CDR3a: 444979 available
CDR3b: 444979 available
4. Train Cross-Modal Regressor¶
This trains models to predict gene expression from TCR sequences:
TCR α+β → Gene expression
TCR α only → Gene expression
TCR β only → Gene expression
print(f"\n=== Training {config['task']} Regressor ===")
results, adata_with_predictions = tcrf.finetune.cross_modal.train_regressor(
adata,
checkpoint_path=config["checkpoint_path"],
num_epochs=config["num_epochs"],
batch_size=config["batch_size"],
modalities=config["modalities"],
val_split=config["val_split"],
test_split=config["test_split"],
save_splits=config["save_splits"],
save_predictions=config["save_predictions"]
)
print("\n✓ Training complete")
=== Training TCR2gene Regressor ===
Loaded model with max_length: 30
=== Training tcr_only regressor ===
Epoch 1/2 - Train Loss: 0.040502, Val Loss: 0.039059, Train R²: -103.7333, Val R²: -0.2391
Saved best model checkpoint (Val Loss: 0.039059)
Epoch 2/2 - Train Loss: 0.039063, Val Loss: 0.038805, Train R²: -10.9746, Val R²: -0.6951
Saved best model checkpoint (Val Loss: 0.038805)
Evaluating on test set with best model...
Test Loss: 0.038711, Test MSE: 0.038711, Test R²: -0.7039
=== Training tcra_only regressor ===
Epoch 1/2 - Train Loss: 0.038970, Val Loss: 0.038550, Train R²: -17.3715, Val R²: -1.0270
Saved best model checkpoint (Val Loss: 0.038550)
Epoch 2/2 - Train Loss: 0.038423, Val Loss: 0.038114, Train R²: -3.4055, Val R²: -1.1113
Saved best model checkpoint (Val Loss: 0.038114)
Evaluating on test set with best model...
Test Loss: 0.038035, Test MSE: 0.038034, Test R²: -1.0393
=== Training tcrb_only regressor ===
Epoch 1/2 - Train Loss: 0.038810, Val Loss: 0.038457, Train R²: -11.7012, Val R²: -0.6569
Saved best model checkpoint (Val Loss: 0.038457)
Epoch 2/2 - Train Loss: 0.038336, Val Loss: 0.038090, Train R²: -2.8365, Val R²: -1.1950
Saved best model checkpoint (Val Loss: 0.038090)
Evaluating on test set with best model...
Test Loss: 0.038028, Test MSE: 0.038028, Test R²: -1.1855
✓ Training complete
Epoch 1/2 (Train): 100%|██████████████████████████████████| 522/522 [00:21<00:00, 23.75it/s]
Epoch 1/2 (Val): 100%|████████████████████████████████████| 174/174 [00:03<00:00, 45.97it/s]
Epoch 2/2 (Train): 100%|██████████████████████████████████| 522/522 [00:22<00:00, 23.43it/s]
Epoch 2/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 72.20it/s]
Testing: 100%|████████████████████████████████████████████| 174/174 [00:03<00:00, 44.76it/s]
Epoch 1/2 (Train): 100%|██████████████████████████████████| 522/522 [00:16<00:00, 31.67it/s]
Epoch 1/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 73.45it/s]
Epoch 2/2 (Train): 100%|██████████████████████████████████| 522/522 [00:14<00:00, 37.17it/s]
Epoch 2/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 74.99it/s]
Testing: 100%|████████████████████████████████████████████| 174/174 [00:02<00:00, 74.76it/s]
Epoch 1/2 (Train): 100%|██████████████████████████████████| 522/522 [00:14<00:00, 35.80it/s]
Epoch 1/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 74.57it/s]
Epoch 2/2 (Train): 100%|██████████████████████████████████| 522/522 [00:14<00:00, 37.16it/s]
Epoch 2/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 71.41it/s]
Testing: 100%|████████████████████████████████████████████| 174/174 [00:02<00:00, 71.89it/s]
5. Performance Summary¶
# Create summary dataframe
summary_rows = []
for mode in results:
for split in ["train", "val", "test"]:
metrics = results[mode][split]
row = {
"modality": mode,
"split": split,
"loss": metrics["loss"],
"mse": metrics["mse"],
"r2": metrics["r2"]
}
summary_rows.append(row)
summary_df = pd.DataFrame(summary_rows)
summary_df.to_csv(f"{results_dir}/{config['task']}_summary.csv", index=False)
print("\n" + "="*60)
print(f"{config['task']} Results Summary")
print("="*60)
print(summary_df.to_string())
print(f"\n✓ Summary saved to: {results_dir}/{config['task']}_summary.csv")
============================================================
TCR2gene Results Summary
============================================================
modality split loss mse r2
0 tcr_only train 0.039063 0.039061 -10.974631
1 tcr_only val 0.038805 0.038805 -0.695051
2 tcr_only test 0.038711 0.038711 -0.703923
3 tcra_only train 0.038423 0.038421 -3.405451
4 tcra_only val 0.038114 0.038114 -1.111292
5 tcra_only test 0.038035 0.038034 -1.039305
6 tcrb_only train 0.038336 0.038335 -2.836521
7 tcrb_only val 0.038090 0.038090 -1.195003
8 tcrb_only test 0.038028 0.038028 -1.185543
✓ Summary saved to: ../results/TCR2gene/TCR2gene_summary.csv