compress normalization func
Browse files
README.md
CHANGED
|
@@ -25240,8 +25240,7 @@ outputs = session.run(None, inputs)[0]
|
|
| 25240 |
|
| 25241 |
# Apply mean pooling to 'outputs' to get a single representation of each text
|
| 25242 |
embeddings = mean_pooling(outputs, input_text["attention_mask"])
|
| 25243 |
-
|
| 25244 |
-
embeddings = embeddings / norm
|
| 25245 |
```
|
| 25246 |
|
| 25247 |
</p>
|
|
|
|
| 25240 |
|
| 25241 |
# Apply mean pooling to 'outputs' to get a single representation of each text
|
| 25242 |
embeddings = mean_pooling(outputs, input_text["attention_mask"])
|
| 25243 |
+
embeddings = embeddings / np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
|
|
|
|
| 25244 |
```
|
| 25245 |
|
| 25246 |
</p>
|