# app.py
import gradio as gr
import asyncio
import os
import shutil
from pathlib import Path
import time
import json
import traceback
from typing import List, Dict, Tuple, Optional
from pragent.backend.text_pipeline import pipeline as run_text_extraction
from pragent.backend.figure_table_pipeline import run_figure_extraction
from pragent.backend.blog_pipeline import generate_text_blog, generate_final_post
from pragent.backend.agents import setup_client, call_text_llm_api
import base64
import mimetypes
import re
# --- 新增模块:用于将 Markdown Post 格式化为结构化 JSON ---
FORMAT_PROMPT_TEMPLATE = '''
You are an expert in structuring social media content. Your task is to convert a post written in Markdown format into a structured JSON format. The JSON structure depends on the target platform.
**Platform:** {platform}
**Markdown Content:**
---
{markdown_text}
---
**Instructions:**
{platform_instructions}
'''
TWITTER_INSTRUCTIONS = '''
Convert the content into a JSON array representing a Twitter thread. Each element in the array is a tweet object.
- Each tweet object must have a "text" key. The text should be plain text, without any Markdown formatting (e.g., no `*`, `#`, `[]()`)
- If a tweet is associated with an image, add an "image_index" key with the corresponding zero-based index from the provided asset list. For example, if the first image in the Markdown `` is used, its index is 0.
- Ensure the thread flows logically. Split the text into multiple tweets if necessary.
**Asset List (for reference):**
{asset_list}
**JSON Output Format:**
[
{{ "text": "Text of the first tweet.", "image_index": 0 }},
{{ "text": "Text of the second tweet." }},
{{ "text": "Text of the third tweet.", "image_index": 1 }}
]
'''
XIAOHONGSHU_INSTRUCTIONS = '''
Convert the content into a single JSON object for a Xiaohongshu post.
- The JSON object must have a "title" key. Extract the main title from the Markdown (usually the first H1/H2 heading). The title should be plain text.
- The JSON object must have a "body" key containing the main text content, with emojis. The body text should be plain text, without any Markdown formatting (e.g., no `*`, `#`, `[]()`)
- The JSON object must have an "image_indices" key, which is an array of all image indexes used in the post, in the order they appear.
**Asset List (for reference):**
{asset_list}
**JSON Output Format:**
{{
"title": "Your Catchy Title Here",
"body": "The full body text of the post...",
"image_indices": [0, 1, 2, 3]
}}
'''
def image_to_base64(path: str) -> str:
"""读取图片文件并将其转换为 Base64 Data URL 字符串"""
try:
# 根据文件路径猜测MIME类型
mime_type, _ = mimetypes.guess_type(path)
if mime_type is None:
mime_type = "image/jpeg"
if path.lower().endswith(".png"):
mime_type = "image/png"
else:
mime_type = "image/jpeg"
with open(path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return f"data:{mime_type};base64,{encoded_string}"
except Exception as e:
print(f"[!] Error converting image to base64: {e}")
return ""
LOGO_PATH = "pragent/logo/logo.png"
LOGO_BASE64 = ""
if os.path.exists(LOGO_PATH):
LOGO_BASE64 = image_to_base64(LOGO_PATH)
else:
print(f"[!] Warning: Logo file not found at {LOGO_PATH}")
async def format_post_for_display(
markdown_text: str,
assets: Optional[List[Dict]],
platform: str,
client,
model: str
) -> Optional[Dict]:
"""
使用 LLM 将 Markdown 格式的帖子转换为结构化的 JSON 以便在 UI 中显示。
"""
if platform == 'twitter':
instructions = TWITTER_INSTRUCTIONS
elif platform == 'xiaohongshu':
instructions = XIAOHONGSHU_INSTRUCTIONS
else:
return None
asset_str = "No assets provided."
if assets:
asset_str = "\n".join([f"- Index {i}: {asset['dest_name']}" for i, asset in enumerate(assets)])
prompt = FORMAT_PROMPT_TEMPLATE.format(
platform=platform.capitalize(),
markdown_text=markdown_text,
platform_instructions=instructions.format(asset_list=asset_str),
)
system_prompt = "You are a content formatting expert. Output only valid JSON."
response_str = ""
try:
response_str = await call_text_llm_api(client, system_prompt, prompt, model)
json_str = None
match = re.search(r"```(?:json)?\s*([\s\S]+?)\s*```", response_str)
if match:
json_str = match.group(1)
else:
json_str = response_str
return json.loads(json_str.strip())
except Exception as e:
print(f"[!] Error formatting post for display: {e}")
traceback.print_exc()
return None
# --- Gradio UI 渲染帮助函数 ---
def render_twitter_thread(thread_data: List[Dict], assets: List[str]) -> str:
html_parts = []
for i, tweet in enumerate(thread_data):
text_html = tweet.get("text", "").replace("\n", "
")
image_html = ""
if "image_index" in tweet and tweet["image_index"] < len(assets):
img_idx = tweet["image_index"]
img_path = assets[img_idx]
base64_string = image_to_base64(img_path)
image_html = f'
{traceback.format_exc()}"
yield gr.update(value=f"❌ An error occurred: {e}"), gr.update(value=error_html, visible=True), gr.update(visible=False)
finally:
# Cleanup is disabled to prevent race conditions with Gradio's reloader
# and to allow inspection of generated files.
pass
# if work_dir and work_dir.exists():
# shutil.rmtree(work_dir)
# --- Gradio 应用界面定义 ---
# 自定义 CSS
CUSTOM_CSS = '''
/* --- Twitter Style --- */
.tweet-row {
display: flex;
align-items: flex-start;
padding: 16px;
border: 1px solid #e1e8ed;
border-radius: 15px;
margin-bottom: 12px;
background-color: #ffffff;
}
.avatar-container {
flex-shrink: 0;
margin-right: 12px;
}
.avatar {
width: 48px;
height: 48px;
border-radius: 50%;
object-fit: cover;
}
.tweet-content {
width: 100%;
}
.user-info {
font-size: 15px;
font-weight: bold;
}
.user-info span {
color: #536471;
font-weight: normal;
}
.tweet-text {
font-size: 15px;
line-height: 1.5;
color: #0f1419;
margin-top: 4px;
word-wrap: break-word;
}
.tweet-image-container {
margin-top: 12px;
}
.tweet-image {
width: 100%;
border-radius: 15px;
border: 1px solid #ddd;
display: block;
}
/* --- Xiaohongshu Style --- */
.xhs-title { font-size: 20px; font-weight: bold; color: #333; margin-bottom: 10px; }
.xhs-body { font-size: 16px; line-height: 1.8; color: #555; word-wrap: break-word; }
#output_container {
border: 2px dashed #ccc;
padding: 20px;
min-height: 100px;
border-radius: 15px;
}
.carousel-container { position: relative; max-width: 100%; margin: auto; overflow: hidden; border-radius: 10px; }
.carousel-slide { display: none; animation: fade 0.5s ease-in-out; }
.carousel-slide:first-child { display: block; }
.carousel-slide img { width: 100%; display: block; }
.prev, .next { cursor: pointer; position: absolute; top: 50%; width: auto; padding: 16px; margin-top: -22px; color: white; font-weight: bold; font-size: 20px; transition: 0.3s ease; border-radius: 0 3px 3px 0; user-select: none; background-color: rgba(0,0,0,0.3); }
.next { right: 0; border-radius: 3px 0 0 3px; }
.prev:hover, .next:hover { background-color: rgba(0,0,0,0.6); }
.carousel-numbertext { color: #f2f2f2; font-size: 12px; padding: 8px 12px; position: absolute; top: 0; background-color: rgba(0,0,0,0.5); border-radius: 0 0 5px 0; }
@keyframes fade { from {opacity: .4} to {opacity: 1}}
'''
ACTIVATE_CAROUSEEL_JS = '''
() => {
// We use a small 100ms delay to ensure Gradio has finished updating the HTML DOM
setTimeout(() => {
const container = document.getElementById('output_container');
if (container) {
const carousel = container.querySelector('.carousel-container');
// Check if a carousel exists and hasn't been initialized yet
if (carousel && !carousel.dataset.initialized) {
console.log("PRAgent Carousel Script: JS listener has found and is activating the carousel ->", carousel.id);
let slideIndex = 1;
const slides = carousel.getElementsByClassName("carousel-slide");
const prevButton = carousel.querySelector(".prev");
const nextButton = carousel.querySelector(".next");
if (slides.length === 0) return;
const showSlides = () => {
if (slideIndex > slides.length) { slideIndex = 1; }
if (slideIndex < 1) { slideIndex = slides.length; }
for (let i = 0; i < slides.length; i++) {
slides[i].style.display = "none";
}
slides[slideIndex - 1].style.display = "block";
};
if (prevButton) {
prevButton.addEventListener('click', () => { slideIndex--; showSlides(); });
}
if (nextButton) {
nextButton.addEventListener('click', () => { slideIndex++; showSlides(); });
}
showSlides(); // Show the first slide
carousel.dataset.initialized = 'true'; // Mark as initialized to prevent re-activation
}
}
}, 100);
}
'''
with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
demo.queue()
gr.Markdown("# 🚀 PRAgent: Paper to Social Media Post")
gr.Markdown("Upload a research paper PDF, and I will generate a social media post for Twitter or Xiaohongshu, complete with images and platform-specific styling.")
with gr.Row():
with gr.Column(scale=1):
pdf_upload = gr.File(label="Upload PDF Paper", file_types=[".pdf"])
with gr.Accordion("Advanced Settings", open=True):
text_api_key_input = gr.Textbox(label="Text API Key", type="password", placeholder="Required: sk-...")
vision_api_key_input = gr.Textbox(label="Vision API Key (Optional)", type="password", placeholder="Optional: If not provided, Text API Key will be used")
base_url_input = gr.Textbox(label="API Base URL")
text_model_input = gr.Textbox(label="Text Model")
vision_model_input = gr.Textbox(label="Vision Model")
platform_select = gr.Radio(["twitter", "xiaohongshu"], label="Target Platform", value="twitter")
language_select = gr.Radio([("English", "en"), ("Chinese", "zh")], label="Language", value="en")
generate_btn = gr.Button("✨ Generate Post", variant="primary")
with gr.Column(scale=2):
status_text = gr.Markdown("Idle. Please upload a file and click generate.", visible=True)
output_container = gr.HTML(elem_id="output_container")
download_button = gr.File(label="Download Post & Images", visible=False)
# 绑定按钮点击事件
click_event = generate_btn.click(
fn=process_pdf,
inputs=[
pdf_upload,
text_api_key_input,
vision_api_key_input,
base_url_input,
text_model_input,
vision_model_input,
platform_select,
language_select
],
outputs=[status_text, output_container, download_button]
)
# 链接 .then() 事件,在前一个事件成功后执行 JavaScript
click_event.then(
fn=None, # 这里不需要执行 Python 函数
inputs=None,
outputs=None,
js=ACTIVATE_CAROUSEEL_JS # 将 JS 放在独立的事件中
)
if __name__ == "__main__":
# Create the hidden temp directory
Path(".temp_output").mkdir(exist_ok=True)
demo.launch()