import uuid from datetime import datetime, timezone from pydantic import BaseModel, Field from pydantic.json_schema import SkipJsonSchema CONFIG = { "max_new_tokens": 3048, "top_p": 0.95, "temperature": 0.6 } class UserRequest(BaseModel): session_id: str request_id: uuid.UUID = Field(uuid.uuid4()) prompt: str = None steering: bool = True coeff: float = -1.0 max_new_tokens: int = Field(CONFIG["max_new_tokens"], le=3048) top_p: float = Field(CONFIG["top_p"], ge=0.0, le=1.0) temperature: float = Field(CONFIG["temperature"], ge=0.0, le=1.0) def get_api_format(self): return { "prompt": self.prompt, "steering": self.steering, "coeff": self.coeff, "generation_config": { "max_new_tokens": self.max_new_tokens, "top_p": self.top_p, "temperature": self.temperature } } class SteeringOutput(UserRequest): request_id: SkipJsonSchema[uuid.UUID] = Field(exclude=True) max_new_tokens: SkipJsonSchema[int] = Field(exclude=True) reasoning: str = None answer: str = None upvote: bool = None timestamp: str = Field(datetime.now(timezone.utc).isoformat())