Pre-training Compass from Scracth¶
Pretraining is a critical component of the COMPASS framework, providing biologically meaningful representations of gene expression profiles before fine-tuning on specific clinical tasks. In this workflow, we demonstrate how COMPASS can be pretrained on TCGA transcriptomic data, using bulk RNA-seq TPM matrices as input. The model employs a contrastive learning strategy, where positive and negative pairs are constructed across patients to capture robust and generalizable gene–concept relationships.
COMPASS is designed as a concept-bottleneck model: instead of learning from tens of thousands of individual genes, the encoder learns to embed transcriptomic features into 132 intermediate gene sets, which are then projected into 44 high-level TIME concepts. This hierarchical design ensures both dimensionality reduction and interpretability, grounding the learned representations in known biological processes.
During pretraining, several configurations were chosen based on empirical results and prior studies:
- Batch size: We recommend using the largest possible batch size (e.g., 2048 or 4096 on high-memory GPUs), as contrastive learning benefits from more negative samples per update. In practice, GPU memory often limits batch size when training with ~1,000–15,000 input genes, so values of 128–512 are typically used.
- Learning rate: A default learning rate of 1e-4 to 5e-5 balances stable optimization and generalization.
- Epochs and patience: Training for 30~200 epochs with an early-stopping patience of 50 epochs prevents overfitting while ensuring sufficient convergence.
- Input features: Pretraining can be done with either the full gene space (e.g., 15,672 genes) or a restricted set of concept-related genes (e.g., 1,065 genes). Restricting to concept genes improves efficiency while maintaining competitive performance, making it a practical alternative when computational resources are limited.
The pretrained COMPASS model thus serves as a feature extractor, generating stable and interpretable concept-level embeddings. These embeddings can then be fine-tuned on downstream tasks such as response prediction, survival modeling, or biomarker discovery, allowing COMPASS to integrate large-scale transcriptomic data with clinical applications.
from compass import PreTrainer, FineTuner, loadcompass
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# Load the example dataset for pretraining
# We provide sample datasets contain gene expression data for training and testing
# Ensure the data is preprocessed appropriately before use
tcga_train_sample = pd.read_csv('https://www.immuno-compass.com/download/other/tcga_example_train.tsv', sep='\t', index_col=0)
tcga_test_sample = pd.read_csv('https://www.immuno-compass.com/download/other/tcga_example_test.tsv', sep='\t', index_col=0)
# Define pretraining hyperparameters
pt_args = {'lr': 5e-5, 'batch_size': 128, 'epochs': 200, 'seed':42, 'patience':50}
pretrainer = PreTrainer(**pt_args)
# Train the model using the provided training and test datasets
# - dfcx_train: Training dataset
# - dfcx_test: Validation dataset to monitor performance
pretrainer.train(dfcx_train=tcga_train_sample,
dfcx_test=tcga_test_sample)
# Save the trained pretrainer model for future use
pretrainer.save('./results/pretrainer.pt')
Epoch: 1/200 - Train Loss: 0.9182 - Test Loss: 0.9010 Epoch: 2/200 - Train Loss: 0.9064 - Test Loss: 0.9033 Epoch: 3/200 - Train Loss: 0.8853 - Test Loss: 0.8917 Epoch: 4/200 - Train Loss: 0.8942 - Test Loss: 0.8561 Epoch: 5/200 - Train Loss: 0.8835 - Test Loss: 0.8803 Epoch: 6/200 - Train Loss: 0.8874 - Test Loss: 0.8764 Epoch: 7/200 - Train Loss: 0.8752 - Test Loss: 0.8525 Epoch: 8/200 - Train Loss: 0.8699 - Test Loss: 0.8485 Epoch: 9/200 - Train Loss: 0.8678 - Test Loss: 0.8610 Epoch: 10/200 - Train Loss: 0.8648 - Test Loss: 0.8544 Epoch: 11/200 - Train Loss: 0.8526 - Test Loss: 0.8257 Epoch: 12/200 - Train Loss: 0.8633 - Test Loss: 0.8236 Epoch: 13/200 - Train Loss: 0.8651 - Test Loss: 0.8461 Epoch: 14/200 - Train Loss: 0.8422 - Test Loss: 0.8399 Epoch: 15/200 - Train Loss: 0.8484 - Test Loss: 0.8281 Epoch: 16/200 - Train Loss: 0.8559 - Test Loss: 0.8032 Epoch: 17/200 - Train Loss: 0.8390 - Test Loss: 0.8211 Epoch: 18/200 - Train Loss: 0.8170 - Test Loss: 0.8313 Epoch: 19/200 - Train Loss: 0.8229 - Test Loss: 0.7787 Epoch: 20/200 - Train Loss: 0.7987 - Test Loss: 0.7748 Epoch: 21/200 - Train Loss: 0.8112 - Test Loss: 0.7688 Epoch: 22/200 - Train Loss: 0.7776 - Test Loss: 0.7679 Epoch: 23/200 - Train Loss: 0.7598 - Test Loss: 0.7339 Epoch: 24/200 - Train Loss: 0.8014 - Test Loss: 0.7590 Epoch: 25/200 - Train Loss: 0.7567 - Test Loss: 0.7499 Epoch: 26/200 - Train Loss: 0.7462 - Test Loss: 0.7587 Epoch: 27/200 - Train Loss: 0.7784 - Test Loss: 0.7173 Epoch: 28/200 - Train Loss: 0.7489 - Test Loss: 0.7166 Epoch: 29/200 - Train Loss: 0.7347 - Test Loss: 0.7054 Epoch: 30/200 - Train Loss: 0.7409 - Test Loss: 0.6906 Epoch: 31/200 - Train Loss: 0.6954 - Test Loss: 0.7279 Epoch: 32/200 - Train Loss: 0.7035 - Test Loss: 0.6650 Epoch: 33/200 - Train Loss: 0.6799 - Test Loss: 0.6050 Epoch: 34/200 - Train Loss: 0.6942 - Test Loss: 0.6065 Epoch: 35/200 - Train Loss: 0.6896 - Test Loss: 0.6409 Epoch: 36/200 - Train Loss: 0.6561 - Test Loss: 0.6502 Epoch: 37/200 - Train Loss: 0.6404 - Test Loss: 0.6177 Epoch: 38/200 - Train Loss: 0.6678 - Test Loss: 0.5774 Epoch: 39/200 - Train Loss: 0.6385 - Test Loss: 0.5975 Epoch: 40/200 - Train Loss: 0.6476 - Test Loss: 0.4975 Epoch: 41/200 - Train Loss: 0.6295 - Test Loss: 0.5803 Epoch: 42/200 - Train Loss: 0.6302 - Test Loss: 0.5578 Epoch: 43/200 - Train Loss: 0.6010 - Test Loss: 0.5402 Epoch: 44/200 - Train Loss: 0.6090 - Test Loss: 0.5296 Epoch: 45/200 - Train Loss: 0.5700 - Test Loss: 0.4647 Epoch: 46/200 - Train Loss: 0.5636 - Test Loss: 0.5174 Epoch: 47/200 - Train Loss: 0.5432 - Test Loss: 0.4596 Epoch: 48/200 - Train Loss: 0.5380 - Test Loss: 0.4684 Epoch: 49/200 - Train Loss: 0.5462 - Test Loss: 0.4242 Epoch: 50/200 - Train Loss: 0.5175 - Test Loss: 0.4086 Epoch: 51/200 - Train Loss: 0.5417 - Test Loss: 0.4464 Epoch: 52/200 - Train Loss: 0.4950 - Test Loss: 0.4364 Epoch: 53/200 - Train Loss: 0.4982 - Test Loss: 0.4471 Epoch: 54/200 - Train Loss: 0.5123 - Test Loss: 0.3815 Epoch: 55/200 - Train Loss: 0.5202 - Test Loss: 0.3819 Epoch: 56/200 - Train Loss: 0.5113 - Test Loss: 0.3832 Epoch: 57/200 - Train Loss: 0.4633 - Test Loss: 0.4179 Epoch: 58/200 - Train Loss: 0.4675 - Test Loss: 0.3885 Epoch: 59/200 - Train Loss: 0.4613 - Test Loss: 0.4363 Epoch: 60/200 - Train Loss: 0.4680 - Test Loss: 0.3570 Epoch: 61/200 - Train Loss: 0.4526 - Test Loss: 0.4009 Epoch: 62/200 - Train Loss: 0.4511 - Test Loss: 0.3445 Epoch: 63/200 - Train Loss: 0.4375 - Test Loss: 0.3876 Epoch: 64/200 - Train Loss: 0.4278 - Test Loss: 0.3667 Epoch: 65/200 - Train Loss: 0.4077 - Test Loss: 0.3625 Epoch: 66/200 - Train Loss: 0.4112 - Test Loss: 0.3608 Epoch: 67/200 - Train Loss: 0.4100 - Test Loss: 0.3235 Epoch: 68/200 - Train Loss: 0.3827 - Test Loss: 0.3372 Epoch: 69/200 - Train Loss: 0.4006 - Test Loss: 0.3385 Epoch: 70/200 - Train Loss: 0.3775 - Test Loss: 0.3357 Epoch: 71/200 - Train Loss: 0.3958 - Test Loss: 0.3036 Epoch: 72/200 - Train Loss: 0.3516 - Test Loss: 0.2844 Epoch: 73/200 - Train Loss: 0.3532 - Test Loss: 0.3145 Epoch: 74/200 - Train Loss: 0.3471 - Test Loss: 0.2912 Epoch: 75/200 - Train Loss: 0.3570 - Test Loss: 0.2776 Epoch: 76/200 - Train Loss: 0.3435 - Test Loss: 0.2569 Epoch: 77/200 - Train Loss: 0.3637 - Test Loss: 0.2359 Epoch: 78/200 - Train Loss: 0.3884 - Test Loss: 0.2883 Epoch: 79/200 - Train Loss: 0.3569 - Test Loss: 0.2635 Epoch: 80/200 - Train Loss: 0.3429 - Test Loss: 0.2601 Epoch: 81/200 - Train Loss: 0.3042 - Test Loss: 0.3234 Epoch: 82/200 - Train Loss: 0.3569 - Test Loss: 0.2830 Epoch: 83/200 - Train Loss: 0.3435 - Test Loss: 0.1884 Epoch: 84/200 - Train Loss: 0.3208 - Test Loss: 0.2501 Epoch: 85/200 - Train Loss: 0.3020 - Test Loss: 0.2563 Epoch: 86/200 - Train Loss: 0.3013 - Test Loss: 0.2600 Epoch: 87/200 - Train Loss: 0.2957 - Test Loss: 0.2225 Epoch: 88/200 - Train Loss: 0.2938 - Test Loss: 0.2670 Epoch: 89/200 - Train Loss: 0.2729 - Test Loss: 0.1948 Epoch: 90/200 - Train Loss: 0.3001 - Test Loss: 0.2363 Epoch: 91/200 - Train Loss: 0.2749 - Test Loss: 0.2107 Epoch: 92/200 - Train Loss: 0.3126 - Test Loss: 0.2078 Epoch: 93/200 - Train Loss: 0.2730 - Test Loss: 0.1923 Epoch: 94/200 - Train Loss: 0.2711 - Test Loss: 0.2175 Epoch: 95/200 - Train Loss: 0.2895 - Test Loss: 0.1961 Epoch: 96/200 - Train Loss: 0.2636 - Test Loss: 0.2038 Epoch: 97/200 - Train Loss: 0.2693 - Test Loss: 0.2051 Epoch: 98/200 - Train Loss: 0.2686 - Test Loss: 0.2009 Epoch: 99/200 - Train Loss: 0.2905 - Test Loss: 0.2149 Epoch: 100/200 - Train Loss: 0.2652 - Test Loss: 0.1581 Epoch: 101/200 - Train Loss: 0.2723 - Test Loss: 0.2502 Epoch: 102/200 - Train Loss: 0.2678 - Test Loss: 0.2355 Epoch: 103/200 - Train Loss: 0.2536 - Test Loss: 0.2379 Epoch: 104/200 - Train Loss: 0.2690 - Test Loss: 0.1999 Epoch: 105/200 - Train Loss: 0.2112 - Test Loss: 0.2056 Epoch: 106/200 - Train Loss: 0.2419 - Test Loss: 0.1697 Epoch: 107/200 - Train Loss: 0.2803 - Test Loss: 0.2098 Epoch: 108/200 - Train Loss: 0.2450 - Test Loss: 0.1892 Epoch: 109/200 - Train Loss: 0.2602 - Test Loss: 0.1516 Epoch: 110/200 - Train Loss: 0.2335 - Test Loss: 0.1754 Epoch: 111/200 - Train Loss: 0.2459 - Test Loss: 0.1646 Epoch: 112/200 - Train Loss: 0.2658 - Test Loss: 0.2199 Epoch: 113/200 - Train Loss: 0.2338 - Test Loss: 0.1792 Epoch: 114/200 - Train Loss: 0.2380 - Test Loss: 0.1695 Epoch: 115/200 - Train Loss: 0.2366 - Test Loss: 0.1465 Epoch: 116/200 - Train Loss: 0.2526 - Test Loss: 0.2055 Epoch: 117/200 - Train Loss: 0.2533 - Test Loss: 0.2130 Epoch: 118/200 - Train Loss: 0.2293 - Test Loss: 0.1688 Epoch: 119/200 - Train Loss: 0.2071 - Test Loss: 0.1882 Epoch: 120/200 - Train Loss: 0.2113 - Test Loss: 0.1847 Epoch: 121/200 - Train Loss: 0.2271 - Test Loss: 0.1617 Epoch: 122/200 - Train Loss: 0.2153 - Test Loss: 0.1822 Epoch: 123/200 - Train Loss: 0.2123 - Test Loss: 0.1964 Epoch: 124/200 - Train Loss: 0.2007 - Test Loss: 0.1910 Epoch: 125/200 - Train Loss: 0.2266 - Test Loss: 0.1693 Epoch: 126/200 - Train Loss: 0.2323 - Test Loss: 0.1791 Epoch: 127/200 - Train Loss: 0.2193 - Test Loss: 0.1903 Epoch: 128/200 - Train Loss: 0.2036 - Test Loss: 0.1799 Epoch: 129/200 - Train Loss: 0.2214 - Test Loss: 0.1500 Epoch: 130/200 - Train Loss: 0.2087 - Test Loss: 0.1776 Epoch: 131/200 - Train Loss: 0.2253 - Test Loss: 0.1948 Epoch: 132/200 - Train Loss: 0.2146 - Test Loss: 0.1454 Epoch: 133/200 - Train Loss: 0.1911 - Test Loss: 0.1967 Epoch: 134/200 - Train Loss: 0.2125 - Test Loss: 0.1996 Epoch: 135/200 - Train Loss: 0.2266 - Test Loss: 0.1854 Epoch: 136/200 - Train Loss: 0.2282 - Test Loss: 0.1582 Epoch: 137/200 - Train Loss: 0.2039 - Test Loss: 0.1770 Epoch: 138/200 - Train Loss: 0.2108 - Test Loss: 0.1803 Epoch: 139/200 - Train Loss: 0.2203 - Test Loss: 0.1497 Epoch: 140/200 - Train Loss: 0.2115 - Test Loss: 0.1645 Epoch: 141/200 - Train Loss: 0.2180 - Test Loss: 0.1974 Epoch: 142/200 - Train Loss: 0.1994 - Test Loss: 0.1470 Epoch: 143/200 - Train Loss: 0.2220 - Test Loss: 0.1646 Epoch: 144/200 - Train Loss: 0.2183 - Test Loss: 0.1735 Epoch: 145/200 - Train Loss: 0.2020 - Test Loss: 0.1854 Epoch: 146/200 - Train Loss: 0.2163 - Test Loss: 0.1779 Epoch: 147/200 - Train Loss: 0.2222 - Test Loss: 0.1861 Epoch: 148/200 - Train Loss: 0.2296 - Test Loss: 0.1545 Epoch: 149/200 - Train Loss: 0.1837 - Test Loss: 0.1794 Epoch: 150/200 - Train Loss: 0.2216 - Test Loss: 0.1604 Epoch: 151/200 - Train Loss: 0.2074 - Test Loss: 0.1490 Epoch: 152/200 - Train Loss: 0.1707 - Test Loss: 0.1816 Epoch: 153/200 - Train Loss: 0.2112 - Test Loss: 0.1970 Epoch: 154/200 - Train Loss: 0.1894 - Test Loss: 0.1778 Epoch: 155/200 - Train Loss: 0.1955 - Test Loss: 0.1932 Epoch: 156/200 - Train Loss: 0.1959 - Test Loss: 0.1548 Epoch: 157/200 - Train Loss: 0.1938 - Test Loss: 0.1726 Epoch: 158/200 - Train Loss: 0.1973 - Test Loss: 0.1600 Epoch: 159/200 - Train Loss: 0.2095 - Test Loss: 0.1335 Epoch: 160/200 - Train Loss: 0.1989 - Test Loss: 0.2028 Epoch: 161/200 - Train Loss: 0.1996 - Test Loss: 0.1246 Epoch: 162/200 - Train Loss: 0.2156 - Test Loss: 0.2214 Epoch: 163/200 - Train Loss: 0.1825 - Test Loss: 0.1586 Epoch: 164/200 - Train Loss: 0.1815 - Test Loss: 0.2088 Epoch: 165/200 - Train Loss: 0.1846 - Test Loss: 0.1498 Epoch: 166/200 - Train Loss: 0.2028 - Test Loss: 0.1328 Epoch: 167/200 - Train Loss: 0.1783 - Test Loss: 0.1745 Epoch: 168/200 - Train Loss: 0.2150 - Test Loss: 0.1787 Epoch: 169/200 - Train Loss: 0.1834 - Test Loss: 0.1495 Epoch: 170/200 - Train Loss: 0.1653 - Test Loss: 0.1458 Epoch: 171/200 - Train Loss: 0.1568 - Test Loss: 0.1704 Epoch: 172/200 - Train Loss: 0.1802 - Test Loss: 0.1604 Epoch: 173/200 - Train Loss: 0.1885 - Test Loss: 0.1627 Epoch: 174/200 - Train Loss: 0.2020 - Test Loss: 0.1451 Epoch: 175/200 - Train Loss: 0.1677 - Test Loss: 0.1249 Epoch: 176/200 - Train Loss: 0.1799 - Test Loss: 0.1656 Epoch: 177/200 - Train Loss: 0.1966 - Test Loss: 0.1407 Epoch: 178/200 - Train Loss: 0.2004 - Test Loss: 0.1722 Epoch: 179/200 - Train Loss: 0.1940 - Test Loss: 0.1642 Epoch: 180/200 - Train Loss: 0.1862 - Test Loss: 0.1568 Epoch: 181/200 - Train Loss: 0.2092 - Test Loss: 0.1861 Epoch: 182/200 - Train Loss: 0.1805 - Test Loss: 0.0967 Epoch: 183/200 - Train Loss: 0.1786 - Test Loss: 0.1628 Epoch: 184/200 - Train Loss: 0.1784 - Test Loss: 0.1474 Epoch: 185/200 - Train Loss: 0.1878 - Test Loss: 0.1496 Epoch: 186/200 - Train Loss: 0.1900 - Test Loss: 0.1454 Epoch: 187/200 - Train Loss: 0.1884 - Test Loss: 0.1470 Epoch: 188/200 - Train Loss: 0.1664 - Test Loss: 0.1888 Epoch: 189/200 - Train Loss: 0.1773 - Test Loss: 0.1426 Epoch: 190/200 - Train Loss: 0.1655 - Test Loss: 0.1220 Epoch: 191/200 - Train Loss: 0.1738 - Test Loss: 0.1681 Epoch: 192/200 - Train Loss: 0.1868 - Test Loss: 0.1962 Epoch: 193/200 - Train Loss: 0.1986 - Test Loss: 0.1481 Epoch: 194/200 - Train Loss: 0.1836 - Test Loss: 0.1575 Epoch: 195/200 - Train Loss: 0.1905 - Test Loss: 0.1380 Epoch: 196/200 - Train Loss: 0.1836 - Test Loss: 0.1558 Epoch: 197/200 - Train Loss: 0.1852 - Test Loss: 0.1512 Epoch: 198/200 - Train Loss: 0.1873 - Test Loss: 0.1557 Epoch: 199/200 - Train Loss: 0.1769 - Test Loss: 0.1119 Epoch: 200/200 - Train Loss: 0.1537 - Test Loss: 0.1628 Saving final model... Best validation loss: 0.09671530872583389 Saving best model on epoch: 182 Saving the model to ./results/pretrainer.pt