Pre-training Compass from Scracth¶

Pretraining is a critical component of the COMPASS framework, providing biologically meaningful representations of gene expression profiles before fine-tuning on specific clinical tasks. In this workflow, we demonstrate how COMPASS can be pretrained on TCGA transcriptomic data, using bulk RNA-seq TPM matrices as input. The model employs a contrastive learning strategy, where positive and negative pairs are constructed across patients to capture robust and generalizable gene–concept relationships.

COMPASS is designed as a concept-bottleneck model: instead of learning from tens of thousands of individual genes, the encoder learns to embed transcriptomic features into 132 intermediate gene sets, which are then projected into 44 high-level TIME concepts. This hierarchical design ensures both dimensionality reduction and interpretability, grounding the learned representations in known biological processes.

During pretraining, several configurations were chosen based on empirical results and prior studies:

  • Batch size: We recommend using the largest possible batch size (e.g., 2048 or 4096 on high-memory GPUs), as contrastive learning benefits from more negative samples per update. In practice, GPU memory often limits batch size when training with ~1,000–15,000 input genes, so values of 128–512 are typically used.
  • Learning rate: A default learning rate of 1e-4 to 5e-5 balances stable optimization and generalization.
  • Epochs and patience: Training for 30~200 epochs with an early-stopping patience of 50 epochs prevents overfitting while ensuring sufficient convergence.
  • Input features: Pretraining can be done with either the full gene space (e.g., 15,672 genes) or a restricted set of concept-related genes (e.g., 1,065 genes). Restricting to concept genes improves efficiency while maintaining competitive performance, making it a practical alternative when computational resources are limited.

The pretrained COMPASS model thus serves as a feature extractor, generating stable and interpretable concept-level embeddings. These embeddings can then be fine-tuned on downstream tasks such as response prediction, survival modeling, or biomarker discovery, allowing COMPASS to integrate large-scale transcriptomic data with clinical applications.

To Perform the Pretraining, please download the TCGA dataset from Figshare first:

Figshare
In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
COMPASS Pretraining Script
--------------------------
This script performs unsupervised pretraining of the COMPASS model 
on TCGA transcriptomic data (subset of genes for efficiency). 
It uses the Performer encoder and standard contrastive/masked 
representation learning objectives.

"""


# =========================================================
# 1. Environment setup
# =========================================================
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # Specify which GPU to use
In [2]:
# =========================================================
# 2. Install and check COMPASS
# =========================================================
get_ipython().system('pip install immuno-compass -U')  # Ensure latest COMPASS version

import compass
print("Using COMPASS version:", compass.__version__)
/mnt/mamba_envs/envs/compass/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Using COMPASS version: 2.2
In [3]:
# =========================================================
# 3. Define output directory
# =========================================================
save_dir = './results/PT'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
In [4]:
# =========================================================
# 4. Import core modules
# =========================================================
from compass import loadcompass
from compass import PreTrainer, FineTuner
from compass.tokenizer import CANCER_CODE
from compass.utils import plot_performance, score2
import pandas as pd
import numpy as np

get_ipython().run_line_magic('matplotlib', 'inline')
In [5]:
# =========================================================
# 5. Data loading function
# =========================================================
_GLOBALS = {}

def load_tcga(path):
    """
    Load TCGA transcriptomic and label tables.

    Parameters
    ----------
    path : str
        Path to the folder containing 'TCGA.TPM.TABLE' and 
        'TCGA.PATIENT.PROCESSED.TABLE' files.

    Returns
    -------
    df_tpm : pd.DataFrame
        TPM-normalized gene expression matrix.
    df_label : pd.DataFrame
        Corresponding patient-level metadata.
    """
    if "df_tpm" not in _GLOBALS:
        _GLOBALS["df_tpm"] = pd.read_pickle(os.path.join(path, "TCGA.TPM.TABLE"))
        _GLOBALS["df_label"] = pd.read_pickle(os.path.join(path, "TCGA.PATIENT.PROCESSED.TABLE"))
    return _GLOBALS["df_tpm"], _GLOBALS["df_label"]
In [6]:
# =========================================================
# 6. Prepare data for pretraining
# =========================================================
# Use a subset of genes for faster training (full TCGA gene set is large)
tcga_path = './data/TCGA/1065'
df_tpm, df_label = load_tcga(tcga_path)

# Encode cancer types numerically
df_cancer = df_label[["cancer_type"]]
dfcx = (
    df_cancer.cancer_type
    .apply(lambda x: x.replace("TCGA-", ""))
    .map(CANCER_CODE)
    .to_frame("cancer_code")
    .join(df_tpm)
)

# Stratified split by cancer type (90% train / 10% test)
test_idx = df_cancer.groupby("cancer_type").apply(
    lambda x: x.sample(frac=0.1, random_state=123).index.tolist()
).sum()
train_idx = df_cancer[~df_cancer.index.isin(test_idx)].index

dfcx_train, dfcx_test = dfcx.loc[train_idx], dfcx.loc[test_idx]
print("Training set:", dfcx_train.shape, "| Test set:", dfcx_test.shape)
Training set: (9166, 1066) | Test set: (1018, 1066)
In [7]:
# =========================================================
# 7. Define training configuration
# =========================================================

model_args = {
    'lr': 1e-3,              # Adjust to 1e-4 when using all genes
    'epochs': 100,
    'batch_size': 1024,      # Reduce if memory is limited
    'seed': 42,
    'patience': 10,
    'encoder': 'performer',  # Use performer or transformer encoder for efficient attention
    'weight_decay': 1e-7,
    'triplet_margin': 1.0,
    'batch_correction': 0.1,
}

data_args = {
    'no_augment_prob': 0.1,  # Probability of skipping augmentation, lower is better, can be zero.
    'mask_p_prob': 0.7,      # Probability of masking positive samples, higher is better
    'mask_n_prob': 0.0,
    'mask_a_prob': 0.0,
    'jitter_a_std': 0.0,
    'jitter_n_std': 0.0,
    'jitter_p_std': 0.4,     # Additive noise for feature perturbation, 0-0.5
}

pretrainer = PreTrainer(**model_args, work_dir=save_dir)
In [8]:
# =========================================================
# 8. Run pretraining
# =========================================================
pretrainer.train(
    dfcx_train=dfcx_train,
    dfcx_test=dfcx_test,
    **data_args
)

# Save pretrained model checkpoint
pretrainer.save(f'{save_dir}/pretrainer.pt')
print("Pretraining completed. Model saved to:", 
      f'{save_dir}/pretrainer.pt')
Epoch: 1/100 - Train Loss: 1.0144 - Test Loss: 0.9703
Epoch: 2/100 - Train Loss: 0.9342 - Test Loss: 0.8863
Epoch: 3/100 - Train Loss: 0.8494 - Test Loss: 0.8369
Epoch: 4/100 - Train Loss: 0.8072 - Test Loss: 0.7723
Epoch: 5/100 - Train Loss: 0.7645 - Test Loss: 0.7309
Epoch: 6/100 - Train Loss: 0.7170 - Test Loss: 0.6854
Epoch: 7/100 - Train Loss: 0.6752 - Test Loss: 0.6475
Epoch: 8/100 - Train Loss: 0.6477 - Test Loss: 0.6102
Epoch: 9/100 - Train Loss: 0.6157 - Test Loss: 0.6098
Epoch: 10/100 - Train Loss: 0.5834 - Test Loss: 0.5875
Epoch: 11/100 - Train Loss: 0.5676 - Test Loss: 0.5510
Epoch: 12/100 - Train Loss: 0.5646 - Test Loss: 0.5699
Epoch: 13/100 - Train Loss: 0.5603 - Test Loss: 0.5529
Epoch: 14/100 - Train Loss: 0.5542 - Test Loss: 0.5597
Epoch: 15/100 - Train Loss: 0.5490 - Test Loss: 0.5533
Epoch: 16/100 - Train Loss: 0.5421 - Test Loss: 0.5303
Epoch: 17/100 - Train Loss: 0.5423 - Test Loss: 0.5391
Epoch: 18/100 - Train Loss: 0.5359 - Test Loss: 0.5298
Epoch: 19/100 - Train Loss: 0.5476 - Test Loss: 0.5438
Epoch: 20/100 - Train Loss: 0.5392 - Test Loss: 0.5297
Epoch: 21/100 - Train Loss: 0.5375 - Test Loss: 0.5006
Epoch: 22/100 - Train Loss: 0.5295 - Test Loss: 0.5360
Epoch: 23/100 - Train Loss: 0.5238 - Test Loss: 0.5032
Epoch: 24/100 - Train Loss: 0.5314 - Test Loss: 0.5298
Epoch: 25/100 - Train Loss: 0.5227 - Test Loss: 0.5233
Epoch: 26/100 - Train Loss: 0.5254 - Test Loss: 0.5071
Epoch: 27/100 - Train Loss: 0.5147 - Test Loss: 0.4760
Epoch: 28/100 - Train Loss: 0.5201 - Test Loss: 0.5232
Epoch: 29/100 - Train Loss: 0.5035 - Test Loss: 0.5269
Epoch: 30/100 - Train Loss: 0.5123 - Test Loss: 0.5028
Epoch: 31/100 - Train Loss: 0.5108 - Test Loss: 0.4977
Epoch: 32/100 - Train Loss: 0.4923 - Test Loss: 0.5208
Epoch: 33/100 - Train Loss: 0.5009 - Test Loss: 0.4910
Epoch: 34/100 - Train Loss: 0.5022 - Test Loss: 0.4895
Epoch: 35/100 - Train Loss: 0.5101 - Test Loss: 0.4892
Epoch: 36/100 - Train Loss: 0.4989 - Test Loss: 0.4846
Epoch: 37/100 - Train Loss: 0.5025 - Test Loss: 0.5126
Stopping early at epoch 37. No improvement in validation loss for 10 consecutive epochs.
Saving final model...

Best validation loss: 0.4760492444038391

Saving best model on epoch: 27

Saving the model to ./results/PT/pretrainer.pt
Pretraining completed. Model saved to: ./results/PT/pretrainer.pt
No description has been provided for this image