FaisalGh commited on
Commit
ad96bdb
·
verified ·
1 Parent(s): 8ca6168

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +41 -8
README.md CHANGED
@@ -18,16 +18,49 @@ A fine-tuned DistilGPT2 model that generates short, clean, and (sometimes) funny
18
  - **Training epochs:** 5
19
  - **Max joke length:** 80 tokens
20
 
21
- ## Usage 🚀
22
 
23
  ### Direct Inference
 
24
  ```python
25
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- joke_gen = pipeline(
28
- "text-generation",
29
- model="FaisalGh/jokes-model"
30
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Generate a joke
33
- print(joke_gen("Tell me a clean joke:")[0]['generated_text'])
 
18
  - **Training epochs:** 5
19
  - **Max joke length:** 80 tokens
20
 
21
+ ## Usage
22
 
23
  ### Direct Inference
24
+
25
  ```python
26
+ from transformers import pipeline, AutoTokenizer
27
+ import torch
28
+
29
+ #Please add the BLOCKLIST for clean jokes
30
+ BLOCKLIST = [
31
+ "sex", "naked", "porn", "fuck", "dick", "penis", "ass",
32
+ "blowjob", "orgasm", "rape", "kill", "die", "shit",
33
+ "crap", "bastard", "hell", "damn", "bitch", "underage",
34
+ "pedo", "hit", "shot", "gun", "drug", "drunk", "fag", "cunt"
35
+ ]
36
+
37
+ def is_safe(text):
38
+ text_lower = text.lower()
39
+ return not any(bad_word in text_lower for bad_word in BLOCKLIST)
40
 
41
+ def generate_joke(prompt="Tell me a clean joke:"):
42
+ joke_gen = pipeline(
43
+ "text-generation",
44
+ model="FaisalGh/jokes-model",
45
+ device=0 if torch.cuda.is_available() else -1
46
+ )
47
+
48
+ output = joke_gen(
49
+ prompt,
50
+ max_length=80,
51
+ temperature=0.7,
52
+ top_k=50,
53
+ top_p=0.9,
54
+ repetition_penalty=1.5,
55
+ no_repeat_ngram_size=2,
56
+ do_sample=True,
57
+ pad_token_id=50256,
58
+ eos_token_id=50256
59
+ )
60
+
61
+ generated_text = output[0]['generated_text']
62
+ first_sentence = generated_text.split(".")[0] + "."
63
+
64
+ return "[Content filtered] Please try again." if not is_safe(first_sentence) else first_sentence.strip()
65
 
66
+ print(generate_joke("Tell me a clean joke:"))