hannahcyberey commited on
Commit
7a22267
·
1 Parent(s): a137c8f
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -249,11 +249,11 @@ async def post_process(session_id, output):
249
 
250
  async def generate(
251
  session_id: str, prompt: str, steering: bool, coeff: float,
252
- max_new_tokens: int, top_p: float, temperature: float, vec_scaling: float
253
  ):
254
  req = UserRequest(
255
  session_id=session_id, prompt=prompt, steering=steering, coeff=coeff,
256
- max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, k=vec_scaling
257
  )
258
 
259
  instances[session_id].append(req)
@@ -374,7 +374,7 @@ with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS,
374
 
375
 
376
  submission = generate_btn.click(
377
- generate, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature, vec_scaling], outputs=output
378
  )
379
 
380
  clear_btn.add([input_text, output])
@@ -383,6 +383,8 @@ with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS,
383
  upvote_btn.click(output_feedback, inputs=[session_id, upvote_btn])
384
  downvote_btn.click(output_feedback, inputs=[session_id, downvote_btn])
385
 
 
 
386
  demo.load(initialize_instance, outputs=session_id)
387
  demo.unload(cleanup_instance)
388
 
 
249
 
250
  async def generate(
251
  session_id: str, prompt: str, steering: bool, coeff: float,
252
+ max_new_tokens: int, top_p: float, temperature: float, layer: int, vec_scaling: float
253
  ):
254
  req = UserRequest(
255
  session_id=session_id, prompt=prompt, steering=steering, coeff=coeff,
256
+ max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, layer=layer, k=vec_scaling
257
  )
258
 
259
  instances[session_id].append(req)
 
374
 
375
 
376
  submission = generate_btn.click(
377
+ generate, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature, layer, vec_scaling], outputs=output
378
  )
379
 
380
  clear_btn.add([input_text, output])
 
383
  upvote_btn.click(output_feedback, inputs=[session_id, upvote_btn])
384
  downvote_btn.click(output_feedback, inputs=[session_id, downvote_btn])
385
 
386
+ layer.change(fn=lambda x: 1, inputs=vec_scaling, outputs=vec_scaling)
387
+
388
  demo.load(initialize_instance, outputs=session_id)
389
  demo.unload(cleanup_instance)
390