from torch import Tensor import torch def pool(last_hidden_states: Tensor, attention_mask: Tensor, pool_type: str) -> Tensor: last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) if pool_type == "avg": emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] elif pool_type == "weighted_avg": emb = last_hidden.sum(dim=1) elif pool_type == "cls": emb = last_hidden[:, 0] elif pool_type == "last": left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: emb = last_hidden[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden.shape[0] emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] else: raise ValueError(f"pool_type {pool_type} not supported") return emb