Tutorial 5: TCR-to-Gene Expression Prediction

Fine-tune TCRfoundation for cross-modal prediction: predicting gene expression from TCR sequences.

1. Setup

import warnings
warnings.filterwarnings('ignore')

import os
import json
import torch
import numpy as np
import pandas as pd
import scanpy as sc
import tcrfoundation as tcrf

# Set random seeds for reproducibility
np.random.seed(0)
torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

2. Configuration

# Configuration
config = {
    "checkpoint_path": "../TCR_foundation_model/foundation_model_best.pt",
    "num_epochs": 2,
    "batch_size": 512,
    "modalities": None,  # Train all modalities: tcr_only, tcra_only, tcrb_only
    "val_split": 0.2,
    "test_split": 0.2,
    "save_splits": False,
    "save_predictions": False,
    "task": "TCR2gene"
}

results_dir = f"../results/{config['task']}"
os.makedirs(results_dir, exist_ok=True)

# Save configuration
with open(f"{results_dir}/config.json", 'w') as f:
    json.dump(config, f, indent=2)

print(f"Task: {config['task']}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
print(f"Results directory: {results_dir}")
Task: TCR2gene
Device: cuda
Results directory: ../results/TCR2gene

3. Load Data

adata = sc.read("../data/adata_finetune.h5ad")
print(f"Dataset: {adata.n_obs} cells, {adata.n_vars} genes")
print(f"\nTCR sequences:")
print(f"  CDR3a: {adata.obs['CDR3a'].notna().sum()} available")
print(f"  CDR3b: {adata.obs['CDR3b'].notna().sum()} available")
Dataset: 444979 cells, 3000 genes

TCR sequences:
  CDR3a: 444979 available
  CDR3b: 444979 available

4. Train Cross-Modal Regressor

This trains models to predict gene expression from TCR sequences:

  • TCR α+β → Gene expression

  • TCR α only → Gene expression

  • TCR β only → Gene expression

print(f"\n=== Training {config['task']} Regressor ===")

results, adata_with_predictions = tcrf.finetune.cross_modal.train_regressor(
    adata,
    checkpoint_path=config["checkpoint_path"],
    num_epochs=config["num_epochs"],
    batch_size=config["batch_size"],
    modalities=config["modalities"],
    val_split=config["val_split"],
    test_split=config["test_split"],
    save_splits=config["save_splits"],
    save_predictions=config["save_predictions"]
)

print("\n✓ Training complete")
=== Training TCR2gene Regressor ===
Loaded model with max_length: 30

=== Training tcr_only regressor ===
Epoch 1/2 - Train Loss: 0.040502, Val Loss: 0.039059, Train R²: -103.7333, Val R²: -0.2391
Saved best model checkpoint (Val Loss: 0.039059)
Epoch 2/2 - Train Loss: 0.039063, Val Loss: 0.038805, Train R²: -10.9746, Val R²: -0.6951
Saved best model checkpoint (Val Loss: 0.038805)

Evaluating on test set with best model...
Test Loss: 0.038711, Test MSE: 0.038711, Test R²: -0.7039

=== Training tcra_only regressor ===
Epoch 1/2 - Train Loss: 0.038970, Val Loss: 0.038550, Train R²: -17.3715, Val R²: -1.0270
Saved best model checkpoint (Val Loss: 0.038550)
Epoch 2/2 - Train Loss: 0.038423, Val Loss: 0.038114, Train R²: -3.4055, Val R²: -1.1113
Saved best model checkpoint (Val Loss: 0.038114)

Evaluating on test set with best model...
Test Loss: 0.038035, Test MSE: 0.038034, Test R²: -1.0393

=== Training tcrb_only regressor ===
Epoch 1/2 - Train Loss: 0.038810, Val Loss: 0.038457, Train R²: -11.7012, Val R²: -0.6569
Saved best model checkpoint (Val Loss: 0.038457)
Epoch 2/2 - Train Loss: 0.038336, Val Loss: 0.038090, Train R²: -2.8365, Val R²: -1.1950
Saved best model checkpoint (Val Loss: 0.038090)

Evaluating on test set with best model...
Test Loss: 0.038028, Test MSE: 0.038028, Test R²: -1.1855

✓ Training complete
Epoch 1/2 (Train): 100%|██████████████████████████████████| 522/522 [00:21<00:00, 23.75it/s]
Epoch 1/2 (Val): 100%|████████████████████████████████████| 174/174 [00:03<00:00, 45.97it/s]
Epoch 2/2 (Train): 100%|██████████████████████████████████| 522/522 [00:22<00:00, 23.43it/s]
Epoch 2/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 72.20it/s]
Testing: 100%|████████████████████████████████████████████| 174/174 [00:03<00:00, 44.76it/s]
Epoch 1/2 (Train): 100%|██████████████████████████████████| 522/522 [00:16<00:00, 31.67it/s]
Epoch 1/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 73.45it/s]
Epoch 2/2 (Train): 100%|██████████████████████████████████| 522/522 [00:14<00:00, 37.17it/s]
Epoch 2/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 74.99it/s]
Testing: 100%|████████████████████████████████████████████| 174/174 [00:02<00:00, 74.76it/s]
Epoch 1/2 (Train): 100%|██████████████████████████████████| 522/522 [00:14<00:00, 35.80it/s]
Epoch 1/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 74.57it/s]
Epoch 2/2 (Train): 100%|██████████████████████████████████| 522/522 [00:14<00:00, 37.16it/s]
Epoch 2/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 71.41it/s]
Testing: 100%|████████████████████████████████████████████| 174/174 [00:02<00:00, 71.89it/s]

5. Performance Summary

# Create summary dataframe
summary_rows = []
for mode in results:
    for split in ["train", "val", "test"]:
        metrics = results[mode][split]
        row = {
            "modality": mode,
            "split": split,
            "loss": metrics["loss"],
            "mse": metrics["mse"],
            "r2": metrics["r2"]
        }
        summary_rows.append(row)

summary_df = pd.DataFrame(summary_rows)
summary_df.to_csv(f"{results_dir}/{config['task']}_summary.csv", index=False)

print("\n" + "="*60)
print(f"{config['task']} Results Summary")
print("="*60)
print(summary_df.to_string())
print(f"\n✓ Summary saved to: {results_dir}/{config['task']}_summary.csv")
============================================================
TCR2gene Results Summary
============================================================
    modality  split      loss       mse         r2
0   tcr_only  train  0.039063  0.039061 -10.974631
1   tcr_only    val  0.038805  0.038805  -0.695051
2   tcr_only   test  0.038711  0.038711  -0.703923
3  tcra_only  train  0.038423  0.038421  -3.405451
4  tcra_only    val  0.038114  0.038114  -1.111292
5  tcra_only   test  0.038035  0.038034  -1.039305
6  tcrb_only  train  0.038336  0.038335  -2.836521
7  tcrb_only    val  0.038090  0.038090  -1.195003
8  tcrb_only   test  0.038028  0.038028  -1.185543

✓ Summary saved to: ../results/TCR2gene/TCR2gene_summary.csv