Spaces:
Sleeping
Sleeping
add manual prompt filter
Browse files- data/curation.json +14 -0
- pages/Gallery.py +8 -1
data/curation.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"abstract": [1,3],
|
| 3 |
+
"animal": [],
|
| 4 |
+
"architecture": [],
|
| 5 |
+
"art": [],
|
| 6 |
+
"artifact": [],
|
| 7 |
+
"food": [],
|
| 8 |
+
"illustration": [39],
|
| 9 |
+
"people": [49,50,51,62,54,56,48,60],
|
| 10 |
+
"produce & plant": [],
|
| 11 |
+
"scenery": [],
|
| 12 |
+
"vehicle": [],
|
| 13 |
+
"world knowledge": [84,83,85]
|
| 14 |
+
}
|
pages/Gallery.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import requests
|
|
@@ -187,7 +188,7 @@ class GalleryApp:
|
|
| 187 |
|
| 188 |
# set focus tag and prompt index if exists
|
| 189 |
if st.session_state.gallery_focus['tag'] is None:
|
| 190 |
-
tag_focus_idx =
|
| 191 |
else:
|
| 192 |
tag_focus_idx = prompt_tags.index(st.session_state.gallery_focus['tag'])
|
| 193 |
|
|
@@ -591,6 +592,12 @@ def load_hf_dataset(show_NSFW=False):
|
|
| 591 |
# add column to record current row index
|
| 592 |
promptBook.loc[:, 'row_idx'] = promptBook.index
|
| 593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
# apply a nsfw filter
|
| 595 |
if not show_NSFW:
|
| 596 |
promptBook = promptBook[promptBook['norm_nsfw'] <= 0.8].reset_index(drop=True)
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
import requests
|
|
|
|
| 188 |
|
| 189 |
# set focus tag and prompt index if exists
|
| 190 |
if st.session_state.gallery_focus['tag'] is None:
|
| 191 |
+
tag_focus_idx = 0
|
| 192 |
else:
|
| 193 |
tag_focus_idx = prompt_tags.index(st.session_state.gallery_focus['tag'])
|
| 194 |
|
|
|
|
| 592 |
# add column to record current row index
|
| 593 |
promptBook.loc[:, 'row_idx'] = promptBook.index
|
| 594 |
|
| 595 |
+
# apply curation filter
|
| 596 |
+
prompt_to_hide = json.load(open('./data/curation.json', 'r'))
|
| 597 |
+
prompt_to_hide = list(itertools.chain.from_iterable(prompt_to_hide.values()))
|
| 598 |
+
print('prompt to hide: ', prompt_to_hide)
|
| 599 |
+
promptBook = promptBook[~promptBook['prompt_id'].isin(prompt_to_hide)].reset_index(drop=True)
|
| 600 |
+
|
| 601 |
# apply a nsfw filter
|
| 602 |
if not show_NSFW:
|
| 603 |
promptBook = promptBook[promptBook['norm_nsfw'] <= 0.8].reset_index(drop=True)
|