CLS pooling with ESM2

#1
by nestorgonzalovich - opened

Hi everyone,

first of all thanks a lot for making these amazingly fast versions of ESM models!
Not sure this is the right place to do this, but I just wanted to raise a point about the use of CLS pooling in ESM models, which all use RoPE attention.

I noticed that, if I use CLS pooling to perform a finetuning experiments (protein-level classification, using LoRA adapters), and I then analyse the individual attributions of tokens to the final decision using the LIG method of the captum library, I get a bias towards strong attributions towards the N-term of the protein (essentially an exponential decay, with the inital Methionine containing most of the information).

That does not make much sense biologically, and I suspected that this could be an artifact of combining CLS pooling (the default method in you implementation) with RoPE attention. Since the CLS token is artificially placed at the beginning of the protein, and RoPE attention emphasizes the transfer of information over short range distances, I think this makes the model pay too much attention to the N-term, both limiting potential performance and interpretability.

I suggest that offering an option to switch to a properly implemented mean pooling (ideally, excluding all the special tokens), or a fancier Global Attention Pooling in the classification head only could improve the model. I'd be happy to contribute if needed!

Synthyra org

Hi @nestorgonzalovich ,

Thanks for the comments!

We completely agree, using CLS pooling is not ideal and use it in the examples because it is a popular method. We've evaluated many composite pooling methods internally and have found that mean + variance pooling is extremely effective. We also recommend keeping the special tokens in mean pooling, because the CLS token acts as a natural attention sink - there is some literature on this topic. Keep in mind protein embeddings are not biological objects, they are arrays of numbers. So, summarizing them in ways that is most useful does not need to coincide with any biological intuition.

If you'd like to use mean + variance pooling to embed protein datasets, you can change the options in our default settings:

embedding_dict = model.embed_dataset(
    sequences=[
        'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
    ],
    tokenizer=model.tokenizer,
    batch_size=2, # adjust for your GPU memory
    max_len=512, # adjust for your needs
    full_embeddings=False, # if True, no pooling is performed
    embed_dtype=torch.float32, # cast to what dtype you want
    pooling_types=['mean', 'var'], # MEAN + VAR here, so each embedding will be hidden_size * 2 long
    num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
    sql=False, # if True, embeddings will be stored in SQLite database
    sql_db_path='embeddings.db',
    save=True, # if True, embeddings will be saved as a .pth file
    save_path='embeddings.pth',
)
# embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql

You can add as many pooling types as you'd like there. We also include max, min, std, etc.

If you'd like to train models with pool parti pooling (weighted mean based off of global attention) feel free to leverage our Protify project.

Hope this is helpful! Please let us know if you have any questions.
Best,
Logan

Sign up or log in to comment