abdull4h commited on
Commit
d43130f
·
verified ·
1 Parent(s): ae02789

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -676
app.py CHANGED
@@ -1,136 +1,104 @@
1
  import os
2
  import re
3
- import gradio as gr
4
- from huggingface_hub import login
5
- import spaces
6
-
7
- # CRITICAL: Disable PyTorch compiler BEFORE importing torch
8
- os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"
9
- os.environ["TORCH_COMPILE_DISABLE"] = "1"
10
- os.environ["TORCH_INDUCTOR_DISABLE"] = "1"
11
- os.environ["TORCHINDUCTOR_DISABLE_CUDAGRAPHS"] = "1"
12
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
13
- os.environ["TORCH_USE_CUDA_DSA"] = "0"
14
-
15
- # Now import torch and disable its compiler features
16
  import torch
17
- if hasattr(torch, "_dynamo"):
18
- if hasattr(torch._dynamo, "config"):
19
- torch._dynamo.config.suppress_errors = True
20
- if hasattr(torch._dynamo, "disable"):
21
- torch._dynamo.disable()
22
- print("Disabled torch._dynamo")
23
 
24
  # Model ID
25
  model_id = "CohereForAI/c4ai-command-r7b-arabic-02-2025"
26
 
27
- # Get token from environment and login
28
- hf_token = os.environ.get("HF_TOKEN")
29
- if hf_token:
30
- login(token=hf_token)
31
- print("Logged in with HF_TOKEN")
32
- else:
33
- print("No HF_TOKEN found. Please set the HF_TOKEN environment variable.")
34
-
35
- # Import transformers
36
- from transformers import AutoTokenizer, AutoModelForCausalLM
37
-
38
- # Simpler clean_response function
39
- def clean_response(text):
40
- # Remove website references
41
- text = re.sub(r'- موقع .*?\n', '', text)
42
-
43
- # Remove dates
44
- text = re.sub(r'\d+ [فبراير|مارس|أبريل|مايو|يونيو|يوليو|أغسطس|سبتمبر|أكتوبر|نوفمبر|ديسمبر]+ \d+ - \d+:\d+ [صباحا|مساء|ص|م]', '', text)
45
-
46
- # Remove repeated questions
47
- text = re.sub(r'(\?[^?]*){2,}', '?', text)
 
 
 
 
 
 
 
 
 
 
48
 
49
- # Remove excessive repetition (sentences that repeat)
50
- lines = text.split('،')
51
- unique_lines = []
52
- for line in lines:
53
- if line.strip() and line.strip() not in unique_lines:
54
- unique_lines.append(line.strip())
55
 
56
- return '.join(unique_lines)
57
 
58
- # Generate text with the Arabic model
59
- @spaces.GPU
60
- def generate_text(prompt, max_length=100, temperature=0.7, force_arabic=True):
61
- if not prompt.strip():
62
- return "Please enter a prompt."
63
-
64
- try:
65
- # Load tokenizer and model
66
- print("Loading tokenizer...")
67
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
68
-
69
- print("Loading model with compiler disabled...")
70
- model = AutoModelForCausalLM.from_pretrained(
71
- model_id,
72
- token=hf_token,
73
- torch_dtype=torch.float16,
74
- device_map="auto",
75
- use_cache=True,
76
- use_flash_attention_2=False,
77
- _attn_implementation="eager"
78
- )
79
-
80
- print(f"Model loaded successfully on {next(model.parameters()).device}")
81
 
82
- """
83
- # For Arabic-focused prompting, add a language instruction if needed
84
- if force_arabic and not any(arabic_word in prompt for arabic_word in ["العربية", "بالعربي", "باللغة العربية"]):
85
- # Add Arabic instruction only if prompt is already in Arabic
86
- if any('\u0600' <= c <= '\u06FF' for c in prompt):
87
- enhanced_prompt = prompt + " (أجب باللغة العربية)"
88
- print(f"Added Arabic language hint: {enhanced_prompt}")
89
- else:
90
- enhanced_prompt = prompt
91
- else:
92
- enhanced_prompt = prompt
93
- """
94
-
95
- # Replace with this line:
96
- enhanced_prompt = prompt
97
-
98
- # Create input for the model using proper tokenization with attention mask
99
- print(f"Generating response for: {enhanced_prompt[:50]}...")
100
-
101
- # Try to use a more direct approach
102
- encoding = tokenizer(prompt, return_tensors="pt")
103
- input_ids = encoding.input_ids.to(model.device)
104
- attention_mask = encoding.attention_mask.to(model.device)
105
-
106
- print(f"Input shape: {input_ids.shape}, Attention mask shape: {attention_mask.shape}")
107
-
108
- # Add repetition penalty
109
  with torch.inference_mode():
110
- output = model.generate(
111
- input_ids=input_ids,
112
- attention_mask=attention_mask,
113
- max_new_tokens=int(max_length),
114
- do_sample=True,
115
- temperature=0.7,
116
- repetition_penalty=1.2, # Add this parameter
117
- no_repeat_ngram_size=3, # And this one
 
 
 
118
  pad_token_id=tokenizer.eos_token_id
119
  )
120
 
121
- # Get only the generated part (exclude the prompt)
122
- input_length = input_ids.shape[1]
123
- generated_tokens = output[0][input_length:]
124
 
125
- # Decode just the generated part
126
- generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
127
- print(f"Generated text (after input): {generated_text[:100]}...")
128
 
129
- # Clean any remaining special tokens
130
- cleaned_response = clean_response(generated_text)
131
- print(f"Final cleaned response: {cleaned_response[:100]}...")
132
 
133
- return cleaned_response
134
 
135
  except Exception as e:
136
  import traceback
@@ -138,564 +106,24 @@ def generate_text(prompt, max_length=100, temperature=0.7, force_arabic=True):
138
  print(f"Error generating text: {str(e)}\n{tb}")
139
  return f"Error generating text: {str(e)}"
140
 
141
- # Keep the existing code, but replace the custom_css with the following:
142
-
143
- custom_css = """
144
- /* Enhanced Color Scheme and UI for Arabic Language Model */
145
- :root {
146
- --primary-color: #1F4287; /* Deep Royal Blue */
147
- --secondary-color: #278EA5; /* Teal Blue */
148
- --background-color: #F9FAFC; /* Soft Light Gray */
149
- --text-color: #333942; /* Dark Slate */
150
- --accent-color: #21BF73; /* Vibrant Green */
151
- --highlight-color: #FF6B6B; /* Coral Red */
152
- --header-gradient: linear-gradient(135deg, #1F4287 0%, #278EA5 100%);
153
- --card-shadow: 0 10px 30px rgba(31, 66, 135, 0.12);
154
- --input-bg: #F2F5F9; /* Light Blue-Gray for inputs */
155
- --border-radius: 16px; /* Consistent border radius */
156
- }
157
-
158
- /* Base Styles */
159
- .gradio-container {
160
- background: var(--background-color);
161
- color: var(--text-color);
162
- font-family: 'Cairo', 'Noto Sans Arabic', 'Helvetica Neue', 'Arial', sans-serif;
163
- max-width: 1200px;
164
- margin: 0 auto;
165
- padding: 20px;
166
- }
167
-
168
- /* Typography */
169
- .gradio-container h1 {
170
- color: var(--primary-color);
171
- font-size: 2.5rem;
172
- text-align: center;
173
- margin-bottom: 0.5rem;
174
- font-weight: 800;
175
- }
176
-
177
- .gradio-container h2 {
178
- color: var(--secondary-color);
179
- font-size: 1.5rem;
180
- text-align: center;
181
- margin-bottom: 2rem;
182
- font-weight: 600;
183
- }
184
-
185
- .gradio-container h3 {
186
- color: var(--secondary-color);
187
- font-size: 1.25rem;
188
- margin-top: 1.5rem;
189
- margin-bottom: 1rem;
190
- font-weight: 600;
191
- }
192
-
193
- /* Card-style Blocks */
194
- .gradio-container .block {
195
- background-color: white;
196
- border-radius: var(--border-radius);
197
- box-shadow: var(--card-shadow);
198
- border: none;
199
- padding: 30px;
200
- margin: 24px 0;
201
- transition: all 0.3s ease;
202
- }
203
-
204
- .gradio-container .block:hover {
205
- box-shadow: 0 15px 40px rgba(31, 66, 135, 0.18);
206
- transform: translateY(-5px);
207
- }
208
-
209
- /* Header Style */
210
- .gradio-container .header {
211
- background: var(--header-gradient);
212
- color: white;
213
- padding: 30px;
214
- border-radius: var(--border-radius);
215
- margin-bottom: 30px;
216
- text-align: center;
217
- position: relative;
218
- overflow: hidden;
219
- }
220
-
221
- .gradio-container .header::before {
222
- content: '';
223
- position: absolute;
224
- top: 0;
225
- left: 0;
226
- right: 0;
227
- bottom: 0;
228
- background: url('data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" viewBox="0 0 100 100"><text x="50%" y="50%" font-size="80" text-anchor="middle" dominant-baseline="middle" font-family="Arial" fill="rgba(255,255,255,0.05)">ذ</text></svg>') repeat;
229
- opacity: 0.1;
230
- }
231
-
232
- .gradio-container .header h1,
233
- .gradio-container .header h2 {
234
- color: white;
235
- text-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
236
- }
237
-
238
- /* Input and Output Containers */
239
- .gradio-container .input-container,
240
- .gradio-container .output-container {
241
- background-color: white;
242
- border-radius: var(--border-radius);
243
- box-shadow: var(--card-shadow);
244
- padding: 25px;
245
- margin-bottom: 25px;
246
- transition: all 0.3s ease;
247
- }
248
-
249
- .gradio-container .input-container:hover,
250
- .gradio-container .output-container:hover {
251
- box-shadow: 0 15px 40px rgba(31, 66, 135, 0.15);
252
- }
253
-
254
- .gradio-container .block-title {
255
- color: var(--primary-color);
256
- font-weight: bold;
257
- text-align: center;
258
- margin-bottom: 20px;
259
- font-size: 1.5rem;
260
- position: relative;
261
- padding-bottom: 10px;
262
- }
263
-
264
- .gradio-container .block-title::after {
265
- content: '';
266
- position: absolute;
267
- bottom: 0;
268
- left: 50%;
269
- transform: translateX(-50%);
270
- width: 60px;
271
- height: 3px;
272
- background: var(--accent-color);
273
- border-radius: 3px;
274
- }
275
-
276
- /* Textareas and Inputs */
277
- .gradio-container textarea,
278
- .gradio-container input[type="text"] {
279
- background-color: var(--input-bg);
280
- border: 2px solid transparent;
281
- border-radius: calc(var(--border-radius) - 4px);
282
- color: var(--text-color);
283
- direction: rtl;
284
- padding: 15px;
285
- transition: all 0.3s ease;
286
- font-size: 1.05rem;
287
- line-height: 1.6;
288
- resize: vertical;
289
- }
290
-
291
- .gradio-container textarea::placeholder,
292
- .gradio-container input[type="text"]::placeholder {
293
- color: #9EA7B3;
294
- }
295
-
296
- .gradio-container textarea:focus,
297
- .gradio-container input[type="text"]:focus {
298
- border-color: var(--accent-color);
299
- box-shadow: 0 0 0 3px rgba(33, 191, 115, 0.2);
300
- outline: none;
301
- }
302
-
303
- /* Labels */
304
- .gradio-container label {
305
- color: var(--primary-color);
306
- font-weight: 600;
307
- margin-bottom: 8px;
308
- display: block;
309
- font-size: 1.05rem;
310
- }
311
-
312
- /* Buttons */
313
- .gradio-container .primary {
314
- background: linear-gradient(135deg, var(--secondary-color) 0%, var(--accent-color) 100%) !important;
315
- color: white !important;
316
- border-radius: calc(var(--border-radius) - 4px);
317
- transition: all 0.3s ease;
318
- font-weight: bold;
319
- padding: 12px 24px !important;
320
- border: none !important;
321
- font-size: 1.1rem;
322
- box-shadow: 0 4px 15px rgba(33, 191, 115, 0.3);
323
- text-align: center;
324
- }
325
-
326
- .gradio-container .primary:hover {
327
- transform: translateY(-3px);
328
- box-shadow: 0 8px 20px rgba(33, 191, 115, 0.4);
329
- }
330
-
331
- .gradio-container .primary:active {
332
- transform: translateY(-1px);
333
- }
334
-
335
- .gradio-container .secondary {
336
- background-color: #EDF2F7;
337
- color: var(--primary-color);
338
- border-radius: calc(var(--border-radius) - 4px);
339
- transition: all 0.3s ease;
340
- font-weight: 600;
341
- padding: 12px 24px !important;
342
- border: 1px solid #D9E2EC !important;
343
- font-size: 1.1rem;
344
- }
345
-
346
- .gradio-container .secondary:hover {
347
- background-color: #E2E8F0;
348
- transform: translateY(-2px);
349
- box-shadow: 0 4px 10px rgba(31, 66, 135, 0.1);
350
- }
351
-
352
- /* Example Buttons Styling */
353
- .gradio-container button:not(.primary):not(.secondary) {
354
- background-color: white;
355
- color: var(--secondary-color);
356
- border: 1px solid var(--secondary-color);
357
- border-radius: 30px;
358
- padding: 8px 16px;
359
- margin: 5px;
360
- transition: all 0.3s ease;
361
- font-size: 0.95rem;
362
- }
363
-
364
- .gradio-container button:not(.primary):not(.secondary):hover {
365
- background-color: var(--secondary-color);
366
- color: white;
367
- transform: scale(1.05);
368
- box-shadow: 0 4px 12px rgba(39, 142, 165, 0.25);
369
- }
370
-
371
- /* Accordion Styling */
372
- .gradio-container .accordion {
373
- border: 1px solid #E2E8F0;
374
- border-radius: var(--border-radius);
375
- overflow: hidden;
376
- margin: 20px 0;
377
- }
378
-
379
- .gradio-container .accordion-title {
380
- background-color: #EDF2F7;
381
- color: var(--primary-color);
382
- padding: 12px 20px;
383
- font-weight: bold;
384
- cursor: pointer;
385
- border-radius: calc(var(--border-radius) - 4px);
386
- transition: all 0.3s ease;
387
- display: flex;
388
- align-items: center;
389
- justify-content: space-between;
390
- }
391
-
392
- .gradio-container .accordion-title:hover {
393
- background-color: #E2E8F0;
394
- }
395
-
396
- .gradio-container .accordion-title::after {
397
- content: '▼';
398
- font-size: 12px;
399
- margin-left: 10px;
400
- transition: transform 0.3s ease;
401
- }
402
-
403
- .gradio-container .accordion-title.open::after {
404
- transform: rotate(180deg);
405
- }
406
-
407
- .gradio-container .accordion-content {
408
- padding: 15px 20px;
409
- background-color: white;
410
- }
411
-
412
- /* Sliders */
413
- .gradio-container input[type="range"] {
414
- -webkit-appearance: none;
415
- width: 100%;
416
- height: 8px;
417
- border-radius: 5px;
418
- background: #E2E8F0;
419
- outline: none;
420
- margin: 15px 0;
421
- }
422
-
423
- .gradio-container input[type="range"]::-webkit-slider-thumb {
424
- -webkit-appearance: none;
425
- appearance: none;
426
- width: 20px;
427
- height: 20px;
428
- border-radius: 50%;
429
- background: var(--accent-color);
430
- cursor: pointer;
431
- box-shadow: 0 2px 8px rgba(33, 191, 115, 0.4);
432
- }
433
-
434
- .gradio-container input[type="range"]::-moz-range-thumb {
435
- width: 20px;
436
- height: 20px;
437
- border-radius: 50%;
438
- background: var(--accent-color);
439
- cursor: pointer;
440
- box-shadow: 0 2px 8px rgba(33, 191, 115, 0.4);
441
- }
442
-
443
- /* Checkboxes */
444
- .gradio-container input[type="checkbox"] {
445
- -webkit-appearance: none;
446
- appearance: none;
447
- width: 20px;
448
- height: 20px;
449
- border: 2px solid var(--secondary-color);
450
- border-radius: 5px;
451
- outline: none;
452
- cursor: pointer;
453
- margin-right: 10px;
454
- vertical-align: middle;
455
- position: relative;
456
- }
457
-
458
- .gradio-container input[type="checkbox"]:checked {
459
- background-color: var(--accent-color);
460
- border-color: var(--accent-color);
461
- }
462
-
463
- .gradio-container input[type="checkbox"]:checked::after {
464
- content: '✓';
465
- color: white;
466
- position: absolute;
467
- top: 50%;
468
- left: 50%;
469
- transform: translate(-50%, -50%);
470
- font-size: 14px;
471
- font-weight: bold;
472
- }
473
-
474
- /* Status and Processing Indicators */
475
- .gradio-container .status-message {
476
- color: var(--highlight-color);
477
- font-weight: bold;
478
- text-align: center;
479
- margin: 15px 0;
480
- padding: 10px;
481
- border-radius: calc(var(--border-radius) - 8px);
482
- background-color: rgba(255, 107, 107, 0.1);
483
- border-left: 3px solid var(--highlight-color);
484
- }
485
-
486
- /* Loading Animation */
487
- @keyframes pulse {
488
- 0% { opacity: 0.6; }
489
- 50% { opacity: 1; }
490
- 100% { opacity: 0.6; }
491
- }
492
-
493
- .gradio-container .loading {
494
- animation: pulse 1.5s infinite;
495
- display: inline-block;
496
- padding-left: 8px;
497
- }
498
-
499
- /* Responsive Design */
500
- @media (max-width: 768px) {
501
- .gradio-container {
502
- padding: 10px;
503
- }
504
-
505
- .gradio-container .block {
506
- padding: 20px;
507
- }
508
-
509
- .gradio-container h1 {
510
- font-size: 2rem;
511
- }
512
-
513
- .gradio-container h2 {
514
- font-size: 1.25rem;
515
- }
516
-
517
- .gradio-container .primary,
518
- .gradio-container .secondary {
519
- padding: 10px 18px !important;
520
- font-size: 1rem;
521
- }
522
- }
523
-
524
- @media (max-width: 480px) {
525
- .gradio-container h1 {
526
- font-size: 1.75rem;
527
- }
528
-
529
- .gradio-container h2 {
530
- font-size: 1.1rem;
531
- }
532
-
533
- .gradio-container .block {
534
- padding: 15px;
535
- }
536
- }
537
-
538
- /* RTL Support - Important for Arabic */
539
- [dir="rtl"] .gradio-container,
540
- .rtl {
541
- text-align: right;
542
- }
543
-
544
- [dir="rtl"] .gradio-container .accordion-title::after,
545
- .rtl .gradio-container .accordion-title::after {
546
- margin-left: 0;
547
- margin-right: 10px;
548
- }
549
-
550
- /* Dark Mode Support (Optional) */
551
- @media (prefers-color-scheme: dark) {
552
- :root {
553
- --primary-color: #4D96FF;
554
- --secondary-color: #38B6FF;
555
- --background-color: #1A1A2E;
556
- --text-color: #E6E6E6;
557
- --accent-color: #38E54D;
558
- --highlight-color: #FF6B6B;
559
- --input-bg: #242442;
560
- --header-gradient: linear-gradient(135deg, #4D96FF 0%, #38B6FF 100%);
561
- --card-shadow: 0 10px 30px rgba(0, 0, 0, 0.3);
562
- }
563
-
564
- .gradio-container {
565
- background: var(--background-color);
566
- }
567
-
568
- .gradio-container .block,
569
- .gradio-container .input-container,
570
- .gradio-container .output-container {
571
- background-color: #242442;
572
- }
573
-
574
- .gradio-container .secondary {
575
- background-color: #333355;
576
- border-color: #444466 !important;
577
- }
578
-
579
- .gradio-container .secondary:hover {
580
- background-color: #3D3D60;
581
- }
582
-
583
- .gradio-container textarea,
584
- .gradio-container input[type="text"] {
585
- background-color: #333355;
586
- color: var(--text-color);
587
- }
588
-
589
- .gradio-container textarea::placeholder,
590
- .gradio-container input[type="text"]::placeholder {
591
- color: #8D8DAA;
592
- }
593
-
594
- .gradio-container .accordion-title {
595
- background-color: #333355;
596
- }
597
-
598
- .gradio-container .accordion-title:hover {
599
- background-color: #3D3D60;
600
- }
601
-
602
- .gradio-container input[type="range"] {
603
- background: #333355;
604
- }
605
- }
606
- """
607
-
608
- # Updated Gradio interface with enhanced design
609
- with gr.Blocks(title="Cohere Arabic Model Demo", css=custom_css) as demo:
610
- # Main title and description
611
- gr.Markdown("""
612
- # 🌟 نموذج Cohere للغة العربية
613
- ## Command R7B Arabic Language Model
614
-
615
- نموذج ذكاء اصطناعي متقدم للتوليد النصي باللغة العربية
616
- """)
617
-
618
- # Main interface container
619
- with gr.Row():
620
- # Input Column
621
- with gr.Column(scale=1):
622
- # Prompt Input
623
- prompt = gr.Textbox(
624
- label="النص الإدخال | Input Prompt",
625
- placeholder="أدخل نصك باللغة العربية هنا...",
626
- lines=5
627
- )
628
-
629
- # Example Prompts Section
630
- gr.Markdown("### أمثلة سريعة | Quick Examples")
631
- with gr.Row():
632
- example_prompts = [
633
- "مرحبا، كيف حالك؟",
634
- "اكتب قصة قصيرة عن قطة",
635
- "اشرح مفهوم الذكاء الاصطناعي",
636
- "قانون الجاذبية للأطفال",
637
- ]
638
- for example in example_prompts:
639
- example_btn = gr.Button(example)
640
- example_btn.click(fn=lambda x=example: x, inputs=[], outputs=[prompt])
641
-
642
- # Advanced Settings Accordion
643
- with gr.Accordion("الإعدادات المتقدمة | Advanced Settings", open=False):
644
- max_tokens = gr.Slider(
645
- minimum=10, maximum=500, value=100,
646
- step=10, label="الحد الأقصى للرموز | Max Tokens"
647
- )
648
- temperature = gr.Slider(
649
- minimum=0.1, maximum=1.0, value=0.7,
650
- step=0.1, label="درجة الحرارة | Temperature"
651
- )
652
- force_arabic = gr.Checkbox(
653
- label="تشجيع الاستجابات بالعربية | Encourage Arabic Responses",
654
- value=True
655
- )
656
-
657
- # Generate and Clear Buttons
658
- with gr.Row():
659
- generate_btn = gr.Button("توليد النص | Generate", variant="primary")
660
- clear_btn = gr.Button("مسح | Clear", variant="secondary")
661
-
662
- # Output Column
663
- with gr.Column(scale=1):
664
- output = gr.Textbox(
665
- label="النص المولد | Generated Text",
666
- lines=10,
667
- interactive=False
668
- )
669
-
670
- # Status Markdown for additional information
671
- status = gr.Markdown("جاهز للتوليد | Ready to generate")
672
-
673
- # Event Handlers
674
- def on_generate(prompt, max_tokens, temperature, force_arabic):
675
- # Update status to indicate generation is in progress
676
- status_update = "جارٍ التوليد... قد يستغرق حتى دقيقتين | Generating... This may take up to 2 minutes."
677
-
678
- # Call the generation function
679
- result = generate_text(prompt, max_tokens, temperature, force_arabic)
680
-
681
- return result, "اكتمل التوليد | Generation complete!"
682
-
683
- # Connect buttons to their functions
684
- generate_btn.click(
685
- fn=on_generate,
686
- inputs=[prompt, max_tokens, temperature, force_arabic],
687
- outputs=[output, status]
688
- )
689
-
690
- # Clear button functionality
691
- clear_btn.click(
692
- fn=lambda: ("", "تم المسح | Cleared"),
693
- inputs=[],
694
- outputs=[prompt, output, status]
695
- )
696
-
697
- # Launch the Gradio app
698
- demo.launch(
699
- share=True, # Enable sharing if needed
700
- debug=True # Enable debug mode
701
- )
 
1
  import os
2
  import re
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ # Global variables for model and tokenizer to prevent reloading
7
+ global_model = None
8
+ global_tokenizer = None
 
9
 
10
  # Model ID
11
  model_id = "CohereForAI/c4ai-command-r7b-arabic-02-2025"
12
 
13
+ def load_models():
14
+ """Load the model and tokenizer once and cache them"""
15
+ global global_model, global_tokenizer
16
+
17
+ # If already loaded, return the cached instances
18
+ if global_model is not None and global_tokenizer is not None:
19
+ return global_tokenizer, global_model
20
+
21
+ # Get token from environment
22
+ hf_token = os.environ.get("HF_TOKEN")
23
+ if not hf_token:
24
+ raise ValueError("No HF_TOKEN found. Please set the HF_TOKEN environment variable.")
25
+
26
+ # Load tokenizer
27
+ print("Loading tokenizer...")
28
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
29
+
30
+ # Load model with appropriate configuration
31
+ print("Loading model...")
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ model_id,
34
+ token=hf_token,
35
+ torch_dtype=torch.float16,
36
+ device_map="auto",
37
+ # Set a reasonable maximum memory usage
38
+ max_memory={0: "14GB"},
39
+ # Ensure we use the model's full capabilities
40
+ use_cache=True,
41
+ # Use settings for stability
42
+ _attn_implementation="eager"
43
+ )
44
 
45
+ # Cache the loaded model and tokenizer
46
+ global_model = model
47
+ global_tokenizer = tokenizer
 
 
 
48
 
49
+ return tokenizer, model
50
 
51
+ def format_prompt(prompt):
52
+ """Format the prompt for optimal response from the model"""
53
+ # Command models often perform better with clear instruction formatting
54
+ formatted_prompt = f"الإجابة على الأسئلة بدقة ومباشرة ودون التطرق للمواضيع الأخرى غير المتعلقة بالسؤال.\n\nالسؤال: {prompt}\n\nالإجابة:"
55
+ return formatted_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ def generate_text(prompt, max_new_tokens=500):
58
+ """Generate text with the Arabic model using optimal parameters"""
59
+ try:
60
+ # Get or load the model and tokenizer
61
+ tokenizer, model = load_models()
62
+
63
+ # Format the prompt
64
+ formatted_prompt = format_prompt(prompt)
65
+ print(f"Formatted prompt: {formatted_prompt[:100]}...")
66
+
67
+ # Tokenize with proper padding and attention mask
68
+ inputs = tokenizer(
69
+ formatted_prompt,
70
+ return_tensors="pt",
71
+ padding=True,
72
+ truncation=False # Allow full context window
73
+ ).to(model.device)
74
+
75
+ # Generate with parameters optimized for the model
 
 
 
 
 
 
 
 
76
  with torch.inference_mode():
77
+ outputs = model.generate(
78
+ input_ids=inputs.input_ids,
79
+ attention_mask=inputs.attention_mask,
80
+ # Use parameters aligned with model capabilities
81
+ max_new_tokens=max_new_tokens,
82
+ temperature=0.3, # Lower for more deterministic responses
83
+ top_p=0.9,
84
+ repetition_penalty=1.2, # Penalize repetition
85
+ no_repeat_ngram_size=3, # Avoid repeating phrases
86
+ do_sample=True, # Enable sampling but with controlled randomness
87
+ num_return_sequences=1,
88
  pad_token_id=tokenizer.eos_token_id
89
  )
90
 
91
+ # Only get the newly generated content after the prompt
92
+ prompt_length = inputs.input_ids.shape[1]
93
+ generated_ids = outputs[0][prompt_length:]
94
 
95
+ # Decode the token IDs to text
96
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
 
97
 
98
+ # Clean up the generated text
99
+ final_text = clean_response(generated_text)
 
100
 
101
+ return final_text
102
 
103
  except Exception as e:
104
  import traceback
 
106
  print(f"Error generating text: {str(e)}\n{tb}")
107
  return f"Error generating text: {str(e)}"
108
 
109
+ def clean_response(text):
110
+ """Clean and format the response"""
111
+ # Remove any special tokens or artifacts
112
+ text = re.sub(r'<.*?>', '', text)
113
+
114
+ # Remove any extra whitespace
115
+ text = re.sub(r'\s+', ' ', text).strip()
116
+
117
+ # Remove any artificial repetitions
118
+ # This regex looks for repeated phrases (4+ words)
119
+ text = re.sub(r'(\b\w+\b\s+\b\w+\b\s+\b\w+\b\s+\b\w+\b\s+)(\1)+', r'\1', text)
120
+
121
+ return text
122
+
123
+ # Example usage
124
+ if __name__ == "__main__":
125
+ # Test with the poem question
126
+ question = 'من كتب قصيدة "على قدر أهل العزم تأتي العزائم"؟'
127
+ response = generate_text(question)
128
+ print("\nQuestion:", question)
129
+ print("\nResponse:", response)