yangzhitao commited on
Commit
58bbf33
·
1 Parent(s): f10ff89

refactor: update model key extraction and improve model dtype handling in create_submit_tab function for enhanced clarity

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. src/backend/schemas.py +25 -11
  3. src/submission/submit.py +16 -10
app.py CHANGED
@@ -544,13 +544,13 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
544
  raise ValueError("Model name is required")
545
 
546
  # Extract model_key from model_name (simple conversion)
547
- model_key = model_name.lower().replace("/", "_").replace("-", "_")
548
 
549
  # Build config
550
  config = {
551
  "model_name": model_name,
552
  "model_key": model_key,
553
- "model_dtype": f"torch.{precision}" if precision else None,
554
  "model_sha": revision or None, # None means "main"
555
  "model_args": None,
556
  }
 
544
  raise ValueError("Model name is required")
545
 
546
  # Extract model_key from model_name (simple conversion)
547
+ model_key = model_name.lower().replace("-", "_").replace("/", "-")
548
 
549
  # Build config
550
  config = {
551
  "model_name": model_name,
552
  "model_key": model_key,
553
+ "model_dtype": precision or None,
554
  "model_sha": revision or None, # None means "main"
555
  "model_args": None,
556
  }
src/backend/schemas.py CHANGED
@@ -1,7 +1,7 @@
1
  from datetime import datetime
2
  from typing import Annotated, Any, Generic, Literal, TypeVar
3
 
4
- from pydantic import BaseModel, ConfigDict, Field, computed_field
5
  from pydantic_core import PydanticCustomError
6
 
7
  T = TypeVar("T", bound=BaseModel)
@@ -102,12 +102,22 @@ class CommunitySubmit_Params(BaseModel):
102
 
103
  # Model info
104
  model_id: Annotated[str, Field(description='The model id. e.g. "Qwen/Qwen2.5-3B"')]
 
 
 
 
 
 
 
 
 
 
105
  base_model: Annotated[
106
- str, Field(description='The base model name (for delta or adapter weights). e.g. "Qwen/Qwen2.5-3B"')
107
- ] = ""
108
  model_sha: Annotated[
109
- str, Field(description='The model sha or "main". e.g. "3aab1f1954e9cc14eb9509a215f9e5ca08227a9b"')
110
- ] = "main"
111
  model_dtype: Literal[
112
  # float types
113
  "bfloat16",
@@ -135,8 +145,11 @@ class CommunitySubmit_Params(BaseModel):
135
  "float8_e8m0fnu",
136
  "float4_e2m1fn_x2",
137
  ]
138
- weight_type: Literal["Original"] | str = "Original"
139
- model_type: Annotated[str, Field(description='The model type. e.g. "pretrained", "fine-tuned"')] = ""
 
 
 
140
 
141
  # Submission info
142
  content: Annotated[str, Field(description='The content of the file in JSON format to upload.')]
@@ -149,13 +162,14 @@ class CommunitySubmit_Params(BaseModel):
149
  @property
150
  def filename(self) -> str:
151
  """Filename of the file to upload."""
152
- model_name = self.model_id.split("/")[-1]
153
- if not model_name:
154
  raise PydanticCustomError(
155
- "model_id_invalid", "Model id {model_id!r} is invalid.", {"model_id": self.model_id}
156
  )
157
  if not self.username:
158
  raise PydanticCustomError(
159
  "username_invalid", "Username {username!r} is invalid.", {"username": self.username}
160
  )
161
- return f"{model_name}_eval_request_False_{self.model_dtype}_{self.weight_type}_{self.username}.json"
 
 
 
1
  from datetime import datetime
2
  from typing import Annotated, Any, Generic, Literal, TypeVar
3
 
4
+ from pydantic import AfterValidator, BaseModel, ConfigDict, Field, computed_field
5
  from pydantic_core import PydanticCustomError
6
 
7
  T = TypeVar("T", bound=BaseModel)
 
102
 
103
  # Model info
104
  model_id: Annotated[str, Field(description='The model id. e.g. "Qwen/Qwen2.5-3B"')]
105
+
106
+ @computed_field
107
+ @property
108
+ def model_key(self) -> str:
109
+ if not self.model_id:
110
+ raise PydanticCustomError(
111
+ "model_id_invalid", "Model id {model_id!r} is invalid.", {"model_id": self.model_id}
112
+ )
113
+ return self.model_id.lower().replace("-", "_").replace("/", "-")
114
+
115
  base_model: Annotated[
116
+ str | None, Field(description='The base model name (for delta or adapter weights). e.g. "Qwen/Qwen2.5-3B"')
117
+ ] = None
118
  model_sha: Annotated[
119
+ str | None, Field(description='The model sha or "main". e.g. "3aab1f1954e9cc14eb9509a215f9e5ca08227a9b"')
120
+ ] = None
121
  model_dtype: Literal[
122
  # float types
123
  "bfloat16",
 
145
  "float8_e8m0fnu",
146
  "float4_e2m1fn_x2",
147
  ]
148
+ weight_type: Literal["Original", "Delta", "Adapter"] = "Original"
149
+ model_type: Annotated[
150
+ str,
151
+ Field(description='The model type. e.g. "pretrained", "fine-tuned", "instruction-tuned", "RL-tuned"'),
152
+ ]
153
 
154
  # Submission info
155
  content: Annotated[str, Field(description='The content of the file in JSON format to upload.')]
 
162
  @property
163
  def filename(self) -> str:
164
  """Filename of the file to upload."""
165
+ if not self.model_key:
 
166
  raise PydanticCustomError(
167
+ "model_key_invalid", "Model key {model_key!r} is invalid.", {"model_key": self.model_key}
168
  )
169
  if not self.username:
170
  raise PydanticCustomError(
171
  "username_invalid", "Username {username!r} is invalid.", {"username": self.username}
172
  )
173
+ # "2025-01-15T10:30:00Z" -> "20250115T103000"
174
+ submit_time = self.submit_time.replace(":", "").replace("-", "").rstrip("Z")
175
+ return f"{submit_time}_{self.model_key}_{self.model_dtype}_{self.model_sha}_{self.username}.json"
src/submission/submit.py CHANGED
@@ -25,7 +25,7 @@ REQUESTED_MODELS: set[str] | None = None
25
  def add_new_submit(
26
  model: str,
27
  base_model: str,
28
- revision: str,
29
  precision: str,
30
  weight_type: str,
31
  model_type: str,
@@ -52,24 +52,29 @@ def add_new_submit(
52
  REQUESTED_MODELS, _ = already_submitted_models(settings.EVAL_REQUESTS_PATH.as_posix())
53
 
54
  # Use provided user_id, or extract from model name as fallback
55
- user_name = user_id
56
 
57
- precision = precision.split(" ")[0] if precision else "float16"
 
58
  # Does the model actually exist?
59
- if not revision or revision == "":
60
- revision = "main"
61
 
62
  # Is the model on the hub?
63
  if weight_type in ["Delta", "Adapter"]:
64
  base_model_on_hub, error, _ = is_model_on_hub(
65
- model_name=base_model, revision=revision, token=settings.HF_TOKEN.get_secret_value(), test_tokenizer=True
 
 
 
66
  )
67
  if not base_model_on_hub:
68
  return styled_error(f'Base model "{base_model}" {error}')
69
 
70
  if not weight_type == "Adapter":
71
  model_on_hub, error, _ = is_model_on_hub(
72
- model_name=model, revision=revision, token=settings.HF_TOKEN.get_secret_value(), test_tokenizer=True
 
 
 
73
  )
74
  if not model_on_hub:
75
  return styled_error(f'Model "{model}" {error}')
@@ -89,7 +94,7 @@ def add_new_submit(
89
  # Validate required fields
90
  if not model or not model.strip():
91
  return styled_error("Model name is required.")
92
- if not user_name or not user_name.strip():
93
  return styled_error("User ID/username is required. Please make sure you are logged in.")
94
 
95
  # Get current UTC time for submit_time
@@ -103,8 +108,9 @@ def add_new_submit(
103
 
104
  # Organize all fields into a comprehensive JSON structure for the content field
105
  # This will be the complete JSON that gets uploaded as a file
 
106
  complete_submission_content = {
107
- "username": user_name,
108
  "model_id": model,
109
  "base_model": base_model or "",
110
  "model_sha": revision,
@@ -123,7 +129,7 @@ def add_new_submit(
123
 
124
  # Request JSON for the API call - includes all fields separately
125
  request_json = {
126
- "username": user_name,
127
  "model_id": model,
128
  "base_model": base_model or "",
129
  "model_sha": revision,
 
25
  def add_new_submit(
26
  model: str,
27
  base_model: str,
28
+ revision: str | None,
29
  precision: str,
30
  weight_type: str,
31
  model_type: str,
 
52
  REQUESTED_MODELS, _ = already_submitted_models(settings.EVAL_REQUESTS_PATH.as_posix())
53
 
54
  # Use provided user_id, or extract from model name as fallback
 
55
 
56
+ if " " in precision:
57
+ precision = precision.split(" ")[0]
58
  # Does the model actually exist?
59
+ revision = revision or None
 
60
 
61
  # Is the model on the hub?
62
  if weight_type in ["Delta", "Adapter"]:
63
  base_model_on_hub, error, _ = is_model_on_hub(
64
+ model_name=base_model,
65
+ revision=revision or "main",
66
+ token=settings.HF_TOKEN.get_secret_value(),
67
+ test_tokenizer=True,
68
  )
69
  if not base_model_on_hub:
70
  return styled_error(f'Base model "{base_model}" {error}')
71
 
72
  if not weight_type == "Adapter":
73
  model_on_hub, error, _ = is_model_on_hub(
74
+ model_name=model,
75
+ revision=revision or "main",
76
+ token=settings.HF_TOKEN.get_secret_value(),
77
+ test_tokenizer=True,
78
  )
79
  if not model_on_hub:
80
  return styled_error(f'Model "{model}" {error}')
 
94
  # Validate required fields
95
  if not model or not model.strip():
96
  return styled_error("Model name is required.")
97
+ if not user_id or not user_id.strip():
98
  return styled_error("User ID/username is required. Please make sure you are logged in.")
99
 
100
  # Get current UTC time for submit_time
 
108
 
109
  # Organize all fields into a comprehensive JSON structure for the content field
110
  # This will be the complete JSON that gets uploaded as a file
111
+ model_type = model_type.rpartition(":")[2].strip() # "⭕ : instruction-tuned" -> "instruction-tuned"
112
  complete_submission_content = {
113
+ "user_id": user_id,
114
  "model_id": model,
115
  "base_model": base_model or "",
116
  "model_sha": revision,
 
129
 
130
  # Request JSON for the API call - includes all fields separately
131
  request_json = {
132
+ "username": user_id,
133
  "model_id": model,
134
  "base_model": base_model or "",
135
  "model_sha": revision,