# app.py from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from diffusers import AutoPipelineForText2Image from contextlib import asynccontextmanager import io import base64 import os # --- Pydantic Models --- class ImageRequest(BaseModel): prompt: str negative_prompt: str = "" steps: int = 20 # Lowered for faster CPU inference class ImageResponse(BaseModel): image_base64: str # --- App State and Lifespan --- app_state = {} @asynccontextmanager async def lifespan(app: FastAPI): # Load the model on startup hf_token = os.getenv("HF_TOKEN") # Still needed to accept terms if not hf_token: raise RuntimeError("HF_TOKEN environment variable not set! Please add it in the Space settings.") # --- *** THESE ARE THE CHANGES FOR CPU *** --- # 1. Use the smaller Stable Diffusion v1.5 model model_id = "runwayml/stable-diffusion-v1-5" print(f"Loading model: {model_id} for CPU...") # 2. Load the pipeline without GPU-specific settings pipe = AutoPipelineForText2Image.from_pretrained( model_id, token=hf_token ) # Note: We do not use .to("cuda") # --- ************************************** --- app_state["pipe"] = pipe print("Model loaded successfully onto CPU.") yield # Clean up on shutdown app_state.clear() print("Resources cleaned up.") # --- FastAPI App --- app = FastAPI(lifespan=lifespan) @app.get("/") def root(): return {"status": "Text-to-Image CPU API is running"} @app.post("/generate-image", response_model=ImageResponse) def generate_image(request: ImageRequest): if "pipe" not in app_state: raise HTTPException(status_code=503, detail="Model is not ready.") pipe = app_state["pipe"] print(f"Generating image for prompt: '{request.prompt}'") try: image = pipe( prompt=request.prompt, negative_prompt=request.negative_prompt, num_inference_steps=request.steps ).images[0] buffer = io.BytesIO() image.save(buffer, format="PNG") img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") return ImageResponse(image_base64=img_str) except Exception as e: print(f"Error during image generation: {e}") raise HTTPException(status_code=500, detail=str(e))