Mengyao00's picture
initial upload
c6e7d88 verified
from torch import Tensor
import torch
def pool(last_hidden_states: Tensor,
attention_mask: Tensor,
pool_type: str) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
if pool_type == "avg":
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pool_type == "weighted_avg":
emb = last_hidden.sum(dim=1)
elif pool_type == "cls":
emb = last_hidden[:, 0]
elif pool_type == "last":
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
emb = last_hidden[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden.shape[0]
emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths]
else:
raise ValueError(f"pool_type {pool_type} not supported")
return emb