File size: 3,898 Bytes
46358a2
e2fac8d
39f6145
6293678
29e0785
6293678
66e2112
 
 
 
 
095ced6
66e2112
 
095ced6
66e2112
 
 
 
095ced6
66e2112
46358a2
66e2112
095ced6
66e2112
095ced6
66e2112
 
095ced6
66e2112
 
095ced6
66e2112
 
 
 
095ced6
66e2112
 
e2fac8d
66e2112
 
095ced6
66e2112
 
 
 
 
e2fac8d
66e2112
 
 
 
 
 
 
 
 
 
 
 
0a17bfe
095ced6
66e2112
095ced6
66e2112
 
 
 
 
095ced6
 
 
 
 
 
66e2112
 
 
 
095ced6
66e2112
 
 
 
 
 
 
0a17bfe
095ced6
66e2112
e2fac8d
095ced6
e2fac8d
66e2112
e2fac8d
 
 
 
ca0aa0f
 
 
 
 
e2fac8d
 
 
29e0785
e2fac8d
095ced6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import spaces

class LlamaGuardModeration:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_id = "meta-llama/Llama-Guard-3-8B"  # Change the model ID if needed
        self.dtype = torch.bfloat16
        
        # HuggingFace token retrieval
        self.huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
        if not self.huggingface_token:
            raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
        
        # Initialize the model and tokenizer
        self.initialize_model()

    def initialize_model(self):
        """Initialize model and tokenizer."""
        if self.model is None:
            # Initialize tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_id, 
                use_auth_token=self.huggingface_token
            )
            
            # Initialize model
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                torch_dtype=self.dtype,
                device_map="auto",
                use_auth_token=self.huggingface_token,
                low_cpu_mem_usage=True
            )

    @staticmethod
    def parse_llama_guard_output(result):
        """Parse Llama Guard output."""
        safety_assessment = result.split("<END CONVERSATION>")[-1].strip()
        lines = [line.strip().lower() for line in safety_assessment.split('\n') if line.strip()]
        
        if not lines:
            return "Error", "No valid output", safety_assessment

        safety_status = next((line for line in lines if line in ['safe', 'unsafe']), None)
        
        if safety_status == 'safe':
            return "Safe", "None", safety_assessment
        elif safety_status == 'unsafe':
            violated_categories = next(
                (lines[i+1] for i, line in enumerate(lines) if line == 'unsafe' and i+1 < len(lines)), 
                "Unspecified"
            )
            return "Unsafe", violated_categories, safety_assessment
        else:
            return "Error", f"Invalid output: {safety_status}", safety_assessment

    @spaces.GPU()
    def moderate(self, user_input, assistant_response):
        """Run moderation check."""
        chat = [
            {"role": "user", "content": user_input},
            {"role": "assistant", "content": assistant_response},
        ]
        
        # Tokenize the inputs and make sure the model runs on the correct device
        input_ids = self.tokenizer(
            [f"{item['role']}: {item['content']}" for item in chat], 
            return_tensors="pt", 
            padding=True,
            truncation=True
        ).to(self.device)
        
        with torch.no_grad():
            output = self.model.generate(
                input_ids=input_ids["input_ids"],
                max_new_tokens=200,
                pad_token_id=self.tokenizer.eos_token_id,
                do_sample=False
            )
        
        result = self.tokenizer.decode(output[0], skip_special_tokens=True)
        return self.parse_llama_guard_output(result)

# Create an instance of the moderator
moderator = LlamaGuardModeration()

# Set up Gradio interface
iface = gr.Interface(
    fn=moderator.moderate,
    inputs=[
        gr.Textbox(lines=3, label="User Input"),
        gr.Textbox(lines=3, label="Assistant Response")
    ],
    outputs=[
        gr.Textbox(label="Safety Status"),
        gr.Textbox(label="Violated Categories"),
        gr.Textbox(label="Raw Output")
    ],
    title="Llama Guard Moderation",
    description="Enter a user input and an assistant response to check for content moderation."
)

if __name__ == "__main__":
    iface.launch()