Introduction¶

To address different research needs and data regimes in immunotherapy response modeling, COMPASS provides a flexible and transparent framework for model training, fine-tuning, and feature extraction. While the pre-trained COMPASS model captures pan-cancer transcriptomic representations from TCGA, the fine-tuning module enables users to efficiently adapt these biological embeddings to cohort-specific or domain-shifted datasets.

To clarify the package’s functionality and respond to reviewer feedback, we now include a schematic diagram illustrating the possible training workflows and fine-tuning pipelines (see Figure X). This diagram summarizes how COMPASS can operate either as a pre-trained feature extractor or as a trainable model under four distinct fine-tuning configurations, arranged from most to least trainable parameters:

  1. Full Fine-Tuning (FFT) — all modules (encoder, projector, and decoder) are trainable, allowing full model adaptation. Use case: large, well-annotated cohorts where end-to-end optimization captures subtle, domain-specific immune variations.

  2. Partial Fine-Tuning (PFT) — the projector and decoder are fine-tuned while the transformer encoder remains frozen. Use case: cross-cohort transfer scenarios, where TCGA-learned gene embeddings remain valid and only high-level concept alignment is required.

  3. Linear-Probing Fine-Tuning (LFT) — trains a simple linear classifier on top of frozen concept embeddings. Use case: small or low-resource datasets emphasizing interpretability and fast benchmarking without re-training deep layers.

  4. Nonparametric Fine-Tuning (NFT) — performs similarity-based or label-representation learning, aligning patient concept embeddings with reference label vectors (e.g., responder vs. non-responder) without updating model weights. Use case: zero-shot or domain-specific settings, where meaningful prediction depends on semantic proximity rather than sample size; domain homogeneity is often more critical than data volume.

Together, these strategies define a continuum of adaptation options—from fully trainable (FFT) to nonparametric, zero-shot inference (NFT). The workflow diagram highlights how users can start from the pre-trained COMPASS backbone and select the most suitable fine-tuning strategy according to data scale, cohort similarity, and computational constraints.

By explicitly outlining these pipelines and their applications, this revision clarifies what the COMPASS package offers in terms of training, pre-trained model utilization, and downstream customization, thereby enhancing accessibility and interpretability for both computational and biomedical audiences.

InĀ [Ā ]:
from compass import loadcompass, FineTuner
InĀ [2]:
import os
from tqdm import tqdm
from itertools import chain
import pandas as pd
import numpy as np
import random, torch
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style = 'white', font_scale=1.3)
import warnings
warnings.filterwarnings("ignore")
InĀ [3]:
def onehot(S):
    assert type(S) == pd.Series, 'Input type should be pd.Series'
    dfd = pd.get_dummies(S, dummy_na=True)
    nanidx = dfd[dfd[np.nan].astype(bool)].index
    dfd.loc[nanidx, :] = np.nan
    dfd = dfd.drop(columns=[np.nan])*1.
    cols = dfd.sum().sort_values(ascending=False).index.tolist()
    dfd = dfd[cols]
    return dfd

load pretrainer and the datasets¶

InĀ [4]:
## load pretrainer
pretrainer = loadcompass('https://www.immuno-compass.com/download/model/pretrainer.pt')

## read data
df_label = pd.read_pickle('./tmpignore/ITRP.PATIENT.TABLE')
df_tpm = pd.read_pickle('./tmpignore/ITRP.TPM.TABLE')
df_tpm.shape, df_label.shape
Out[4]:
((1133, 15672), (1133, 110))
InĀ [5]:
train_idx = df_label[df_label.cohort != 'Gide'].index
test_idx = df_label[df_label.cohort == 'Gide'].index
InĀ [Ā ]:
 
InĀ [6]:
df_tpm.head()
Out[6]:
A1BG A1CF A2M A2ML1 A4GALT A4GNT AAAS AACS AADAC AADAT ... ZWILCH ZWINT ZXDA ZXDB ZXDC ZYG11A ZYG11B ZYX ZZEF1 ZZZ3
Index
IMVigor210-0257bb-ar-0257bbb 0.205851 2.155888 659.745279 20.704149 7.936608 0.000000 82.356025 6.818171 1.341996 8.806979 ... 19.827670 35.762746 3.052251 4.759638 23.932628 0.353733 53.545112 33.434797 63.913951 21.918333
IMVigor210-025b45-ar-025b45c 1.868506 0.000000 368.595425 7.356325 14.221725 0.012419 66.000702 16.410020 74.672523 9.551180 ... 21.562821 7.727498 2.840277 4.399035 10.118828 0.425108 30.963466 87.048508 50.694129 15.833533
IMVigor210-032c64-ar-032c642 0.074416 0.023730 194.673484 1.016972 58.998834 0.012352 105.698176 15.143666 0.028117 2.441625 ... 28.428787 29.953545 3.286946 4.307672 13.970757 1.582359 19.573847 94.128930 47.873491 10.933422
IMVigor210-0571f1-ar-0571f17 2.306056 0.000000 325.709796 18.747406 10.965047 0.018950 76.854569 7.491749 0.043138 7.001308 ... 23.462814 18.647978 5.777748 5.938934 12.687338 1.001439 20.971129 50.101555 78.684380 14.659834
IMVigor210-065890-ar-0658907 0.000000 0.024102 182.904400 23.246839 3.457102 0.000000 66.561993 14.851419 120.742181 25.713897 ... 30.468925 16.782164 4.356220 7.165276 17.453367 0.552250 33.347260 20.544651 41.852786 18.699320

5 rows Ɨ 15672 columns

InĀ [7]:
dfcx = df_label.cancer_type.map(CANCER_CODE).to_frame('cancer_code').join(df_tpm)
df_task = onehot(df_label.response_label)
dfcx.head()
Out[7]:
cancer_code A1BG A1CF A2M A2ML1 A4GALT A4GNT AAAS AACS AADAC ... ZWILCH ZWINT ZXDA ZXDB ZXDC ZYG11A ZYG11B ZYX ZZEF1 ZZZ3
Index
IMVigor210-0257bb-ar-0257bbb 1 0.205851 2.155888 659.745279 20.704149 7.936608 0.000000 82.356025 6.818171 1.341996 ... 19.827670 35.762746 3.052251 4.759638 23.932628 0.353733 53.545112 33.434797 63.913951 21.918333
IMVigor210-025b45-ar-025b45c 1 1.868506 0.000000 368.595425 7.356325 14.221725 0.012419 66.000702 16.410020 74.672523 ... 21.562821 7.727498 2.840277 4.399035 10.118828 0.425108 30.963466 87.048508 50.694129 15.833533
IMVigor210-032c64-ar-032c642 1 0.074416 0.023730 194.673484 1.016972 58.998834 0.012352 105.698176 15.143666 0.028117 ... 28.428787 29.953545 3.286946 4.307672 13.970757 1.582359 19.573847 94.128930 47.873491 10.933422
IMVigor210-0571f1-ar-0571f17 1 2.306056 0.000000 325.709796 18.747406 10.965047 0.018950 76.854569 7.491749 0.043138 ... 23.462814 18.647978 5.777748 5.938934 12.687338 1.001439 20.971129 50.101555 78.684380 14.659834
IMVigor210-065890-ar-0658907 1 0.000000 0.024102 182.904400 23.246839 3.457102 0.000000 66.561993 14.851419 120.742181 ... 30.468925 16.782164 4.356220 7.165276 17.453367 0.552250 33.347260 20.544651 41.852786 18.699320

5 rows Ɨ 15673 columns

InĀ [8]:
df_task.head()
Out[8]:
NR R
Index
IMVigor210-0257bb-ar-0257bbb 1.0 0.0
IMVigor210-025b45-ar-025b45c 1.0 0.0
IMVigor210-032c64-ar-032c642 1.0 0.0
IMVigor210-0571f1-ar-0571f17 1.0 0.0
IMVigor210-065890-ar-0658907 0.0 1.0
InĀ [10]:
dfcx_train = dfcx.loc[train_idx]
dfy_train = df_task.loc[train_idx]

dfcx_test = dfcx.loc[test_idx]
dfy_test = df_task.loc[test_idx]

print(len(dfcx_train), len(dfcx_test))
1060 73

Initialize and perform fine-tuning¶

finetuning parameters for three parametric Fine-Tuning methods: PFT, LFT, FFT

InĀ [Ā ]:
params = {'mode': 'PFT', ## can change to LFT, FFT
        'seed':42,
        'lr': 1e-2,
        'device':'cuda',
        'weight_decay': 1e-3,
        'batch_size':32, 
        'max_epochs': 20,
        'with_wandb': False,
        'save_best_model':False,
        'verbose': False}
InĀ [11]:
finetuner = FineTuner(pretrainer, **params)
finetuner = finetuner.tune(dfcx_train = dfcx_train, dfy_train = dfy_train)
100%|##########| 20/20 [10:32<00:00, 31.63s/it]
InĀ [12]:
finetuner.save('./tmpignore/finetuner_without_gide.pt')
Saving the model to ./tmpignore/finetuner_without_gide.pt

Evaluate the model performance¶

InĀ [13]:
dfe, df_pred = finetuner.predict(dfcx_test, batch_size = 16)
100%|##########| 5/5 [00:00<00:00,  6.67it/s]
InĀ [14]:
dfp = dfy_test.join(df_pred)
y_true, y_prob, y_pred = dfp['R'], dfp[1], dfp[[0, 1]].idxmax(axis=1)
fig = plot_performance(y_true, y_prob, y_pred)
No description has been provided for this image
InĀ [15]:
pd.DataFrame(finetuner.performance,
             columns = ['epoch', 'f1', 'mcc', 'prc', 'roc', 'acc']).set_index('epoch').plot()
Out[15]:
<Axes: xlabel='epoch'>
No description has been provided for this image
InĀ [16]:
finetuner.best_epoch
Out[16]:
17