CLS pooling with ESM2
Hi everyone,
first of all thanks a lot for making these amazingly fast versions of ESM models!
Not sure this is the right place to do this, but I just wanted to raise a point about the use of CLS pooling in ESM models, which all use RoPE attention.
I noticed that, if I use CLS pooling to perform a finetuning experiments (protein-level classification, using LoRA adapters), and I then analyse the individual attributions of tokens to the final decision using the LIG method of the captum library, I get a bias towards strong attributions towards the N-term of the protein (essentially an exponential decay, with the inital Methionine containing most of the information).
That does not make much sense biologically, and I suspected that this could be an artifact of combining CLS pooling (the default method in you implementation) with RoPE attention. Since the CLS token is artificially placed at the beginning of the protein, and RoPE attention emphasizes the transfer of information over short range distances, I think this makes the model pay too much attention to the N-term, both limiting potential performance and interpretability.
I suggest that offering an option to switch to a properly implemented mean pooling (ideally, excluding all the special tokens), or a fancier Global Attention Pooling in the classification head only could improve the model. I'd be happy to contribute if needed!
Hi @nestorgonzalovich ,
Thanks for the comments!
We completely agree, using CLS pooling is not ideal and use it in the examples because it is a popular method. We've evaluated many composite pooling methods internally and have found that mean + variance pooling is extremely effective. We also recommend keeping the special tokens in mean pooling, because the CLS token acts as a natural attention sink - there is some literature on this topic. Keep in mind protein embeddings are not biological objects, they are arrays of numbers. So, summarizing them in ways that is most useful does not need to coincide with any biological intuition.
If you'd like to use mean + variance pooling to embed protein datasets, you can change the options in our default settings:
embedding_dict = model.embed_dataset(
sequences=[
'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
],
tokenizer=model.tokenizer,
batch_size=2, # adjust for your GPU memory
max_len=512, # adjust for your needs
full_embeddings=False, # if True, no pooling is performed
embed_dtype=torch.float32, # cast to what dtype you want
pooling_types=['mean', 'var'], # MEAN + VAR here, so each embedding will be hidden_size * 2 long
num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
sql=False, # if True, embeddings will be stored in SQLite database
sql_db_path='embeddings.db',
save=True, # if True, embeddings will be saved as a .pth file
save_path='embeddings.pth',
)
# embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
You can add as many pooling types as you'd like there. We also include max, min, std, etc.
If you'd like to train models with pool parti pooling (weighted mean based off of global attention) feel free to leverage our Protify project.
Hope this is helpful! Please let us know if you have any questions.
Best,
Logan