hannahcyberey's picture
update
1a99b8c
raw
history blame
13.3 kB
import os
import logging, json
from pathlib import Path
import asyncio
import aiohttp
import pandas as pd
import gradio as gr
from gradio_toggle import Toggle
from scheduler import load_scheduler
from schemas import UserRequest, SteeringOutput, CONFIG
MAX_RETRIES = 10
MAX_RETRY_WAIT_TIME = 75
MIN_RETRY_WAIT_TIME = 5
ENDPOINT_ALIVE = False
HF_TOKEN = os.getenv('HF_TOKEN')
API_URL = "https://a6k5m81qw14hkvhz.us-east-1.aws.endpoints.huggingface.cloud"
headers = {
"Accept" : "application/json",
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json"
}
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(name)s %(levelname)s:%(message)s')
logger = logging.getLogger(__name__)
model_name = "DeepSeek-R1-Distill-Qwen-7B"
examples = pd.read_csv("assets/examples.csv")
instances = {}
scheduler = load_scheduler()
HEAD = """
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.7.2/css/all.min.css" integrity="sha512-Evv84Mr4kqVGRNSgIGL/F/aIDqQb7xQ2vcrdIwxfjThSH8CSR7PBEakCr51Ck+w+/U6swU2Im1vVX0SVk9ABhg==" crossorigin="anonymous" referrerpolicy="no-referrer" />
"""
HTML = f"""
<div id="banner">
<h1><img src="/gradio_api/file=assets/rudder_3094973.png">&nbsp;LLM Censorship Steering</h1>
<div id="links" class="row" style="margin-bottom: .8em;">
<i class="fa-solid fa-file-pdf fa-lg"></i><a href="https://arxiv.org/abs/2504.17130"> Paper</a> &nbsp;
<i class="fa-solid fa-blog fa-lg"></i><a href="https://hannahxchen.github.io/blog/2025/censorship-steering"> Blog Post</a> &nbsp;
<i class="fa-brands fa-github fa-lg"></i><a href="https://github.com/hannahxchen/llm-censorship-steering"> Code</a> &nbsp;
</div>
<div id="cover">
<img src="/gradio_api/file=assets/demo-cover.png">
</div>
</div>
"""
CSS = """
div.gradio-container .app {
max-width: 1600px !important;
}
div#banner {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
h1 {
font-size: 32px;
line-height: 1.35em;
margin-bottom: 0em;
display: flex;
img {
display: inline;
height: 1.35em;
}
}
div#cover img {
max-height: 130px;
padding-top: 0.5em;
}
}
@media (max-width: 500px) {
div#banner {
h1 {
font-size: 22px;
}
div#links {
font-size: 14px;
}
}
div#model-state p {
font-size: 14px;
}
}
div#main-components {
align-items: flex-end;
}
div#steering-toggle {
padding-top: 8px;
padding-bottom: 8px;
.toggle-label {
color: var(--body-text-color);
}
span p {
font-size: var(--block-info-text-size);
line-height: var(--line-sm);
color: var(--block-label-text-color);
}
}
div#coeff-slider {
padding-bottom: 5px;
.slider_input_container span {color: var(--body-text-color);}
.slider_input_container {
display: flex;
flex-wrap: wrap;
input {appearance: auto;}
}
}
div#coeff-slider .wrap .head {
justify-content: unset;
label {margin-right: var(--size-2);}
label span {
color: var(--body-text-color);
margin-bottom: 0;
}
}
"""
slider_info = """\
<div style='display: flex; justify-content: space-between; line-height: normal;'>\
<span style='font-size: var(--block-info-text-size); color: var(--block-label-text-color);'>Less censorship</span>\
<span style='font-size: var(--block-info-text-size); color: var(--block-label-text-color);'>More censorship</span>\
</div>\
"""\
slider_ticks = """\
<datalist id='values' style='display: flex; justify-content: space-between; width: 100%; padding: 0 6px;'>\
<option value='-2' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>-2</option>\
<option value='-1' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>-1</option>\
<option value='0' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>0</option>\
<option value='1' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>1</option>\
<option value='2' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>2</option>\
</datalist>\
"""
JS = """
async() => {
const node = document.querySelector("div.slider_input_container");
node.insertAdjacentHTML('beforebegin', "%s");
const sliderNode = document.querySelector("input#range_id_0");
sliderNode.insertAdjacentHTML('afterend', "%s");
sliderNode.setAttribute("list", "values");
document.querySelector('span.min_value').remove();
document.querySelector('span.max_value').remove();
}
""" % (slider_info, slider_ticks)
def initialize_instance(request: gr.Request):
instances[request.session_hash] = []
logger.info("Number of connections: %d", len(instances))
return request.session_hash
def cleanup_instance(request: gr.Request):
global ENDPOINT_ALIVE
session_id = request.session_hash
if session_id in instances:
with open("outputs.jsonl", "a") as f:
for data in instances[session_id]:
scheduler.append(data.model_dump())
json.dump(data.model_dump(), f)
f.write("\n")
del instances[session_id]
if len(instances) == 0:
ENDPOINT_ALIVE = False
logger.info("Number of connections: %d", len(instances))
async def initialize_endpoint():
async with aiohttp.ClientSession() as session:
async with session.get(f"{API_URL}/health", headers=headers) as resp:
if resp.status == 200:
return True
else:
resp_text = await resp.text()
logger.error("API Error Code: %d, Message: %s", resp.status, resp_text)
return False
async def get_endpoint_state():
global ENDPOINT_ALIVE
n = 0
sleep_time = MAX_RETRY_WAIT_TIME
while n < MAX_RETRIES:
n += 1
if not ENDPOINT_ALIVE:
logger.info("Initializing inference endpoint")
yield "Initializing"
ENDPOINT_ALIVE = await initialize_endpoint()
if ENDPOINT_ALIVE:
logger.info("Inference endpoint is ready")
gr.Info("Inference endpoint is ready")
yield "Ready"
break
gr.Warning("Initializing inference endpoint\n(This may take 2~3 minutes)", duration=sleep_time)
await asyncio.sleep(sleep_time)
sleep_time = max(sleep_time * 0.8, MIN_RETRY_WAIT_TIME)
if n == MAX_RETRIES:
yield "Server Error"
async def save_output(req: UserRequest, output: str):
if "</think>" in output:
p = [p for p in output.partition("</think>") if p != ""]
reasoning = "".join(p[:-1])
if len(p) == 1:
answer = None
else:
answer = p[-1]
else:
answer = None
reasoning = output
steering_output = SteeringOutput(**req.model_dump(), reasoning=reasoning, answer=answer)
instances[req.session_id].append(steering_output)
async def generate(
session_id: str, prompt: str, steering: bool, coeff: float,
max_new_tokens: int, top_p: float, temperature: float
):
req = UserRequest(
session_id=session_id, prompt=prompt, steering=steering, coeff=coeff,
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
)
data = req.get_api_format()
logger.info("User Request: %s", data)
generated_text = ""
session = aiohttp.ClientSession()
async with session.post(f"{API_URL}/generate", headers=headers, json=data) as resp:
if resp.status == 200:
generated_text += "<think>"
async for chunk, _ in resp.content.iter_chunks():
generated_text += chunk.decode()
yield generated_text
else:
logger.error("API Error Ccode: %d, Error Message: %s", resp.status, resp.text())
raise gr.Error("API Server Error")
await session.close()
if generated_text != "":
await save_output(req, generated_text)
async def post_process(session_id):
return instances[session_id][-1].request_id, gr.update(interactive=True), gr.update(interactive=True)
async def output_feedback(session_id, request_id, feedback):
logger.info("Feedback received for request %s: %s", str(request_id), feedback)
try:
data = instances[session_id].pop()
if data.request_id == request_id:
if "Upvote" in feedback:
setattr(data, "upvote", True)
elif "Downvote" in feedback:
setattr(data, "upvote", False)
instances[session_id].append(data)
gr.Info("Thank you for your feedback!")
except:
logger.debug("Feedback submission error")
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
theme = gr.themes.Base(primary_hue="emerald", text_size=gr.themes.sizes.text_lg).set()
with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS, js=JS) as demo:
session_id = gr.State()
request_id = gr.State()
endpoint_state = gr.State(get_endpoint_state)
gr.HTML(HTML)
with gr.Row(elem_id="main-components"):
with gr.Column(scale=1):
@gr.render(inputs=endpoint_state, triggers=[endpoint_state.change])
def render_state(endpoint_state):
if endpoint_state == "Ready":
color = "green"
elif endpoint_state == "Server Error":
color = "red"
else:
color = "orange"
if endpoint_state != None:
gr.Markdown(f'🤖 {model_name} | Inference Endpoint State: <span style="color:{color}; font-weight: bold;">{endpoint_state}</span>', elem_id="model-state")
with gr.Row():
steer_toggle = Toggle(label="Steering", info="Turn off to generate original outputs", value=True, interactive=True, scale=2, elem_id="steering-toggle")
coeff = gr.Slider(label="Coefficient:", value=-1.0, minimum=-2, maximum=2, step=0.1, scale=8, show_reset_button=False, elem_id="coeff-slider")
@gr.on(inputs=[steer_toggle], outputs=[steer_toggle, coeff], triggers=[steer_toggle.change])
def update_toggle(toggle_value):
if toggle_value is True:
return gr.update(label="Steering", info="Turn off to generate original outputs"), gr.update(interactive=True)
else:
return gr.update(label="No Steering", info="Turn on to steer model outputs"), gr.update(interactive=False)
with gr.Accordion("⚙️ Advanced Settings", open=False):
with gr.Row():
temperature = gr.Slider(0, 1, step=0.1, value=CONFIG["temperature"], interactive=True, label="Temperature", scale=2)
top_p = gr.Slider(0, 1, step=0.1, value=CONFIG["top_p"], interactive=True, label="Top p", scale=2)
max_new_tokens = gr.Number(CONFIG["max_new_tokens"], minimum=10, maximum=CONFIG["max_new_tokens"], interactive=True, label="Max new tokens", scale=1)
input_text = gr.Textbox(label="Input", placeholder="Enter your prompt here...", lines=6, interactive=True)
with gr.Row():
clear_btn = gr.ClearButton()
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
output = gr.Textbox(label="Output", lines=15, max_lines=15, interactive=False)
with gr.Row():
upvote_btn = gr.Button("👍 Upvote", interactive=False)
downvote_btn = gr.Button("👎 Downvote", interactive=False)
gr.HTML("<p>‼️ For research purposes, we log user inputs and generated outputs. Please avoid submitting any confidential or personal information.</p>")
gr.Markdown("#### Examples")
gr.Examples(examples=examples[examples["type"] == "sensitive"].prompt.tolist(), inputs=input_text, label="Sensitive")
gr.Examples(examples=examples[examples["type"] == "harmful"].prompt.tolist(), inputs=input_text, label="Harmful")
@gr.on(triggers=[clear_btn.click], outputs=[request_id, upvote_btn, downvote_btn])
def clear():
return None, gr.update(interactive=False), gr.update(interactive=False)
clear_btn.add([input_text, output])
generate_btn.click(
generate, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature], outputs=output
).success(
post_process, inputs=session_id, outputs=[request_id, upvote_btn, downvote_btn]
)
upvote_btn.click(output_feedback, inputs=[session_id, request_id, upvote_btn])
downvote_btn.click(output_feedback, inputs=[session_id, request_id, downvote_btn])
demo.load(initialize_instance, outputs=session_id)
demo.unload(cleanup_instance)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=5)
demo.launch(debug=True)