Spaces:
Sleeping
Sleeping
| """Database class for the evaluation.""" | |
| import json | |
| import os | |
| import subprocess | |
| import threading | |
| from contextlib import contextmanager | |
| from datetime import datetime, timezone | |
| from typing import Optional | |
| import psycopg2 | |
| import psycopg2.extras | |
| from huggingface_hub import CommitOperationAdd, HfApi | |
| REQUIRED_TABLES = ('tasks', 'assignments', 'task_config', 'answers', 'allocation_history') | |
| def _get_time_now() -> str: | |
| return datetime.now(timezone.utc).isoformat() | |
| class DB: | |
| """Database for the evaluation.""" | |
| def __init__(self, db_path: str, schema_path: Optional[str] = None, verify_only: bool = True): | |
| """Initialize the database. | |
| Note: db_path/schema_path kept for compatibility; Postgres uses env vars. | |
| Required env vars: PGHOST, PGPORT, PGUSER, PGPASSWORD, PGDATABASE | |
| """ | |
| self.db_path = db_path # kept for compatibility | |
| self._lock = threading.Lock() | |
| # Validate environment variables early | |
| self._pg_dsn = dict( | |
| host=os.environ.get('PGHOST'), | |
| port=int(os.environ.get('PGPORT')) if os.environ.get('PGPORT') else None, | |
| user=os.environ.get('PGUSER'), | |
| password=os.environ.get('PGPASSWORD'), | |
| database=os.environ.get('PGDATABASE'), | |
| ) | |
| if not all([self._pg_dsn['host'], self._pg_dsn['user'], self._pg_dsn['password'], self._pg_dsn['database']]): | |
| raise OSError('Missing PostgreSQL env vars: PGHOST, PGUSER, PGPASSWORD, PGDATABASE') | |
| # Verify required tables exist | |
| if verify_only: | |
| with self._connect() as conn: | |
| missing = [] | |
| for t in REQUIRED_TABLES: | |
| cur = conn.execute('SELECT to_regclass(%s)', (f'public.{t}',)) | |
| exists = cur.fetchone()[0] | |
| if exists is None: | |
| missing.append(t) | |
| if missing: | |
| raise FileNotFoundError(f'Missing required tables: {", ".join(missing)}') | |
| def _connect(self): | |
| """Context manager yielding a lightweight wrapper with conn.execute(...). | |
| - Uses psycopg2 with autocommit by default (emulates sqlite autocommit usage). | |
| - Supports BEGIN/COMMIT/ROLLBACK issued via execute for backward compatibility. | |
| """ | |
| class _CursorWrapper: | |
| def __init__(self, cursor): | |
| self._cursor = cursor | |
| def fetchone(self): | |
| return self._cursor.fetchone() | |
| def fetchall(self): | |
| return self._cursor.fetchall() | |
| def rowcount(self): | |
| return self._cursor.rowcount | |
| class _ConnWrapper: | |
| def __init__(self, real_conn): | |
| self._conn = real_conn | |
| def execute(self, sql: str, params: tuple | list | None = None): | |
| sql_strip = (sql or '').strip().upper() | |
| # Map transaction control statements | |
| if sql_strip.startswith('BEGIN'): | |
| self._conn.autocommit = False | |
| cur = self._conn.cursor() | |
| try: | |
| cur.execute('BEGIN;') | |
| finally: | |
| cur.close() | |
| # Return a dummy cursor wrapper | |
| class _DummyCursor: | |
| def fetchone(self): | |
| return None | |
| def fetchall(self): | |
| return [] | |
| def rowcount(self): | |
| return 0 | |
| return _DummyCursor() | |
| if sql_strip == 'COMMIT': | |
| try: | |
| self._conn.commit() | |
| finally: | |
| self._conn.autocommit = True | |
| class _DummyCursor: | |
| def fetchone(self): | |
| return None | |
| def fetchall(self): | |
| return [] | |
| def rowcount(self): | |
| return 0 | |
| return _DummyCursor() | |
| if sql_strip == 'ROLLBACK': | |
| try: | |
| self._conn.rollback() | |
| finally: | |
| self._conn.autocommit = True | |
| class _DummyCursor: | |
| def fetchone(self): | |
| return None | |
| def fetchall(self): | |
| return [] | |
| def rowcount(self): | |
| return 0 | |
| return _DummyCursor() | |
| # Regular statement | |
| with self._conn.cursor() as c: | |
| c.execute(sql, params or ()) | |
| # For SELECTs, we need a cursor able to fetch after context ends. | |
| # Use named cursor copy via RealDictCursor? Simpler: re-execute on a new cursor kept open. | |
| # Instead, materialize results if it's a query returning rows. | |
| if c.description is not None: | |
| rows = c.fetchall() | |
| # Create a lightweight object mimicking cursor | |
| class _Mat: | |
| def __init__(self, rows): | |
| self._rows = rows | |
| def fetchone(self): | |
| return self._rows[0] if self._rows else None | |
| def fetchall(self): | |
| return list(self._rows) | |
| def rowcount(self): | |
| return len(self._rows) | |
| return _Mat(rows) | |
| else: | |
| # Non-SELECT: return wrapper with rowcount captured | |
| rc = c.rowcount | |
| class _Affect: | |
| def __init__(self, rc): | |
| self._rc = rc | |
| def fetchone(self): | |
| return None | |
| def fetchall(self): | |
| return [] | |
| def rowcount(self): | |
| return self._rc | |
| return _Affect(rc) | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc, tb): | |
| # If a transaction is open and no explicit commit/rollback occurred, rollback on error | |
| if not self._conn.autocommit: | |
| if exc is None: | |
| self._conn.commit() | |
| else: | |
| self._conn.rollback() | |
| self._conn.close() | |
| # Open a new psycopg2 connection per context | |
| conn = psycopg2.connect(**{k: v for k, v in self._pg_dsn.items() if v is not None}) | |
| conn.autocommit = True | |
| wrapper = _ConnWrapper(conn) | |
| try: | |
| yield wrapper | |
| finally: | |
| # _ConnWrapper.__exit__ handles close when used with "with". If not, ensure close here. | |
| try: | |
| if not conn.closed: | |
| conn.close() | |
| except Exception: | |
| pass | |
| # ========== Answer Recording ========== | |
| def record_answer( | |
| self, | |
| round_id: str, | |
| user_id: str, | |
| item_id: str, | |
| label: str, | |
| image_path: str, | |
| score: int, | |
| words_not_present: list[str], | |
| ) -> None: | |
| """Record an answer into the database.""" | |
| now = _get_time_now() | |
| with self._connect() as conn, self._lock: | |
| conn.execute( | |
| ( | |
| 'INSERT INTO answers (round_id, user_id, item_id, label, image_path, ' | |
| 'score, words_not_present, answered_at) ' | |
| 'VALUES (%s, %s, %s, %s, %s, %s, %s, %s) ' | |
| 'ON CONFLICT (round_id, user_id, item_id) DO UPDATE SET ' | |
| 'score = EXCLUDED.score, ' | |
| 'words_not_present = EXCLUDED.words_not_present, ' | |
| 'answered_at = EXCLUDED.answered_at' | |
| ), | |
| ( | |
| round_id, | |
| user_id, | |
| item_id, | |
| label, | |
| image_path, | |
| score, | |
| json.dumps(words_not_present, ensure_ascii=False), | |
| now, | |
| ), | |
| ) | |
| def get_answered_item_ids(self, round_id: str, user_id: str) -> set[str]: | |
| """Get all the answered item ids of a user in a round.""" | |
| with self._connect() as conn: | |
| cur = conn.execute( | |
| 'SELECT item_id FROM answers WHERE round_id=%s AND user_id=%s', | |
| (round_id, user_id), | |
| ) | |
| return {r[0] for r in cur.fetchall()} # fetchall returns a list of tuples | |
| def get_answer(self, round_id: str, user_id: str, item_id: str) -> Optional[tuple[int, list[str]]]: | |
| """Get an answer of a user in a round for a specific item. | |
| Returns: | |
| tuple[int, list[str]]: score (int) and list of words not present in the image. | |
| None if the answer does not exist. | |
| """ | |
| with self._connect() as conn: | |
| cur = conn.execute( | |
| 'SELECT score, words_not_present FROM answers WHERE round_id=%s AND user_id=%s AND item_id=%s', | |
| (round_id, user_id, item_id), | |
| ) | |
| row = cur.fetchone() | |
| if not row: | |
| return None | |
| score, words_not_present = row | |
| return int(score), json.loads(words_not_present) | |
| # ========== Lease Management (for in_progress tasks) ========== | |
| def cleanup_expired_leases(self, round_id: str) -> int: | |
| """Auto-recover expired in_progress tasks to pending status. | |
| Returns the number of tasks recovered. | |
| """ | |
| now = _get_time_now() | |
| with self._connect() as conn, self._lock: | |
| cur = conn.execute( | |
| """ | |
| UPDATE assignments | |
| SET status = 'pending', lease_until = NULL, started_at = NULL | |
| WHERE round_id = %s | |
| AND status = 'in_progress' | |
| AND lease_until < %s | |
| """, | |
| (round_id, now), | |
| ) | |
| return cur.rowcount or 0 | |
| # ========== Auto Commit (For HF Spaces) ========== | |
| def commit_and_push_db(self) -> None: | |
| """Commit and push the database to the repository.""" | |
| space_id = os.getenv('SPACE_ID', None) | |
| if not space_id: | |
| return | |
| try: | |
| repo_root = subprocess.run( | |
| ['git', 'rev-parse', '--show-toplevel'], | |
| check=True, | |
| capture_output=True, | |
| text=True, | |
| ).stdout.strip() | |
| ts = _get_time_now() | |
| db_rel = os.path.relpath(self.db_path, repo_root) | |
| subprocess.run(['git', 'add', db_rel], check=True, cwd=repo_root) | |
| api = HfApi(token=os.getenv('HF_TOKEN', None)) | |
| api.create_commit( | |
| repo_id='JaneDing2025/IconEval_Pilot', | |
| repo_type='dataset', | |
| operations=[CommitOperationAdd(path_in_repo=db_rel, path_or_fileobj=self.db_path)], | |
| commit_message=f'update db {ts}', | |
| ) | |
| except Exception as e: | |
| raise RuntimeError(f'Failed to commit and push the database: {e}') | |