from fastapi import FastAPI from pydantic import BaseModel from typing import Any, Dict, List, Optional import json import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel BASE_MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507" ADAPTER_MODEL_ID = "iteratehack/battleground-rlaif-qwen-gamehistory-grpo" DEFAULT_MAX_NEW_TOKENS = 256 DEFAULT_TEMPERATURE = 0.2 app = FastAPI() tokenizer: Optional[AutoTokenizer] = None model = None device = "cuda" if torch.cuda.is_available() else "cpu" INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI. Given the current game state as a JSON object, choose the best full-turn sequence of actions and respond with a single JSON object in this exact format: {"actions":[{"type":"","tavern_index":,"hand_index":,"board_index":,"card_name":}, ...]} Rules: 1. Respond with JSON only. Do not add explanations or any extra text. 2. The top-level object must have exactly one key: "actions". 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of atomic action objects. 4. Use 0-based integers for indices or null when not used. 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD", "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". 6. "card_name" must exactly match a card name from the game state when required, otherwise null. Now here is the game state JSON: """ class GenerateRequest(BaseModel): phase: Optional[str] = None turn: Optional[int] = None state: Dict[str, Any] max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS temperature: float = DEFAULT_TEMPERATURE def build_prompt(example: Dict[str, Any]) -> str: """Build a JSON-mode prompt (the only mode supported by this Space).""" state = example.get("state", {}) or {} gs = state.get("game_state", {}) or {} phase = example.get("phase", gs.get("phase", "PlayerTurn")) turn = example.get("turn", gs.get("turn_number", 0)) obj = { "task": "battlegrounds_policy_v1", "phase": phase, "turn": turn, "state": state, } state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False) return INSTRUCTION_PREFIX + "\n" + state_text def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]: text = text.strip() start_idx = text.find("{") if start_idx == -1: return None end_idx = text.rfind("}") if end_idx == -1: return None json_str = text[start_idx : end_idx + 1] try: obj = json.loads(json_str) except Exception: return None if not isinstance(obj, dict): return None seq = None if "actions" in obj: if isinstance(obj["actions"], list): seq = obj["actions"] elif isinstance(obj["actions"], dict): seq = [obj["actions"]] elif "action" in obj: if isinstance(obj["action"], list): seq = obj["action"] elif isinstance(obj["action"], dict): seq = [obj["action"]] if seq is None: return None actions: List[Dict[str, Any]] = [] for step in seq: if not isinstance(step, dict): return None actions.append(step) return actions def load_model() -> None: global tokenizer, model if tokenizer is not None and model is not None: return tok = AutoTokenizer.from_pretrained(ADAPTER_MODEL_ID, trust_remote_code=True) if tok.pad_token is None: tok.pad_token = tok.eos_token tok.padding_side = "left" if torch.cuda.is_available(): base = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, ) else: base = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, torch_dtype=torch.float32, trust_remote_code=True, ) peft_model = PeftModel.from_pretrained(base, ADAPTER_MODEL_ID) if not torch.cuda.is_available(): peft_model.to(device) peft_model.eval() tokenizer = tok model = peft_model @app.on_event("startup") async def _startup_event() -> None: load_model() @app.get("/") def root(): return { "status": "ok", "message": "DeepBattler Battlegrounds Space is running", "base_model": BASE_MODEL_ID, "adapter_model": ADAPTER_MODEL_ID, } @app.post("/generate_actions") def generate_actions(req: GenerateRequest): load_model() example = { "phase": req.phase, "turn": req.turn, "state": req.state, } prompt = build_prompt(example) inputs = tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=req.max_new_tokens, do_sample=True, temperature=req.temperature, ) generated_ids = output_ids[0, inputs["input_ids"].shape[1] :] completion = tokenizer.decode(generated_ids, skip_special_tokens=True) actions = parse_actions_from_completion(completion) return { "actions": actions, "raw_completion": completion, }