Spaces:
Running
Running
important bug fix for image selection
Browse files- Archive/optimization.py +37 -0
- Archive/optimization2.py +40 -0
- Archive/test_form.py +39 -0
- Home.py +2 -0
- pages/Gallery.py +121 -21
- pages/__pycache__/Gallery.cpython-39.pyc +0 -0
Archive/optimization.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.optimize import minimize, differential_evolution
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Define the function y = x_1*w_1 + x_2*w_2 + x_3*w_3
|
| 6 |
+
def objective_function(w_indices):
|
| 7 |
+
x_1 = x_1_values[int(w_indices[0])]
|
| 8 |
+
x_2 = x_2_values[int(w_indices[1])]
|
| 9 |
+
x_3 = x_3_values[int(w_indices[2])]
|
| 10 |
+
return - (x_1 * w_indices[3] + x_2 * w_indices[4] + x_3 * w_indices[5]) # Use w_indices to get w_1, w_2, w_3
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
if __name__ == '__main__':
|
| 14 |
+
# Given sets of discrete values for x_1, x_2, and x_3
|
| 15 |
+
x_1_values = [1, 2, 3, 5, 6]
|
| 16 |
+
x_2_values = [0, 5, 7, 2, 1]
|
| 17 |
+
x_3_values = [3, 7, 4, 5, 2]
|
| 18 |
+
|
| 19 |
+
# Perform differential evolution optimization with integer variables
|
| 20 |
+
# bounds = [(0, len(x_1_values) - 2), (0, len(x_2_values) - 1), (0, len(x_3_values) - 1), (-1, 1), (-1, 1), (-1, 1)]
|
| 21 |
+
bounds = [(3, 4), (3, 4), (3, 4), (-1, 1), (-1, 1), (-1, 1)]
|
| 22 |
+
result = differential_evolution(objective_function, bounds)
|
| 23 |
+
|
| 24 |
+
# Get the optimal indices of x_1, x_2, and x_3
|
| 25 |
+
x_1_index, x_2_index, x_3_index, w_1_opt, w_2_opt, w_3_opt = result.x
|
| 26 |
+
|
| 27 |
+
# Calculate the peak point (x_1, x_2, x_3) corresponding to the optimal indices
|
| 28 |
+
x_1_peak = x_1_values[int(x_1_index)]
|
| 29 |
+
x_2_peak = x_2_values[int(x_2_index)]
|
| 30 |
+
x_3_peak = x_3_values[int(x_3_index)]
|
| 31 |
+
|
| 32 |
+
# Print the results
|
| 33 |
+
print("Optimal w_1:", w_1_opt)
|
| 34 |
+
print("Optimal w_2:", w_2_opt)
|
| 35 |
+
print("Optimal w_3:", w_3_opt)
|
| 36 |
+
print("Peak Point (x_1, x_2, x_3):", (x_1_peak, x_2_peak, x_3_peak))
|
| 37 |
+
print("Maximum Value of y:", -result.fun) # Use negative sign as we previously used to maximize
|
Archive/optimization2.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.optimize import minimize
|
| 3 |
+
|
| 4 |
+
if __name__ == '__main__':
|
| 5 |
+
|
| 6 |
+
# Given subset of m values for x_1, x_2, and x_3
|
| 7 |
+
x1_subset = [2, 3, 4]
|
| 8 |
+
x2_subset = [0, 1]
|
| 9 |
+
x3_subset = [5, 6, 7]
|
| 10 |
+
|
| 11 |
+
# Full set of possible values for x_1, x_2, and x_3
|
| 12 |
+
x1_full = [1, 2, 3, 4, 5]
|
| 13 |
+
x2_full = [0, 1, 2, 3, 4, 5]
|
| 14 |
+
x3_full = [3, 5, 7]
|
| 15 |
+
|
| 16 |
+
# Define the objective function for quantile-based ranking
|
| 17 |
+
def objective_function(w):
|
| 18 |
+
y_subset = [x1 * w[0] + x2 * w[1] + x3 * w[2] for x1, x2, x3 in zip(x1_subset, x2_subset, x3_subset)]
|
| 19 |
+
y_full_set = [x1 * w[0] + x2 * w[1] + x3 * w[2] for x1 in x1_full for x2 in x2_full for x3 in x3_full]
|
| 20 |
+
|
| 21 |
+
# Calculate the 90th percentile of y values for the full set
|
| 22 |
+
y_full_set_90th_percentile = np.percentile(y_full_set, 90)
|
| 23 |
+
|
| 24 |
+
# Maximize the difference between the 90th percentile of the subset and the 90th percentile of the full set
|
| 25 |
+
return - min(y_subset) + y_full_set_90th_percentile
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Bounds for w_1, w_2, and w_3 (-1 to 1)
|
| 29 |
+
bounds = [(-1, 1), (-1, 1), (-1, 1)]
|
| 30 |
+
|
| 31 |
+
# Perform bounded optimization to find the values of w_1, w_2, and w_3 that maximize the objective function
|
| 32 |
+
result = minimize(objective_function, np.zeros(3), method='TNC', bounds=bounds)
|
| 33 |
+
|
| 34 |
+
# Get the optimal values of w_1, w_2, and w_3
|
| 35 |
+
w_1_opt, w_2_opt, w_3_opt = result.x
|
| 36 |
+
|
| 37 |
+
# Print the results
|
| 38 |
+
print("Optimal w_1:", w_1_opt)
|
| 39 |
+
print("Optimal w_2:", w_2_opt)
|
| 40 |
+
print("Optimal w_3:", w_3_opt)
|
Archive/test_form.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def grid(col=3, row=4, name='grid1'):
|
| 5 |
+
cols = st.columns(col)
|
| 6 |
+
for i in range(row):
|
| 7 |
+
for j in range(col):
|
| 8 |
+
with cols[j]:
|
| 9 |
+
value = st.session_state.checked_dic[name].get(f"{name}_{i*col+j}", False)
|
| 10 |
+
|
| 11 |
+
check = st.checkbox(f"{i*col+j}", key=f"{name}_{i*col+j}", value=value)
|
| 12 |
+
if check:
|
| 13 |
+
st.session_state.checked_dic[name][f"{name}_{i*col+j}"] = True
|
| 14 |
+
else:
|
| 15 |
+
st.session_state.checked_dic[name][f"{name}_{i*col+j}"] = False
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def on_click():
|
| 19 |
+
for key in st.session_state:
|
| 20 |
+
if st.session_state[key] and key[-1].isdigit():
|
| 21 |
+
st.write(key)
|
| 22 |
+
# for key in st.session_state.checked_dic[name]:
|
| 23 |
+
# if st.session_state.checked_dic[name][key]:
|
| 24 |
+
# st.write(key)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if __name__ == "__main__":
|
| 29 |
+
if 'checked_dic' not in st.session_state:
|
| 30 |
+
st.session_state.checked_dic = {'grid1': {}, 'grid2': {}}
|
| 31 |
+
|
| 32 |
+
name = st.selectbox('Select a grid', ['grid1', 'grid2'])
|
| 33 |
+
|
| 34 |
+
with st.form(f"{name}_form"):
|
| 35 |
+
grid(name=name)
|
| 36 |
+
submit_button = st.form_submit_button("Submit", on_click=on_click)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
Home.py
CHANGED
|
@@ -38,6 +38,8 @@ def logout():
|
|
| 38 |
|
| 39 |
|
| 40 |
if __name__ == '__main__':
|
|
|
|
|
|
|
| 41 |
st.set_page_config(page_title="Login", page_icon="🏠", layout="wide")
|
| 42 |
st.write('A Research by MAPS Lab, NYU Shanghai')
|
| 43 |
st.title("Personalized Model Coffer")
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
if __name__ == '__main__':
|
| 41 |
+
# print(st.source_util.get_pages('Home.py'))
|
| 42 |
+
|
| 43 |
st.set_page_config(page_title="Login", page_icon="🏠", layout="wide")
|
| 44 |
st.write('A Research by MAPS Lab, NYU Shanghai')
|
| 45 |
st.title("Personalized Model Coffer")
|
pages/Gallery.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
| 1 |
-
import
|
|
|
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
-
import
|
|
|
|
|
|
|
| 5 |
from datasets import load_dataset, Dataset, load_from_disk
|
| 6 |
from huggingface_hub import login
|
| 7 |
-
import os
|
| 8 |
-
import requests
|
| 9 |
-
from bs4 import BeautifulSoup
|
| 10 |
-
import altair as alt
|
| 11 |
from streamlit_extras.switch_page_button import switch_page
|
|
|
|
| 12 |
|
| 13 |
SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
|
| 14 |
|
|
@@ -62,20 +64,25 @@ class GalleryApp:
|
|
| 62 |
# handel checkbox information
|
| 63 |
prompt_id = items.iloc[idx + j]['prompt_id']
|
| 64 |
modelVersion_id = items.iloc[idx + j]['modelVersion_id']
|
|
|
|
| 65 |
check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
|
| 66 |
|
|
|
|
|
|
|
| 67 |
# show checkbox
|
| 68 |
-
checked = st.checkbox('Select', key=f'select_{
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
| 79 |
|
| 80 |
# show selected info
|
| 81 |
for key in info:
|
|
@@ -186,7 +193,7 @@ class GalleryApp:
|
|
| 186 |
# select number of columns
|
| 187 |
col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
|
| 188 |
|
| 189 |
-
return items, info, col_num
|
| 190 |
|
| 191 |
def sidebar(self):
|
| 192 |
with st.sidebar:
|
|
@@ -244,7 +251,7 @@ class GalleryApp:
|
|
| 244 |
st.title('Model Visualization and Retrieval')
|
| 245 |
st.write('This is a gallery of images generated by the models')
|
| 246 |
|
| 247 |
-
prompt_tags, tag, prompt_id, items
|
| 248 |
|
| 249 |
# add safety check for some prompts
|
| 250 |
safety_check = True
|
|
@@ -263,8 +270,23 @@ class GalleryApp:
|
|
| 263 |
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
|
| 264 |
|
| 265 |
if safety_check:
|
| 266 |
-
items, info, col_num = self.selection_panel(items)
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
with st.form(key=f'{prompt_id}'):
|
| 270 |
# buttons = st.columns([1, 1, 1])
|
|
@@ -293,20 +315,97 @@ class GalleryApp:
|
|
| 293 |
with st.spinner('Loading images...'):
|
| 294 |
self.gallery_standard(items, col_num, info)
|
| 295 |
|
|
|
|
|
|
|
|
|
|
| 296 |
def submit_actions(self, status, prompt_id):
|
| 297 |
if status == 'Select':
|
| 298 |
modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
|
| 299 |
st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
|
| 300 |
print(st.session_state.selected_dict, 'select')
|
|
|
|
| 301 |
elif status == 'Deselect':
|
| 302 |
st.session_state.selected_dict[prompt_id] = []
|
| 303 |
print(st.session_state.selected_dict, 'deselect')
|
|
|
|
| 304 |
# self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
|
| 305 |
pass
|
| 306 |
elif status == 'Continue':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
# switch_page("ranking")
|
|
|
|
| 308 |
pass
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
@st.cache_data
|
| 312 |
def load_hf_dataset():
|
|
@@ -342,6 +441,7 @@ def load_hf_dataset():
|
|
| 342 |
|
| 343 |
if __name__ == "__main__":
|
| 344 |
st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
|
|
|
|
| 345 |
if 'user_id' not in st.session_state:
|
| 346 |
st.warning('Please log in first.')
|
| 347 |
home_btn = st.button('Go to Home Page')
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
|
| 4 |
+
import altair as alt
|
| 5 |
import numpy as np
|
| 6 |
import pandas as pd
|
| 7 |
+
import streamlit as st
|
| 8 |
+
|
| 9 |
+
from bs4 import BeautifulSoup
|
| 10 |
from datasets import load_dataset, Dataset, load_from_disk
|
| 11 |
from huggingface_hub import login
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from streamlit_extras.switch_page_button import switch_page
|
| 13 |
+
from sklearn.svm import LinearSVC
|
| 14 |
|
| 15 |
SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
|
| 16 |
|
|
|
|
| 64 |
# handel checkbox information
|
| 65 |
prompt_id = items.iloc[idx + j]['prompt_id']
|
| 66 |
modelVersion_id = items.iloc[idx + j]['modelVersion_id']
|
| 67 |
+
|
| 68 |
check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
|
| 69 |
|
| 70 |
+
st.write("Position: ", idx + j)
|
| 71 |
+
|
| 72 |
# show checkbox
|
| 73 |
+
checked = st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
|
| 74 |
+
|
| 75 |
+
#
|
| 76 |
+
# if checked:
|
| 77 |
+
# if prompt_id not in st.session_state.selected_dict:
|
| 78 |
+
# st.session_state.selected_dict[prompt_id] = []
|
| 79 |
+
# if modelVersion_id not in st.session_state.selected_dict[prompt_id]:
|
| 80 |
+
# st.session_state.selected_dict[prompt_id].append(modelVersion_id)
|
| 81 |
+
# else:
|
| 82 |
+
# try:
|
| 83 |
+
# st.session_state.selected_dict[prompt_id].remove(modelVersion_id)
|
| 84 |
+
# except:
|
| 85 |
+
# pass
|
| 86 |
|
| 87 |
# show selected info
|
| 88 |
for key in info:
|
|
|
|
| 193 |
# select number of columns
|
| 194 |
col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
|
| 195 |
|
| 196 |
+
return items, info, col_num, preprocessor
|
| 197 |
|
| 198 |
def sidebar(self):
|
| 199 |
with st.sidebar:
|
|
|
|
| 251 |
st.title('Model Visualization and Retrieval')
|
| 252 |
st.write('This is a gallery of images generated by the models')
|
| 253 |
|
| 254 |
+
prompt_tags, tag, prompt_id, items= self.sidebar()
|
| 255 |
|
| 256 |
# add safety check for some prompts
|
| 257 |
safety_check = True
|
|
|
|
| 270 |
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
|
| 271 |
|
| 272 |
if safety_check:
|
| 273 |
+
items, info, col_num, preprocessor = self.selection_panel(items)
|
| 274 |
+
|
| 275 |
+
# method = st.radio('Select a method to set dynamic weight', ['Grid Search', 'SVM', 'Greedy', 'Disable dynamic weight'], index=0, horizontal=True)
|
| 276 |
+
#
|
| 277 |
+
# if method != 'Disable dynamic weight':
|
| 278 |
+
# if len(st.session_state.selected_dict[prompt_id]) > 0:
|
| 279 |
+
# selected = items[
|
| 280 |
+
# items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
|
| 281 |
+
# drop=True)
|
| 282 |
+
# st.session_state.score_weights[0: 3] = self.dynamic_weight(selected, items, preprocessor,
|
| 283 |
+
# method=method)
|
| 284 |
+
# # st.experimental_rerun()
|
| 285 |
+
#
|
| 286 |
+
# else:
|
| 287 |
+
# print('no selected models')
|
| 288 |
+
#
|
| 289 |
+
# st.write(st.session_state.selected_dict.get(prompt_id, []))
|
| 290 |
|
| 291 |
with st.form(key=f'{prompt_id}'):
|
| 292 |
# buttons = st.columns([1, 1, 1])
|
|
|
|
| 315 |
with st.spinner('Loading images...'):
|
| 316 |
self.gallery_standard(items, col_num, info)
|
| 317 |
|
| 318 |
+
with st.sidebar:
|
| 319 |
+
st.write(str(st.session_state.selected_dict[prompt_id]))
|
| 320 |
+
|
| 321 |
def submit_actions(self, status, prompt_id):
|
| 322 |
if status == 'Select':
|
| 323 |
modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
|
| 324 |
st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
|
| 325 |
print(st.session_state.selected_dict, 'select')
|
| 326 |
+
st.experimental_rerun()
|
| 327 |
elif status == 'Deselect':
|
| 328 |
st.session_state.selected_dict[prompt_id] = []
|
| 329 |
print(st.session_state.selected_dict, 'deselect')
|
| 330 |
+
st.experimental_rerun()
|
| 331 |
# self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
|
| 332 |
pass
|
| 333 |
elif status == 'Continue':
|
| 334 |
+
st.session_state.selected_dict[prompt_id] = []
|
| 335 |
+
for key in st.session_state:
|
| 336 |
+
keys = key.split('_')
|
| 337 |
+
if keys[0] == 'select' and keys[1] == str(prompt_id):
|
| 338 |
+
if st.session_state[key]:
|
| 339 |
+
st.session_state.selected_dict[prompt_id].append(int(keys[2]))
|
| 340 |
# switch_page("ranking")
|
| 341 |
+
print(st.session_state.selected_dict, 'continue')
|
| 342 |
pass
|
| 343 |
|
| 344 |
+
def dynamic_weight(self, selected, items, preprocessor='crop', method='Grid Search'):
|
| 345 |
+
optimal_weight = [0, 0, 0]
|
| 346 |
+
if method == 'Grid Search':
|
| 347 |
+
# grid search method
|
| 348 |
+
top_ranking = len(items) * len(selected)
|
| 349 |
+
|
| 350 |
+
for clip_weight in np.arange(-1, 1, 0.1):
|
| 351 |
+
for mcos_weight in np.arange(-1, 1, 0.1):
|
| 352 |
+
for pop_weight in np.arange(-1, 1, 0.1):
|
| 353 |
+
weight_all = clip_weight*items[f'norm_clip_{preprocessor}'] + mcos_weight*items[f'norm_mcos_{preprocessor}'] + pop_weight*items['norm_pop']
|
| 354 |
+
weight_all_sorted = weight_all.sort_values(ascending=False)
|
| 355 |
+
|
| 356 |
+
weight_selected = clip_weight*selected[f'norm_clip_{preprocessor}'] + mcos_weight*selected[f'norm_mcos_{preprocessor}'] + pop_weight*selected['norm_pop']
|
| 357 |
+
|
| 358 |
+
# get the index of values of weight_selected in weight_all_sorted
|
| 359 |
+
rankings = []
|
| 360 |
+
for weight in weight_selected:
|
| 361 |
+
rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
|
| 362 |
+
if sum(rankings) <= top_ranking:
|
| 363 |
+
top_ranking = sum(rankings)
|
| 364 |
+
optimal_weight = [clip_weight, mcos_weight, pop_weight]
|
| 365 |
+
print('optimal weight:', optimal_weight)
|
| 366 |
+
|
| 367 |
+
elif method == 'SVM':
|
| 368 |
+
# svm method
|
| 369 |
+
print('start svm method')
|
| 370 |
+
# get residual dataframe that contains models not selected
|
| 371 |
+
residual = items[~items['modelVersion_id'].isin(selected['modelVersion_id'])].reset_index(drop=True)
|
| 372 |
+
residual = residual[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
|
| 373 |
+
residual = residual.to_numpy()
|
| 374 |
+
selected = selected[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
|
| 375 |
+
selected = selected.to_numpy()
|
| 376 |
+
|
| 377 |
+
y = np.concatenate((np.full((len(selected), 1), -1), np.full((len(residual), 1), 1)), axis=0).ravel()
|
| 378 |
+
X = np.concatenate((selected, residual), axis=0)
|
| 379 |
+
|
| 380 |
+
# fit svm model, and get parameters for the hyperplane
|
| 381 |
+
clf = LinearSVC(random_state=0, C=1.0, fit_intercept=False, dual='auto')
|
| 382 |
+
clf.fit(X, y)
|
| 383 |
+
optimal_weight = clf.coef_[0].tolist()
|
| 384 |
+
print('optimal weight:', optimal_weight)
|
| 385 |
+
pass
|
| 386 |
+
|
| 387 |
+
elif method == 'Greedy':
|
| 388 |
+
for idx in selected.index:
|
| 389 |
+
# find which score is the highest, clip, mcos, or pop
|
| 390 |
+
clip_score = selected.loc[idx, 'norm_clip_crop']
|
| 391 |
+
mcos_score = selected.loc[idx, 'norm_mcos_crop']
|
| 392 |
+
pop_score = selected.loc[idx, 'norm_pop']
|
| 393 |
+
if clip_score >= mcos_score and clip_score >= pop_score:
|
| 394 |
+
optimal_weight[0] += 1
|
| 395 |
+
elif mcos_score >= clip_score and mcos_score >= pop_score:
|
| 396 |
+
optimal_weight[1] += 1
|
| 397 |
+
elif pop_score >= clip_score and pop_score >= mcos_score:
|
| 398 |
+
optimal_weight[2] += 1
|
| 399 |
+
|
| 400 |
+
# normalize optimal_weight
|
| 401 |
+
optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
|
| 402 |
+
print('optimal weight:', optimal_weight)
|
| 403 |
+
|
| 404 |
+
return optimal_weight
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
|
| 409 |
|
| 410 |
@st.cache_data
|
| 411 |
def load_hf_dataset():
|
|
|
|
| 441 |
|
| 442 |
if __name__ == "__main__":
|
| 443 |
st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
|
| 444 |
+
|
| 445 |
if 'user_id' not in st.session_state:
|
| 446 |
st.warning('Please log in first.')
|
| 447 |
home_btn = st.button('Go to Home Page')
|
pages/__pycache__/Gallery.cpython-39.pyc
CHANGED
|
Binary files a/pages/__pycache__/Gallery.cpython-39.pyc and b/pages/__pycache__/Gallery.cpython-39.pyc differ
|
|
|