Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import logging, json | |
| from pathlib import Path | |
| import asyncio | |
| import aiohttp | |
| import pandas as pd | |
| import gradio as gr | |
| from gradio_toggle import Toggle | |
| from scheduler import load_scheduler | |
| from schemas import UserRequest, SteeringOutput, CONFIG | |
| MAX_RETRIES = 10 | |
| MAX_RETRY_WAIT_TIME = 75 | |
| MIN_RETRY_WAIT_TIME = 5 | |
| ENDPOINT_ALIVE = False | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| API_URL = "https://a6k5m81qw14hkvhz.us-east-1.aws.endpoints.huggingface.cloud" | |
| headers = { | |
| "Accept" : "application/json", | |
| "Authorization": f"Bearer {HF_TOKEN}", | |
| "Content-Type": "application/json" | |
| } | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(name)s %(levelname)s:%(message)s') | |
| logger = logging.getLogger(__name__) | |
| model_name = "DeepSeek-R1-Distill-Qwen-7B" | |
| examples = pd.read_csv("assets/examples.csv") | |
| instances = {} | |
| scheduler = load_scheduler() | |
| HEAD = """ | |
| <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.7.2/css/all.min.css" integrity="sha512-Evv84Mr4kqVGRNSgIGL/F/aIDqQb7xQ2vcrdIwxfjThSH8CSR7PBEakCr51Ck+w+/U6swU2Im1vVX0SVk9ABhg==" crossorigin="anonymous" referrerpolicy="no-referrer" /> | |
| """ | |
| HTML = f""" | |
| <div id="banner"> | |
| <h1><img src="/gradio_api/file=assets/rudder_3094973.png"> LLM Censorship Steering</h1> | |
| <div id="links" class="row" style="margin-bottom: .8em;"> | |
| <i class="fa-solid fa-file-pdf fa-lg"></i><a href="https://arxiv.org/abs/2504.17130"> Paper</a> | |
| <i class="fa-solid fa-blog fa-lg"></i><a href="https://hannahxchen.github.io/blog/2025/censorship-steering"> Blog Post</a> | |
| <i class="fa-brands fa-github fa-lg"></i><a href="https://github.com/hannahxchen/llm-censorship-steering"> Code</a> | |
| </div> | |
| <div id="cover"> | |
| <img src="/gradio_api/file=assets/demo-cover.png"> | |
| </div> | |
| </div> | |
| """ | |
| CSS = """ | |
| div.gradio-container .app { | |
| max-width: 1600px !important; | |
| } | |
| div#banner { | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| justify-content: center; | |
| h1 { | |
| font-size: 32px; | |
| line-height: 1.35em; | |
| margin-bottom: 0em; | |
| display: flex; | |
| img { | |
| display: inline; | |
| height: 1.35em; | |
| } | |
| } | |
| div#cover img { | |
| max-height: 130px; | |
| padding-top: 0.5em; | |
| } | |
| } | |
| @media (max-width: 500px) { | |
| div#banner { | |
| h1 { | |
| font-size: 22px; | |
| } | |
| div#links { | |
| font-size: 14px; | |
| } | |
| } | |
| div#model-state p { | |
| font-size: 14px; | |
| } | |
| } | |
| div#main-components { | |
| align-items: flex-end; | |
| } | |
| div#steering-toggle { | |
| padding-top: 8px; | |
| padding-bottom: 8px; | |
| .toggle-label { | |
| color: var(--body-text-color); | |
| } | |
| span p { | |
| font-size: var(--block-info-text-size); | |
| line-height: var(--line-sm); | |
| color: var(--block-label-text-color); | |
| } | |
| } | |
| div#coeff-slider { | |
| padding-bottom: 5px; | |
| .slider_input_container span {color: var(--body-text-color);} | |
| .slider_input_container { | |
| display: flex; | |
| flex-wrap: wrap; | |
| input {appearance: auto;} | |
| } | |
| } | |
| div#coeff-slider .wrap .head { | |
| justify-content: unset; | |
| label {margin-right: var(--size-2);} | |
| label span { | |
| color: var(--body-text-color); | |
| margin-bottom: 0; | |
| } | |
| } | |
| """ | |
| slider_info = """\ | |
| <div style='display: flex; justify-content: space-between; line-height: normal;'>\ | |
| <span style='font-size: var(--block-info-text-size); color: var(--block-label-text-color);'>Less censorship</span>\ | |
| <span style='font-size: var(--block-info-text-size); color: var(--block-label-text-color);'>More censorship</span>\ | |
| </div>\ | |
| """\ | |
| slider_ticks = """\ | |
| <datalist id='values' style='display: flex; justify-content: space-between; width: 100%; padding: 0 6px;'>\ | |
| <option value='-2' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>-2</option>\ | |
| <option value='-1' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>-1</option>\ | |
| <option value='0' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>0</option>\ | |
| <option value='1' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>1</option>\ | |
| <option value='2' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>2</option>\ | |
| </datalist>\ | |
| """ | |
| JS = """ | |
| async() => { | |
| const node = document.querySelector("div.slider_input_container"); | |
| node.insertAdjacentHTML('beforebegin', "%s"); | |
| const sliderNode = document.querySelector("input#range_id_0"); | |
| sliderNode.insertAdjacentHTML('afterend', "%s"); | |
| sliderNode.setAttribute("list", "values"); | |
| document.querySelector('span.min_value').remove(); | |
| document.querySelector('span.max_value').remove(); | |
| } | |
| """ % (slider_info, slider_ticks) | |
| def initialize_instance(request: gr.Request): | |
| instances[request.session_hash] = [] | |
| logger.info("Number of connections: %d", len(instances)) | |
| return request.session_hash | |
| def cleanup_instance(request: gr.Request): | |
| global ENDPOINT_ALIVE | |
| session_id = request.session_hash | |
| if session_id in instances: | |
| with open("outputs.jsonl", "a") as f: | |
| for data in instances[session_id]: | |
| scheduler.append(data.model_dump()) | |
| json.dump(data.model_dump(), f) | |
| f.write("\n") | |
| del instances[session_id] | |
| if len(instances) == 0: | |
| ENDPOINT_ALIVE = False | |
| logger.info("Number of connections: %d", len(instances)) | |
| async def initialize_endpoint(): | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(f"{API_URL}/health", headers=headers) as resp: | |
| if resp.status == 200: | |
| return True | |
| else: | |
| resp_text = await resp.text() | |
| logger.error("API Error Code: %d, Message: %s", resp.status, resp_text) | |
| return False | |
| async def get_endpoint_state(): | |
| global ENDPOINT_ALIVE | |
| n = 0 | |
| sleep_time = MAX_RETRY_WAIT_TIME | |
| while n < MAX_RETRIES: | |
| n += 1 | |
| if not ENDPOINT_ALIVE: | |
| logger.info("Initializing inference endpoint") | |
| yield "Initializing" | |
| ENDPOINT_ALIVE = await initialize_endpoint() | |
| if ENDPOINT_ALIVE: | |
| logger.info("Inference endpoint is ready") | |
| gr.Info("Inference endpoint is ready") | |
| yield "Ready" | |
| break | |
| gr.Warning("Initializing inference endpoint\n(This may take 2~3 minutes)", duration=sleep_time) | |
| await asyncio.sleep(sleep_time) | |
| sleep_time = max(sleep_time * 0.8, MIN_RETRY_WAIT_TIME) | |
| if n == MAX_RETRIES: | |
| yield "Server Error" | |
| async def save_output(req: UserRequest, output: str): | |
| if "</think>" in output: | |
| p = [p for p in output.partition("</think>") if p != ""] | |
| reasoning = "".join(p[:-1]) | |
| if len(p) == 1: | |
| answer = None | |
| else: | |
| answer = p[-1] | |
| else: | |
| answer = None | |
| reasoning = output | |
| steering_output = SteeringOutput(**req.model_dump(), reasoning=reasoning, answer=answer) | |
| instances[req.session_id].append(steering_output) | |
| async def generate( | |
| session_id: str, prompt: str, steering: bool, coeff: float, | |
| max_new_tokens: int, top_p: float, temperature: float | |
| ): | |
| req = UserRequest( | |
| session_id=session_id, prompt=prompt, steering=steering, coeff=coeff, | |
| max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature | |
| ) | |
| data = req.get_api_format() | |
| logger.info("User Request: %s", data) | |
| generated_text = "" | |
| session = aiohttp.ClientSession() | |
| async with session.post(f"{API_URL}/generate", headers=headers, json=data) as resp: | |
| if resp.status == 200: | |
| generated_text += "<think>" | |
| async for chunk, _ in resp.content.iter_chunks(): | |
| generated_text += chunk.decode() | |
| yield generated_text | |
| else: | |
| logger.error("API Error Ccode: %d, Error Message: %s", resp.status, resp.text()) | |
| raise gr.Error("API Server Error") | |
| await session.close() | |
| if generated_text != "": | |
| await save_output(req, generated_text) | |
| async def post_process(session_id): | |
| return instances[session_id][-1].request_id, gr.update(interactive=True), gr.update(interactive=True) | |
| async def output_feedback(session_id, request_id, feedback): | |
| logger.info("Feedback received for request %s: %s", str(request_id), feedback) | |
| try: | |
| data = instances[session_id].pop() | |
| if data.request_id == request_id: | |
| if "Upvote" in feedback: | |
| setattr(data, "upvote", True) | |
| elif "Downvote" in feedback: | |
| setattr(data, "upvote", False) | |
| instances[session_id].append(data) | |
| gr.Info("Thank you for your feedback!") | |
| except: | |
| logger.debug("Feedback submission error") | |
| gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"]) | |
| theme = gr.themes.Base(primary_hue="emerald", text_size=gr.themes.sizes.text_lg).set() | |
| with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS, js=JS) as demo: | |
| session_id = gr.State() | |
| request_id = gr.State() | |
| endpoint_state = gr.State(get_endpoint_state) | |
| gr.HTML(HTML) | |
| with gr.Row(elem_id="main-components"): | |
| with gr.Column(scale=1): | |
| def render_state(endpoint_state): | |
| if endpoint_state == "Ready": | |
| color = "green" | |
| elif endpoint_state == "Server Error": | |
| color = "red" | |
| else: | |
| color = "orange" | |
| if endpoint_state != None: | |
| gr.Markdown(f'🤖 {model_name} | Inference Endpoint State: <span style="color:{color}; font-weight: bold;">{endpoint_state}</span>', elem_id="model-state") | |
| with gr.Row(): | |
| steer_toggle = Toggle(label="Steering", info="Turn off to generate original outputs", value=True, interactive=True, scale=2, elem_id="steering-toggle") | |
| coeff = gr.Slider(label="Coefficient:", value=-1.0, minimum=-2, maximum=2, step=0.1, scale=8, show_reset_button=False, elem_id="coeff-slider") | |
| def update_toggle(toggle_value): | |
| if toggle_value is True: | |
| return gr.update(label="Steering", info="Turn off to generate original outputs"), gr.update(interactive=True) | |
| else: | |
| return gr.update(label="No Steering", info="Turn on to steer model outputs"), gr.update(interactive=False) | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| with gr.Row(): | |
| temperature = gr.Slider(0, 1, step=0.1, value=CONFIG["temperature"], interactive=True, label="Temperature", scale=2) | |
| top_p = gr.Slider(0, 1, step=0.1, value=CONFIG["top_p"], interactive=True, label="Top p", scale=2) | |
| max_new_tokens = gr.Number(CONFIG["max_new_tokens"], minimum=10, maximum=CONFIG["max_new_tokens"], interactive=True, label="Max new tokens", scale=1) | |
| input_text = gr.Textbox(label="Input", placeholder="Enter your prompt here...", lines=6, interactive=True) | |
| with gr.Row(): | |
| clear_btn = gr.ClearButton() | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| output = gr.Textbox(label="Output", lines=15, max_lines=15, interactive=False) | |
| with gr.Row(): | |
| upvote_btn = gr.Button("👍 Upvote", interactive=False) | |
| downvote_btn = gr.Button("👎 Downvote", interactive=False) | |
| gr.HTML("<p>‼️ For research purposes, we log user inputs and generated outputs. Please avoid submitting any confidential or personal information.</p>") | |
| gr.Markdown("#### Examples") | |
| gr.Examples(examples=examples[examples["type"] == "sensitive"].prompt.tolist(), inputs=input_text, label="Sensitive") | |
| gr.Examples(examples=examples[examples["type"] == "harmful"].prompt.tolist(), inputs=input_text, label="Harmful") | |
| def clear(): | |
| return None, gr.update(interactive=False), gr.update(interactive=False) | |
| clear_btn.add([input_text, output]) | |
| generate_btn.click( | |
| generate, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature], outputs=output | |
| ).success( | |
| post_process, inputs=session_id, outputs=[request_id, upvote_btn, downvote_btn] | |
| ) | |
| upvote_btn.click(output_feedback, inputs=[session_id, request_id, upvote_btn]) | |
| downvote_btn.click(output_feedback, inputs=[session_id, request_id, downvote_btn]) | |
| demo.load(initialize_instance, outputs=session_id) | |
| demo.unload(cleanup_instance) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=5) | |
| demo.launch(debug=True) | |