lhallee commited on
Commit
5790905
·
verified ·
1 Parent(s): 2b3b1f4

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +77 -31
modeling_fastesm.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  import os
 
4
  from torch.nn import functional as F
5
  from torch.utils.data import Dataset as TorchDataset
6
  from torch.utils.data import DataLoader as DataLoader
@@ -500,79 +501,124 @@ class Pooler:
500
  self.pooling_options = {
501
  'mean': self.mean_pooling,
502
  'max': self.max_pooling,
503
- 'min': self.min_pooling,
504
  'norm': self.norm_pooling,
505
- 'prod': self.prod_pooling,
506
  'median': self.median_pooling,
507
  'std': self.std_pooling,
508
  'var': self.var_pooling,
509
  'cls': self.cls_pooling,
 
510
  }
511
 
512
- def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  if attention_mask is None:
514
  return emb.mean(dim=1)
515
  else:
516
  attention_mask = attention_mask.unsqueeze(-1)
517
  return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
518
 
519
- def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
520
  if attention_mask is None:
521
  return emb.max(dim=1).values
522
  else:
523
  attention_mask = attention_mask.unsqueeze(-1)
524
  return (emb * attention_mask).max(dim=1).values
525
-
526
- def min_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
527
- if attention_mask is None:
528
- return emb.min(dim=1).values
529
- else:
530
- attention_mask = attention_mask.unsqueeze(-1)
531
- return (emb * attention_mask).min(dim=1).values
532
 
533
- def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
534
  if attention_mask is None:
535
  return emb.norm(dim=1, p=2)
536
  else:
537
  attention_mask = attention_mask.unsqueeze(-1)
538
  return (emb * attention_mask).norm(dim=1, p=2)
539
 
540
- def prod_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
541
- length = emb.shape[1]
542
- if attention_mask is None:
543
- return emb.prod(dim=1) / length
544
- else:
545
- attention_mask = attention_mask.unsqueeze(-1)
546
- return ((emb * attention_mask).prod(dim=1) / attention_mask.sum(dim=1)) / length
547
-
548
- def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
549
  if attention_mask is None:
550
  return emb.median(dim=1).values
551
  else:
552
  attention_mask = attention_mask.unsqueeze(-1)
553
  return (emb * attention_mask).median(dim=1).values
554
 
555
- def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
556
  if attention_mask is None:
557
  return emb.std(dim=1)
558
  else:
559
- attention_mask = attention_mask.unsqueeze(-1)
560
- return (emb * attention_mask).std(dim=1)
 
561
 
562
- def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
563
  if attention_mask is None:
564
  return emb.var(dim=1)
565
  else:
566
- attention_mask = attention_mask.unsqueeze(-1)
567
- return (emb * attention_mask).var(dim=1)
568
-
569
- def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
 
 
 
 
 
 
 
 
570
  return emb[:, 0, :]
571
 
572
- def __call__(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # [mean, max]
 
 
 
 
 
573
  final_emb = []
574
  for pooling_type in self.pooling_types:
575
- final_emb.append(self.pooling_options[pooling_type](emb, attention_mask)) # (b, d)
576
  return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
577
 
578
 
 
1
  import torch
2
  import torch.nn as nn
3
  import os
4
+ import networkx as nx
5
  from torch.nn import functional as F
6
  from torch.utils.data import Dataset as TorchDataset
7
  from torch.utils.data import DataLoader as DataLoader
 
501
  self.pooling_options = {
502
  'mean': self.mean_pooling,
503
  'max': self.max_pooling,
 
504
  'norm': self.norm_pooling,
 
505
  'median': self.median_pooling,
506
  'std': self.std_pooling,
507
  'var': self.var_pooling,
508
  'cls': self.cls_pooling,
509
+ 'parti': self._pool_parti,
510
  }
511
 
512
+ def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
513
+ maxed_attentions = torch.max(attentions, dim=1)[0]
514
+ return maxed_attentions
515
+
516
+ def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"):
517
+ # Run PageRank on the attention matrix converted to a graph.
518
+ # Raises exceptions if the graph doesn't match the token sequence or has no edges.
519
+ # Returns the PageRank scores for each token node.
520
+ G = self._convert_to_graph(attention_matrix)
521
+ if G.number_of_nodes() != attention_matrix.shape[0]:
522
+ raise Exception(
523
+ f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.")
524
+ if G.number_of_edges() == 0:
525
+ raise Exception(f"You don't seem to have any attention edges left in the graph.")
526
+
527
+ return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
528
+
529
+ def _convert_to_graph(self, matrix):
530
+ # Convert a matrix (e.g., attention scores) to a directed graph using networkx.
531
+ # Each element in the matrix represents a directed edge with a weight.
532
+ G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
533
+ return G
534
+
535
+ def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None):
536
+ # Remove keys where attention_mask is 0
537
+ if attention_mask is not None:
538
+ for k in list(dict_importance.keys()):
539
+ if attention_mask[k] == 0:
540
+ del dict_importance[k]
541
+
542
+ #dict_importance[0] # remove cls
543
+ #dict_importance[-1] # remove eos
544
+ total = sum(dict_importance.values())
545
+ return np.array([v / total for _, v in dict_importance.items()])
546
+
547
+ def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
548
+ maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
549
+ # emb is (b, L, d), maxed_attentions is (b, L, L)
550
+ emb_pooled = []
551
+ for e, a, mask in zip(emb, maxed_attentions, attention_mask):
552
+ dict_importance = self._page_rank(a)
553
+ importance_weights = self._calculate_importance_weights(dict_importance, mask)
554
+ num_tokens = int(mask.sum().item())
555
+ emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0))
556
+ pooled = torch.tensor(np.array(emb_pooled))
557
+ return pooled
558
+
559
+ def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
560
  if attention_mask is None:
561
  return emb.mean(dim=1)
562
  else:
563
  attention_mask = attention_mask.unsqueeze(-1)
564
  return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
565
 
566
+ def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
567
  if attention_mask is None:
568
  return emb.max(dim=1).values
569
  else:
570
  attention_mask = attention_mask.unsqueeze(-1)
571
  return (emb * attention_mask).max(dim=1).values
 
 
 
 
 
 
 
572
 
573
+ def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
574
  if attention_mask is None:
575
  return emb.norm(dim=1, p=2)
576
  else:
577
  attention_mask = attention_mask.unsqueeze(-1)
578
  return (emb * attention_mask).norm(dim=1, p=2)
579
 
580
+ def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
 
 
 
 
 
 
 
 
581
  if attention_mask is None:
582
  return emb.median(dim=1).values
583
  else:
584
  attention_mask = attention_mask.unsqueeze(-1)
585
  return (emb * attention_mask).median(dim=1).values
586
 
587
+ def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
588
  if attention_mask is None:
589
  return emb.std(dim=1)
590
  else:
591
+ # Compute variance correctly over non-masked positions, then take sqrt
592
+ var = self.var_pooling(emb, attention_mask, **kwargs)
593
+ return torch.sqrt(var)
594
 
595
+ def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
596
  if attention_mask is None:
597
  return emb.var(dim=1)
598
  else:
599
+ # Correctly compute variance over only non-masked positions
600
+ attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1)
601
+ # Compute mean over non-masked positions
602
+ mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
603
+ mean = mean.unsqueeze(1) # (b, 1, d)
604
+ # Compute squared differences from mean, only over non-masked positions
605
+ squared_diff = (emb - mean) ** 2 # (b, L, d)
606
+ # Sum squared differences over non-masked positions and divide by count
607
+ var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
608
+ return var
609
+
610
+ def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
611
  return emb[:, 0, :]
612
 
613
+ def __call__(
614
+ self,
615
+ emb: torch.Tensor,
616
+ attention_mask: Optional[torch.Tensor] = None,
617
+ attentions: Optional[torch.Tensor] = None
618
+ ): # [mean, max]
619
  final_emb = []
620
  for pooling_type in self.pooling_types:
621
+ final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
622
  return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
623
 
624