ImageGEN / app.py
Rx Codex AI
Update app.py
0251672 verified
# app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from diffusers import AutoPipelineForText2Image
from contextlib import asynccontextmanager
import io
import base64
import os
# --- Pydantic Models ---
class ImageRequest(BaseModel):
prompt: str
negative_prompt: str = ""
steps: int = 20 # Lowered for faster CPU inference
class ImageResponse(BaseModel):
image_base64: str
# --- App State and Lifespan ---
app_state = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the model on startup
hf_token = os.getenv("HF_TOKEN") # Still needed to accept terms
if not hf_token:
raise RuntimeError("HF_TOKEN environment variable not set! Please add it in the Space settings.")
# --- *** THESE ARE THE CHANGES FOR CPU *** ---
# 1. Use the smaller Stable Diffusion v1.5 model
model_id = "runwayml/stable-diffusion-v1-5"
print(f"Loading model: {model_id} for CPU...")
# 2. Load the pipeline without GPU-specific settings
pipe = AutoPipelineForText2Image.from_pretrained(
model_id,
token=hf_token
)
# Note: We do not use .to("cuda")
# --- ************************************** ---
app_state["pipe"] = pipe
print("Model loaded successfully onto CPU.")
yield
# Clean up on shutdown
app_state.clear()
print("Resources cleaned up.")
# --- FastAPI App ---
app = FastAPI(lifespan=lifespan)
@app.get("/")
def root():
return {"status": "Text-to-Image CPU API is running"}
@app.post("/generate-image", response_model=ImageResponse)
def generate_image(request: ImageRequest):
if "pipe" not in app_state:
raise HTTPException(status_code=503, detail="Model is not ready.")
pipe = app_state["pipe"]
print(f"Generating image for prompt: '{request.prompt}'")
try:
image = pipe(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
num_inference_steps=request.steps
).images[0]
buffer = io.BytesIO()
image.save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
return ImageResponse(image_base64=img_str)
except Exception as e:
print(f"Error during image generation: {e}")
raise HTTPException(status_code=500, detail=str(e))