Pre-training Compass from Scracth¶
Pretraining is a critical component of the COMPASS framework, providing biologically meaningful representations of gene expression profiles before fine-tuning on specific clinical tasks. In this workflow, we demonstrate how COMPASS can be pretrained on TCGA transcriptomic data, using bulk RNA-seq TPM matrices as input. The model employs a contrastive learning strategy, where positive and negative pairs are constructed across patients to capture robust and generalizable gene–concept relationships.
COMPASS is designed as a concept-bottleneck model: instead of learning from tens of thousands of individual genes, the encoder learns to embed transcriptomic features into 132 intermediate gene sets, which are then projected into 44 high-level TIME concepts. This hierarchical design ensures both dimensionality reduction and interpretability, grounding the learned representations in known biological processes.
During pretraining, several configurations were chosen based on empirical results and prior studies:
- Batch size: We recommend using the largest possible batch size (e.g., 2048 or 4096 on high-memory GPUs), as contrastive learning benefits from more negative samples per update. In practice, GPU memory often limits batch size when training with ~1,000–15,000 input genes, so values of 128–512 are typically used.
- Learning rate: A default learning rate of 1e-4 to 5e-5 balances stable optimization and generalization.
- Epochs and patience: Training for 30~200 epochs with an early-stopping patience of 50 epochs prevents overfitting while ensuring sufficient convergence.
- Input features: Pretraining can be done with either the full gene space (e.g., 15,672 genes) or a restricted set of concept-related genes (e.g., 1,065 genes). Restricting to concept genes improves efficiency while maintaining competitive performance, making it a practical alternative when computational resources are limited.
The pretrained COMPASS model thus serves as a feature extractor, generating stable and interpretable concept-level embeddings. These embeddings can then be fine-tuned on downstream tasks such as response prediction, survival modeling, or biomarker discovery, allowing COMPASS to integrate large-scale transcriptomic data with clinical applications.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
COMPASS Pretraining Script
--------------------------
This script performs unsupervised pretraining of the COMPASS model
on TCGA transcriptomic data (subset of genes for efficiency).
It uses the Performer encoder and standard contrastive/masked
representation learning objectives.
"""
# =========================================================
# 1. Environment setup
# =========================================================
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2" # Specify which GPU to use
# =========================================================
# 2. Install and check COMPASS
# =========================================================
get_ipython().system('pip install immuno-compass -U') # Ensure latest COMPASS version
import compass
print("Using COMPASS version:", compass.__version__)
/mnt/mamba_envs/envs/compass/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Using COMPASS version: 2.2
# =========================================================
# 3. Define output directory
# =========================================================
save_dir = './results/PT'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# =========================================================
# 4. Import core modules
# =========================================================
from compass import loadcompass
from compass import PreTrainer, FineTuner
from compass.tokenizer import CANCER_CODE
from compass.utils import plot_performance, score2
import pandas as pd
import numpy as np
get_ipython().run_line_magic('matplotlib', 'inline')
# =========================================================
# 5. Data loading function
# =========================================================
_GLOBALS = {}
def load_tcga(path):
"""
Load TCGA transcriptomic and label tables.
Parameters
----------
path : str
Path to the folder containing 'TCGA.TPM.TABLE' and
'TCGA.PATIENT.PROCESSED.TABLE' files.
Returns
-------
df_tpm : pd.DataFrame
TPM-normalized gene expression matrix.
df_label : pd.DataFrame
Corresponding patient-level metadata.
"""
if "df_tpm" not in _GLOBALS:
_GLOBALS["df_tpm"] = pd.read_pickle(os.path.join(path, "TCGA.TPM.TABLE"))
_GLOBALS["df_label"] = pd.read_pickle(os.path.join(path, "TCGA.PATIENT.PROCESSED.TABLE"))
return _GLOBALS["df_tpm"], _GLOBALS["df_label"]
# =========================================================
# 6. Prepare data for pretraining
# =========================================================
# Use a subset of genes for faster training (full TCGA gene set is large)
tcga_path = './data/TCGA/1065'
df_tpm, df_label = load_tcga(tcga_path)
# Encode cancer types numerically
df_cancer = df_label[["cancer_type"]]
dfcx = (
df_cancer.cancer_type
.apply(lambda x: x.replace("TCGA-", ""))
.map(CANCER_CODE)
.to_frame("cancer_code")
.join(df_tpm)
)
# Stratified split by cancer type (90% train / 10% test)
test_idx = df_cancer.groupby("cancer_type").apply(
lambda x: x.sample(frac=0.1, random_state=123).index.tolist()
).sum()
train_idx = df_cancer[~df_cancer.index.isin(test_idx)].index
dfcx_train, dfcx_test = dfcx.loc[train_idx], dfcx.loc[test_idx]
print("Training set:", dfcx_train.shape, "| Test set:", dfcx_test.shape)
Training set: (9166, 1066) | Test set: (1018, 1066)
# =========================================================
# 7. Define training configuration
# =========================================================
model_args = {
'lr': 1e-3, # Adjust to 1e-4 when using all genes
'epochs': 100,
'batch_size': 1024, # Reduce if memory is limited
'seed': 42,
'patience': 10,
'encoder': 'performer', # Use performer or transformer encoder for efficient attention
'weight_decay': 1e-7,
'triplet_margin': 1.0,
'batch_correction': 0.1,
}
data_args = {
'no_augment_prob': 0.1, # Probability of skipping augmentation, lower is better, can be zero.
'mask_p_prob': 0.7, # Probability of masking positive samples, higher is better
'mask_n_prob': 0.0,
'mask_a_prob': 0.0,
'jitter_a_std': 0.0,
'jitter_n_std': 0.0,
'jitter_p_std': 0.4, # Additive noise for feature perturbation, 0-0.5
}
pretrainer = PreTrainer(**model_args, work_dir=save_dir)
# =========================================================
# 8. Run pretraining
# =========================================================
pretrainer.train(
dfcx_train=dfcx_train,
dfcx_test=dfcx_test,
**data_args
)
# Save pretrained model checkpoint
pretrainer.save(f'{save_dir}/pretrainer.pt')
print("Pretraining completed. Model saved to:",
f'{save_dir}/pretrainer.pt')
Epoch: 1/100 - Train Loss: 1.0144 - Test Loss: 0.9703 Epoch: 2/100 - Train Loss: 0.9342 - Test Loss: 0.8863 Epoch: 3/100 - Train Loss: 0.8494 - Test Loss: 0.8369 Epoch: 4/100 - Train Loss: 0.8072 - Test Loss: 0.7723 Epoch: 5/100 - Train Loss: 0.7645 - Test Loss: 0.7309 Epoch: 6/100 - Train Loss: 0.7170 - Test Loss: 0.6854 Epoch: 7/100 - Train Loss: 0.6752 - Test Loss: 0.6475 Epoch: 8/100 - Train Loss: 0.6477 - Test Loss: 0.6102 Epoch: 9/100 - Train Loss: 0.6157 - Test Loss: 0.6098 Epoch: 10/100 - Train Loss: 0.5834 - Test Loss: 0.5875 Epoch: 11/100 - Train Loss: 0.5676 - Test Loss: 0.5510 Epoch: 12/100 - Train Loss: 0.5646 - Test Loss: 0.5699 Epoch: 13/100 - Train Loss: 0.5603 - Test Loss: 0.5529 Epoch: 14/100 - Train Loss: 0.5542 - Test Loss: 0.5597 Epoch: 15/100 - Train Loss: 0.5490 - Test Loss: 0.5533 Epoch: 16/100 - Train Loss: 0.5421 - Test Loss: 0.5303 Epoch: 17/100 - Train Loss: 0.5423 - Test Loss: 0.5391 Epoch: 18/100 - Train Loss: 0.5359 - Test Loss: 0.5298 Epoch: 19/100 - Train Loss: 0.5476 - Test Loss: 0.5438 Epoch: 20/100 - Train Loss: 0.5392 - Test Loss: 0.5297 Epoch: 21/100 - Train Loss: 0.5375 - Test Loss: 0.5006 Epoch: 22/100 - Train Loss: 0.5295 - Test Loss: 0.5360 Epoch: 23/100 - Train Loss: 0.5238 - Test Loss: 0.5032 Epoch: 24/100 - Train Loss: 0.5314 - Test Loss: 0.5298 Epoch: 25/100 - Train Loss: 0.5227 - Test Loss: 0.5233 Epoch: 26/100 - Train Loss: 0.5254 - Test Loss: 0.5071 Epoch: 27/100 - Train Loss: 0.5147 - Test Loss: 0.4760 Epoch: 28/100 - Train Loss: 0.5201 - Test Loss: 0.5232 Epoch: 29/100 - Train Loss: 0.5035 - Test Loss: 0.5269 Epoch: 30/100 - Train Loss: 0.5123 - Test Loss: 0.5028 Epoch: 31/100 - Train Loss: 0.5108 - Test Loss: 0.4977 Epoch: 32/100 - Train Loss: 0.4923 - Test Loss: 0.5208 Epoch: 33/100 - Train Loss: 0.5009 - Test Loss: 0.4910 Epoch: 34/100 - Train Loss: 0.5022 - Test Loss: 0.4895 Epoch: 35/100 - Train Loss: 0.5101 - Test Loss: 0.4892 Epoch: 36/100 - Train Loss: 0.4989 - Test Loss: 0.4846 Epoch: 37/100 - Train Loss: 0.5025 - Test Loss: 0.5126 Stopping early at epoch 37. No improvement in validation loss for 10 consecutive epochs. Saving final model... Best validation loss: 0.4760492444038391 Saving best model on epoch: 27 Saving the model to ./results/PT/pretrainer.pt Pretraining completed. Model saved to: ./results/PT/pretrainer.pt