init
Browse files- README.md +5 -0
- app.py +8 -4
- requirements.txt +2 -1
README.md
CHANGED
|
@@ -8,6 +8,11 @@ sdk_version: 5.42.0
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
short_description: HPSv3 demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
short_description: HPSv3 demo
|
| 11 |
+
preload_from_hub:
|
| 12 |
+
- MizzenAI/HPSv3 HPSv3.safetensors
|
| 13 |
+
- xswu/HPSv2 HPS_v2.1_compressed.pt
|
| 14 |
+
- yuvalkirstain/PickScore_v1
|
| 15 |
+
- laion/CLIP-ViT-H-14-laion2B-s32B-b79K
|
| 16 |
---
|
| 17 |
|
| 18 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import os
|
|
| 4 |
import sys
|
| 5 |
from PIL import Image
|
| 6 |
import uuid
|
|
|
|
| 7 |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 8 |
|
| 9 |
from hpsv3.inference import HPSv3RewardInferencer
|
|
@@ -30,7 +31,7 @@ MODEL_CONFIGS = {
|
|
| 30 |
},
|
| 31 |
"HPSv2": {
|
| 32 |
"name": "HPSv2",
|
| 33 |
-
"checkpoint_path": "
|
| 34 |
"type": "hpsv2"
|
| 35 |
},
|
| 36 |
"ImageReward": {
|
|
@@ -40,12 +41,12 @@ MODEL_CONFIGS = {
|
|
| 40 |
},
|
| 41 |
"PickScore": {
|
| 42 |
"name": "PickScore",
|
| 43 |
-
"checkpoint_path": "
|
| 44 |
"type": "pickscore"
|
| 45 |
},
|
| 46 |
"CLIP": {
|
| 47 |
"name": "CLIP ViT-H-14",
|
| 48 |
-
"checkpoint_path": "/
|
| 49 |
"type": "clip"
|
| 50 |
}
|
| 51 |
}
|
|
@@ -73,8 +74,10 @@ def load_model(model_key, update_status_fn=None):
|
|
| 73 |
|
| 74 |
try:
|
| 75 |
if config["type"] == "hpsv3":
|
|
|
|
| 76 |
model = HPSv3RewardInferencer(
|
| 77 |
device=DEVICE,
|
|
|
|
| 78 |
)
|
| 79 |
elif config["type"] == "hpsv2":
|
| 80 |
model_obj, preprocess_train, preprocess_val = create_model_and_transforms(
|
|
@@ -96,7 +99,8 @@ def load_model(model_key, update_status_fn=None):
|
|
| 96 |
with_score_predictor=False,
|
| 97 |
with_region_predictor=False
|
| 98 |
)
|
| 99 |
-
|
|
|
|
| 100 |
model_obj.load_state_dict(checkpoint['state_dict'])
|
| 101 |
model_obj = model_obj.to(DEVICE).eval()
|
| 102 |
tokenizer = get_tokenizer('ViT-H-14')
|
|
|
|
| 4 |
import sys
|
| 5 |
from PIL import Image
|
| 6 |
import uuid
|
| 7 |
+
import huggingface_hub
|
| 8 |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 9 |
|
| 10 |
from hpsv3.inference import HPSv3RewardInferencer
|
|
|
|
| 31 |
},
|
| 32 |
"HPSv2": {
|
| 33 |
"name": "HPSv2",
|
| 34 |
+
"checkpoint_path": "xswu/HPSv2/HPS_v2.1_compressed.pt",
|
| 35 |
"type": "hpsv2"
|
| 36 |
},
|
| 37 |
"ImageReward": {
|
|
|
|
| 41 |
},
|
| 42 |
"PickScore": {
|
| 43 |
"name": "PickScore",
|
| 44 |
+
"checkpoint_path": "yuvalkirstain/PickScore_v1",
|
| 45 |
"type": "pickscore"
|
| 46 |
},
|
| 47 |
"CLIP": {
|
| 48 |
"name": "CLIP ViT-H-14",
|
| 49 |
+
"checkpoint_path": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
|
| 50 |
"type": "clip"
|
| 51 |
}
|
| 52 |
}
|
|
|
|
| 74 |
|
| 75 |
try:
|
| 76 |
if config["type"] == "hpsv3":
|
| 77 |
+
checkpoint_path = huggingface_hub.hf_hub_download("MizzenAI/HPSv3", 'HPSv3.safetensors', repo_type='model')
|
| 78 |
model = HPSv3RewardInferencer(
|
| 79 |
device=DEVICE,
|
| 80 |
+
checkpoint_path=checkpoint_path
|
| 81 |
)
|
| 82 |
elif config["type"] == "hpsv2":
|
| 83 |
model_obj, preprocess_train, preprocess_val = create_model_and_transforms(
|
|
|
|
| 99 |
with_score_predictor=False,
|
| 100 |
with_region_predictor=False
|
| 101 |
)
|
| 102 |
+
checkpoint_path = huggingface_hub.hf_hub_download("xswu/HPSv2", 'HPS_v2.1_compressed.pt', repo_type='model')
|
| 103 |
+
checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
|
| 104 |
model_obj.load_state_dict(checkpoint['state_dict'])
|
| 105 |
model_obj = model_obj.to(DEVICE).eval()
|
| 106 |
tokenizer = get_tokenizer('ViT-H-14')
|
requirements.txt
CHANGED
|
@@ -191,4 +191,5 @@ xxhash==3.5.0
|
|
| 191 |
yarl==1.20.1
|
| 192 |
zipp==3.22.0
|
| 193 |
# flash-attn==2.7.4.post1
|
| 194 |
-
hpsv3==1.0.0
|
|
|
|
|
|
| 191 |
yarl==1.20.1
|
| 192 |
zipp==3.22.0
|
| 193 |
# flash-attn==2.7.4.post1
|
| 194 |
+
hpsv3==1.0.0
|
| 195 |
+
hpsv2
|