import os import logging from pathlib import Path import asyncio import aiohttp import pandas as pd import gradio as gr from gradio_toggle import Toggle from scheduler import load_scheduler from schemas import UserRequest, SteeringOutput, CONFIG MAX_RETRIES = 10 MAX_RETRY_WAIT_TIME = 75 MIN_RETRY_WAIT_TIME = 5 ENDPOINT_ALIVE = False HF_TOKEN = os.getenv('HF_TOKEN') API_URL = "https://a6k5m81qw14hkvhz.us-east-1.aws.endpoints.huggingface.cloud" headers = { "Accept" : "application/json", "Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json" } logging.basicConfig(level=logging.INFO, format='%(asctime)s %(name)s %(levelname)s:%(message)s') logger = logging.getLogger(__name__) model_name = "DeepSeek-R1-Distill-Qwen-7B" examples = pd.read_csv("assets/examples.csv") instances = {} scheduler = load_scheduler() HEAD = """ """ HTML = f""" """ CSS = """ div.gradio-container .app { max-width: 1600px !important; } div#banner { display: flex; flex-direction: column; align-items: center; justify-content: center; h1 { font-size: 32px; line-height: 1.35em; margin-bottom: 0em; display: flex; img { display: inline; height: 1.35em; } } div#cover img { max-height: 130px; padding-top: 0.5em; } } @media (max-width: 500px) { div#banner { h1 { font-size: 22px; } div#links { font-size: 14px; } } div#model-state p { font-size: 14px; } } div#steering-toggle { padding-top: 8px; padding-bottom: 8px; .toggle-label { color: var(--body-text-color); } span p { font-size: var(--block-info-text-size); line-height: var(--line-sm); color: var(--block-label-text-color); } } div#coeff-slider { padding-bottom: 5px; .slider_input_container span {color: var(--body-text-color);} .slider_input_container { display: flex; flex-wrap: wrap; input {appearance: auto;} } } div#coeff-slider .wrap .head { justify-content: unset; label {margin-right: var(--size-2);} label span { color: var(--body-text-color); margin-bottom: 0; } } .tooltip { word-wrap: break-word; width: 12rem; } .tooltip-inner { filter: alpha(opacity=100); font-size: var(--block-info-text-size); text-align: center; padding: .4rem .2rem; background-color: var(--neutral-500); border-width: 1px; border-radius: var(--block-radius); } """ slider_info = """\
\ Less censorship\ More censorship\
\ """\ slider_ticks = """\ \ \ \ \ \ \ \ """ coeff_info = """\ \ """ JS = """ async() => { const node = document.querySelector("div.slider_input_container"); node.insertAdjacentHTML('beforebegin', "%s"); const sliderNode = document.querySelector("input#range_id_0"); sliderNode.insertAdjacentHTML('afterend', "%s"); sliderNode.setAttribute("list", "values"); const coeffBox = document.querySelector("div#coeff-slider label span"); coeffBox.insertAdjacentHTML('afterend', "%s"); var tooltipTriggerList = [].slice.call(document.querySelectorAll('[data-bs-toggle="tooltip"]')) var tooltipList = tooltipTriggerList.map(function (tooltipTriggerEl) { return new bootstrap.Tooltip(tooltipTriggerEl) }) document.querySelector('span.min_value').remove(); document.querySelector('span.max_value').remove(); } """ % (slider_info, slider_ticks, coeff_info) def initialize_instance(request: gr.Request): instances[request.session_hash] = [] logger.info("Number of connections: %d", len(instances)) return request.session_hash def cleanup_instance(request: gr.Request): global ENDPOINT_ALIVE session_id = request.session_hash if session_id in instances: for data in instances[session_id]: if isinstance(data, SteeringOutput): scheduler.append(data.model_dump()) del instances[session_id] if len(instances) == 0: ENDPOINT_ALIVE = False logger.info("Number of connections: %d", len(instances)) async def initialize_endpoint(): alive = False session = aiohttp.ClientSession() async with session.get(f"{API_URL}/health", headers=headers) as resp: resp_text = await resp.text() if resp.status == 200: alive = True else: logger.error("API Error Code: %d, Message: %s", resp.status, resp_text) await session.close() return alive async def get_endpoint_state(): global ENDPOINT_ALIVE n = 0 sleep_time = MAX_RETRY_WAIT_TIME while n < MAX_RETRIES: n += 1 if not ENDPOINT_ALIVE: logger.info("Initializing inference endpoint") yield "Initializing" ENDPOINT_ALIVE = await initialize_endpoint() if ENDPOINT_ALIVE: logger.info("Inference endpoint is ready") gr.Info("Inference endpoint is ready") yield "Ready" break gr.Warning("Initializing inference endpoint\n(This may take 2~3 minutes)", duration=sleep_time) await asyncio.sleep(sleep_time) sleep_time = max(sleep_time * 0.8, MIN_RETRY_WAIT_TIME) if n == MAX_RETRIES: yield "Server Error" async def post_process(session_id, output): req = instances[session_id].pop() if "" in output: p = [p for p in output.partition("") if p != ""] reasoning = "".join(p[:-1]) if len(p) == 1: answer = None else: answer = p[-1] else: answer = None reasoning = output steering_output = SteeringOutput(**req.model_dump(), reasoning=reasoning, answer=answer) instances[session_id].append(steering_output) async def generate( session_id: str, prompt: str, steering: bool, coeff: float, max_new_tokens: int, top_p: float, temperature: float, layer: int, vec_scaling: float ): req = UserRequest( session_id=session_id, prompt=prompt, steering=steering, coeff=coeff, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, layer=layer, k=vec_scaling ) instances[session_id].append(req) data = req.get_api_format() logger.info("User Request: %s", data) generated_text = "" try: async with aiohttp.ClientSession() as session: async with session.post(f"{API_URL}/generate", headers=headers, json=data) as resp: if resp.status == 200: generated_text += "" async for chunk, _ in resp.content.iter_chunks(): generated_text += chunk.decode() yield generated_text else: logger.error("API Error Ccode: %d, Error Message: %s", resp.status, resp.text()) raise gr.Error("API Server Error") await post_process(session_id, generated_text) except: logger.info("Client session error") async def output_feedback(session_id, feedback): logger.info("Feedback received: %s", feedback) try: data = instances[session_id].pop() if "Upvote" in feedback: setattr(data, "upvote", True) elif "Downvote" in feedback: setattr(data, "upvote", False) instances[session_id].append(data) gr.Info("Thank you for your feedback!") except: logger.debug("Feedback submission error") async def show_feedback_buttons(upvote_btn, downvote_btn): return gr.update(interactive=True), gr.update(interactive=True) gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"]) theme = gr.themes.Base(primary_hue="emerald", text_size=gr.themes.sizes.text_lg).set() with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS, js=JS) as demo: session_id = gr.State() endpoint_state = gr.State(get_endpoint_state) gr.HTML(HTML) @gr.render(inputs=endpoint_state, triggers=[endpoint_state.change]) def render_state(endpoint_state): if endpoint_state == "Ready": color = "green" elif endpoint_state == "Server Error": color = "red" else: color = "orange" if endpoint_state != None: gr.Markdown(f'🤖 {model_name} | Inference Endpoint State: {endpoint_state}', elem_id="model-state") with gr.Row(elem_id="main-components"): with gr.Column(scale=1): with gr.Row(): steer_toggle = Toggle(label="Steering", info="Turn off to generate original outputs", value=True, interactive=True, scale=2, elem_id="steering-toggle") coeff = gr.Slider(label="Coefficient", value=-1.0, minimum=-2, maximum=2, step=0.1, scale=8, show_reset_button=False, elem_id="coeff-slider") @gr.on(inputs=[steer_toggle], outputs=[steer_toggle, coeff], triggers=[steer_toggle.change]) def update_toggle(toggle_value): if toggle_value is True: return gr.update(label="Steering", info="Turn off to generate original outputs"), gr.update(interactive=True) else: return gr.update(label="No Steering", info="Turn on to steer model outputs"), gr.update(interactive=False) input_text = gr.Textbox(label="Input", placeholder="Enter your prompt here...", lines=6, interactive=True) with gr.Row(): clear_btn = gr.ClearButton() stop_btn = gr.Button("Stop") generate_btn = gr.Button("Generate", variant="primary") with gr.Accordion("⚙️ Advanced Settings", open=False): with gr.Row(): temperature = gr.Slider(0, 1, step=0.1, value=CONFIG["temperature"], interactive=True, label="Temperature", scale=1) top_p = gr.Slider(0, 1, step=0.1, value=CONFIG["top_p"], interactive=True, label="Top p", scale=1) with gr.Row(): layer = gr.Slider(0, 27, step=1, value=CONFIG["layer"], interactive=True, label="Steering layer", scale=2) max_new_tokens = gr.Number(CONFIG["max_new_tokens"], minimum=10, maximum=3048, interactive=True, label="Max new tokens", scale=1) vec_scaling = gr.Number(1.5, minimum=0, interactive=True, label="Vector scaling", scale=1) with gr.Column(scale=1): output = gr.Textbox(label="Output", lines=15, max_lines=15, interactive=False) with gr.Row(): upvote_btn = gr.Button("👍 Upvote", interactive=False) downvote_btn = gr.Button("👎 Downvote", interactive=False) gr.HTML("

‼️ For research purposes, we log user inputs and generated outputs. Please avoid submitting any confidential or personal information.

") gr.Markdown("#### Examples") gr.Examples(examples=examples[examples["type"] == "sensitive"].prompt.tolist(), inputs=input_text, label="Sensitive") gr.Examples(examples=examples[examples["type"] == "harmful"].prompt.tolist(), inputs=input_text, label="Harmful") @gr.on(triggers=[clear_btn.click, stop_btn.click], outputs=[upvote_btn, downvote_btn]) def clear_feedback_buttons(): return gr.update(interactive=False), gr.update(interactive=False) @gr.on(triggers=[generate_btn.click], outputs=[upvote_btn, downvote_btn]) def show_feedback_buttons(): return gr.update(interactive=True), gr.update(interactive=True) submission = generate_btn.click( generate, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature, layer, vec_scaling], outputs=output ) clear_btn.add([input_text, output]) stop_btn.click(None, None, None, cancels=[submission], queue=False) upvote_btn.click(output_feedback, inputs=[session_id, upvote_btn]) downvote_btn.click(output_feedback, inputs=[session_id, downvote_btn]) layer.change(fn=lambda x: 1, inputs=vec_scaling, outputs=vec_scaling) demo.load(initialize_instance, outputs=session_id) demo.unload(cleanup_instance) if __name__ == "__main__": demo.queue(default_concurrency_limit=5) demo.launch(debug=True)