Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
test: add unit tests for utils
Browse files- src/utils.py +20 -13
- tests/src/test_utils.py +151 -20
- tests/test_utils.py +1 -12
src/utils.py
CHANGED
|
@@ -98,14 +98,7 @@ def get_default_cols(task: TaskType, version_slug, add_fix_cols: bool = True) ->
|
|
| 98 |
return cols, types
|
| 99 |
|
| 100 |
|
| 101 |
-
def
|
| 102 |
-
df: pd.DataFrame,
|
| 103 |
-
domain_query: list,
|
| 104 |
-
language_query: list,
|
| 105 |
-
task: TaskType = TaskType.qa,
|
| 106 |
-
reset_ranking: bool = True,
|
| 107 |
-
version_slug: str = None,
|
| 108 |
-
) -> pd.DataFrame:
|
| 109 |
cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
|
| 110 |
selected_cols = []
|
| 111 |
for c in cols:
|
|
@@ -115,21 +108,35 @@ def select_columns(
|
|
| 115 |
eval_col = LongDocBenchmarks[version_slug].value[c].value
|
| 116 |
else:
|
| 117 |
raise NotImplementedError
|
| 118 |
-
if eval_col.domain not in
|
| 119 |
continue
|
| 120 |
-
if eval_col.lang not in
|
| 121 |
continue
|
| 122 |
selected_cols.append(c)
|
| 123 |
# We use COLS to maintain sorting
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
fixed_cols, _ = get_fixed_col_names_and_types()
|
| 125 |
filtered_df = df[fixed_cols + selected_cols]
|
| 126 |
filtered_df.replace({"": pd.NA}, inplace=True)
|
| 127 |
if reset_ranking:
|
| 128 |
-
filtered_df[COL_NAME_AVG] =
|
| 129 |
-
|
|
|
|
|
|
|
| 130 |
filtered_df.reset_index(inplace=True, drop=True)
|
| 131 |
filtered_df = reset_rank(filtered_df)
|
| 132 |
-
|
| 133 |
return filtered_df
|
| 134 |
|
| 135 |
|
|
|
|
| 98 |
return cols, types
|
| 99 |
|
| 100 |
|
| 101 |
+
def get_selected_cols(task, version_slug, domains, languages):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
|
| 103 |
selected_cols = []
|
| 104 |
for c in cols:
|
|
|
|
| 108 |
eval_col = LongDocBenchmarks[version_slug].value[c].value
|
| 109 |
else:
|
| 110 |
raise NotImplementedError
|
| 111 |
+
if eval_col.domain not in domains:
|
| 112 |
continue
|
| 113 |
+
if eval_col.lang not in languages:
|
| 114 |
continue
|
| 115 |
selected_cols.append(c)
|
| 116 |
# We use COLS to maintain sorting
|
| 117 |
+
return selected_cols
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def select_columns(
|
| 121 |
+
df: pd.DataFrame,
|
| 122 |
+
domains: list,
|
| 123 |
+
languages: list,
|
| 124 |
+
task: TaskType = TaskType.qa,
|
| 125 |
+
reset_ranking: bool = True,
|
| 126 |
+
version_slug: str = None,
|
| 127 |
+
) -> pd.DataFrame:
|
| 128 |
+
selected_cols = get_selected_cols(
|
| 129 |
+
task, version_slug, domains, languages)
|
| 130 |
fixed_cols, _ = get_fixed_col_names_and_types()
|
| 131 |
filtered_df = df[fixed_cols + selected_cols]
|
| 132 |
filtered_df.replace({"": pd.NA}, inplace=True)
|
| 133 |
if reset_ranking:
|
| 134 |
+
filtered_df[COL_NAME_AVG] = \
|
| 135 |
+
filtered_df[selected_cols].apply(calculate_mean, axis=1).round(decimals=2)
|
| 136 |
+
filtered_df.sort_values(
|
| 137 |
+
by=[COL_NAME_AVG], ascending=False, inplace=True)
|
| 138 |
filtered_df.reset_index(inplace=True, drop=True)
|
| 139 |
filtered_df = reset_rank(filtered_df)
|
|
|
|
| 140 |
return filtered_df
|
| 141 |
|
| 142 |
|
tests/src/test_utils.py
CHANGED
|
@@ -1,26 +1,157 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
-
def
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
-
def
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
from src.utils import remove_html, calculate_mean, filter_models, filter_queries, get_default_cols, select_columns, get_selected_cols
|
| 5 |
+
from src.models import model_hyperlink, TaskType
|
| 6 |
+
from src.columns import COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
NUM_QA_BENCHMARKS_24_05 = 53
|
| 10 |
+
NUM_DOC_BENCHMARKS_24_05 = 11
|
| 11 |
+
NUM_QA_BENCHMARKS_24_04 = 13
|
| 12 |
+
NUM_DOC_BENCHMARKS_24_04 = 15
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@pytest.fixture
|
| 16 |
+
def toy_df():
|
| 17 |
+
return pd.DataFrame(
|
| 18 |
+
{
|
| 19 |
+
"Retrieval Method": [
|
| 20 |
+
"bge-m3",
|
| 21 |
+
"bge-m3",
|
| 22 |
+
"jina-embeddings-v2-base",
|
| 23 |
+
"jina-embeddings-v2-base"
|
| 24 |
+
],
|
| 25 |
+
"Reranking Model": [
|
| 26 |
+
"bge-reranker-v2-m3",
|
| 27 |
+
"NoReranker",
|
| 28 |
+
"bge-reranker-v2-m3",
|
| 29 |
+
"NoReranker"
|
| 30 |
+
],
|
| 31 |
+
"Rank 🏆": [1, 2, 3, 4],
|
| 32 |
+
"Revision": ["", "", "", ""],
|
| 33 |
+
"Submission Date": ["", "", "", ""],
|
| 34 |
+
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
|
| 35 |
+
"wiki_en": [0.8, 0.7, 0.2, 0.1],
|
| 36 |
+
"wiki_zh": [0.4, 0.1, 0.4, 0.3],
|
| 37 |
+
"news_en": [0.8, 0.7, 0.2, 0.1],
|
| 38 |
+
"news_zh": [0.4, 0.1, 0.4, 0.3],
|
| 39 |
+
}
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_remove_html():
|
| 44 |
+
model_name = "jina-embeddings-v3"
|
| 45 |
+
html_str = model_hyperlink(
|
| 46 |
+
"https://jina.ai", model_name)
|
| 47 |
+
output_str = remove_html(html_str)
|
| 48 |
+
assert output_str == model_name
|
| 49 |
|
| 50 |
|
| 51 |
+
def test_calculate_mean():
|
| 52 |
+
valid_row = [1, 3]
|
| 53 |
+
invalid_row = [2, pd.NA]
|
| 54 |
+
df = pd.DataFrame([valid_row, invalid_row], columns=["a", "b"])
|
| 55 |
+
result = list(df.apply(calculate_mean, axis=1))
|
| 56 |
+
assert result[0] == sum(valid_row) / 2
|
| 57 |
+
assert result[1] == -1
|
| 58 |
|
| 59 |
|
| 60 |
+
@pytest.mark.parametrize("models, expected", [
|
| 61 |
+
(["model1", "model3"], 2),
|
| 62 |
+
(["model1", "model_missing"], 1),
|
| 63 |
+
(["model1", "model2", "model3"], 3),
|
| 64 |
+
(["model1", ], 1),
|
| 65 |
+
([], 3),
|
| 66 |
+
])
|
| 67 |
+
def test_filter_models(models, expected):
|
| 68 |
+
df = pd.DataFrame(
|
| 69 |
+
{
|
| 70 |
+
COL_NAME_RERANKING_MODEL: ["model1", "model2", "model3", ],
|
| 71 |
+
"col2": [1, 2, 3],
|
| 72 |
+
}
|
| 73 |
+
)
|
| 74 |
+
output_df = filter_models(df, models)
|
| 75 |
+
assert len(output_df) == expected
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@pytest.mark.parametrize("query, expected", [
|
| 79 |
+
("model1;model3", 2),
|
| 80 |
+
("model1;model4", 1),
|
| 81 |
+
("model1;model2;model3", 3),
|
| 82 |
+
("model1", 1),
|
| 83 |
+
("", 3),
|
| 84 |
+
])
|
| 85 |
+
def test_filter_queries(query, expected):
|
| 86 |
+
df = pd.DataFrame(
|
| 87 |
+
{
|
| 88 |
+
COL_NAME_RETRIEVAL_MODEL: ["model1", "model2", "model3", ],
|
| 89 |
+
COL_NAME_RERANKING_MODEL: ["model4", "model5", "model6", ],
|
| 90 |
+
}
|
| 91 |
+
)
|
| 92 |
+
output_df = filter_queries(query, df)
|
| 93 |
+
assert len(output_df) == expected
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@pytest.mark.parametrize(
|
| 97 |
+
"task_type, slug, expected",
|
| 98 |
+
[
|
| 99 |
+
(TaskType.qa, "air_bench_2404", NUM_QA_BENCHMARKS_24_04),
|
| 100 |
+
(TaskType.long_doc, "air_bench_2404", NUM_DOC_BENCHMARKS_24_04),
|
| 101 |
+
(TaskType.qa, "air_bench_2405", NUM_QA_BENCHMARKS_24_05),
|
| 102 |
+
(TaskType.long_doc, "air_bench_2405", NUM_DOC_BENCHMARKS_24_05),
|
| 103 |
+
]
|
| 104 |
+
)
|
| 105 |
+
def test_get_default_cols(task_type, slug, expected):
|
| 106 |
+
attr_cols = ['Rank 🏆', 'Retrieval Method', 'Reranking Model', 'Revision', 'Submission Date', 'Average ⬆️']
|
| 107 |
+
cols, types = get_default_cols(task_type, slug)
|
| 108 |
+
benchmark_cols = list(frozenset(cols).difference(frozenset(attr_cols)))
|
| 109 |
+
assert len(benchmark_cols) == expected
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@pytest.mark.parametrize(
|
| 113 |
+
"task_type, domains, languages, expected",
|
| 114 |
+
[
|
| 115 |
+
(TaskType.qa, ["wiki", "news"], ["zh",], ["wiki_zh", "news_zh"]),
|
| 116 |
+
(TaskType.qa, ["law",], ["zh", "en"], ["law_en"]),
|
| 117 |
+
(
|
| 118 |
+
TaskType.long_doc,
|
| 119 |
+
["healthcare"],
|
| 120 |
+
["zh", "en"],
|
| 121 |
+
[
|
| 122 |
+
'healthcare_en_pubmed_100k_200k_1',
|
| 123 |
+
'healthcare_en_pubmed_100k_200k_2',
|
| 124 |
+
'healthcare_en_pubmed_100k_200k_3',
|
| 125 |
+
'healthcare_en_pubmed_40k_50k_5_merged',
|
| 126 |
+
'healthcare_en_pubmed_30k_40k_10_merged'
|
| 127 |
+
]
|
| 128 |
+
)
|
| 129 |
+
]
|
| 130 |
+
)
|
| 131 |
+
def test_get_selected_cols(task_type, domains, languages, expected):
|
| 132 |
+
slug = "air_bench_2404"
|
| 133 |
+
cols = get_selected_cols(task_type, slug, domains, languages)
|
| 134 |
+
assert sorted(cols) == sorted(expected)
|
| 135 |
|
| 136 |
|
| 137 |
+
def test_select_columns(toy_df):
|
| 138 |
+
expected = [
|
| 139 |
+
'Rank 🏆',
|
| 140 |
+
'Retrieval Method',
|
| 141 |
+
'Reranking Model',
|
| 142 |
+
'Revision',
|
| 143 |
+
'Submission Date',
|
| 144 |
+
'Average ⬆️',
|
| 145 |
+
'news_zh']
|
| 146 |
+
df_result = select_columns(
|
| 147 |
+
toy_df,
|
| 148 |
+
[
|
| 149 |
+
"news",
|
| 150 |
+
],
|
| 151 |
+
[
|
| 152 |
+
"zh",
|
| 153 |
+
],
|
| 154 |
+
version_slug="air_bench_2404",
|
| 155 |
+
)
|
| 156 |
+
assert len(df_result.columns) == len(expected)
|
| 157 |
+
assert df_result["Average ⬆️"].equals(df_result["news_zh"])
|
tests/test_utils.py
CHANGED
|
@@ -75,18 +75,7 @@ def test_filter_queries(toy_df):
|
|
| 75 |
assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
|
| 76 |
|
| 77 |
|
| 78 |
-
|
| 79 |
-
df_result = select_columns(
|
| 80 |
-
toy_df,
|
| 81 |
-
[
|
| 82 |
-
"news",
|
| 83 |
-
],
|
| 84 |
-
[
|
| 85 |
-
"zh",
|
| 86 |
-
],
|
| 87 |
-
)
|
| 88 |
-
assert len(df_result.columns) == 4
|
| 89 |
-
assert df_result["Average ⬆️"].equals(df_result["news_zh"])
|
| 90 |
|
| 91 |
|
| 92 |
def test_update_table_long_doc(toy_df_long_doc):
|
|
|
|
| 75 |
assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
|
| 76 |
|
| 77 |
|
| 78 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
def test_update_table_long_doc(toy_df_long_doc):
|