fahrendrakhoirul commited on
Commit
01a1e74
·
verified ·
1 Parent(s): a63901c

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +697 -0
  2. model_class.py +48 -0
  3. pipeline.py +36 -0
app.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import plotly.express as px
5
+ import plotly.graph_objects as go
6
+ from datetime import datetime
7
+ import json
8
+ import torch
9
+
10
+ # ====================== Utility Functions ======================
11
+
12
+ def simulate_topic_prediction(text):
13
+ """Topic model prediction - replace with your actual model"""
14
+ topics = ['product', 'customer_service', 'shipping']
15
+ predictions = {}
16
+
17
+ token_result = pipeline.tokenizer(text, return_tensors="pt")
18
+ input_ids = token_result['input_ids']
19
+ attention_mask = token_result['attention_mask']
20
+
21
+ with torch.no_grad():
22
+ outputs = pipeline.aspect_model(input_ids=input_ids, attention_mask=attention_mask)
23
+
24
+ # st.write("Topic Model output:", outputs)
25
+
26
+ # Convert model outputs to probabilities
27
+ # The output is already sigmoid result, use directly
28
+ if hasattr(outputs, 'logits'):
29
+ probs = outputs.logits.squeeze().numpy()
30
+ else:
31
+ probs = outputs.squeeze().numpy()
32
+
33
+ # Create predictions dictionary
34
+ for i, topic in enumerate(topics):
35
+ predictions[topic] = float(probs[i]) if len(probs.shape) > 0 else float(probs)
36
+
37
+ return predictions
38
+
39
+ def simulate_sentiment_prediction(text):
40
+ """Sentiment prediction - replace with your actual model"""
41
+ sentiments = ['positive', 'neutral', 'negative']
42
+
43
+ token_result = pipeline.tokenizer(text, return_tensors="pt")
44
+ input_ids = token_result['input_ids']
45
+ attention_mask = token_result['attention_mask']
46
+
47
+ with torch.no_grad():
48
+ outputs = pipeline.sentiment_model(input_ids=input_ids, attention_mask=attention_mask)
49
+
50
+ # st.write("Sentiment Model output:", outputs)
51
+
52
+ # Convert model outputs to probabilities
53
+ # The output is already softmax result, use directly
54
+ if hasattr(outputs, 'logits'):
55
+ probs = outputs.logits.squeeze().numpy()
56
+ else:
57
+ probs = outputs.squeeze().numpy()
58
+
59
+ # st.write("Sentiment Model probabilities:", probs)
60
+ # Get the predicted sentiment
61
+ predicted_idx = np.argmax(probs)
62
+ predicted_sentiment = sentiments[predicted_idx]
63
+ confidence = float(probs[predicted_idx])
64
+
65
+ return {
66
+ 'sentiment': predicted_sentiment,
67
+ 'confidence': confidence,
68
+ 'all_probs': {sentiments[i]: float(probs[i]) for i in range(len(sentiments))}
69
+ }
70
+
71
+ def display_predictions(text, topic_predictions, sentiment_prediction):
72
+ """Display prediction results"""
73
+ st.markdown("---")
74
+ st.subheader("🎯 Classification Results")
75
+
76
+ # Display input text
77
+ st.markdown("**Input Text:**")
78
+ st.info(text)
79
+
80
+ col_topic, col_sentiment = st.columns(2)
81
+
82
+ with col_topic:
83
+ st.markdown("**🏷️ Topic Classification (Multi-label):**")
84
+
85
+ for topic, prob in topic_predictions.items():
86
+ if prob >= 0.5: # Fixed threshold
87
+ confidence_class = "topic-positive" if prob > 0.7 else "topic-neutral"
88
+ emoji = "✅" if prob > 0.7 else "⚠️"
89
+
90
+ result_html = f"""
91
+ <div class="prediction-result {confidence_class}">
92
+ {emoji} <strong>{topic.replace('_', ' ').title()}</strong>
93
+ <br>Confidence: {prob:.2%}
94
+ </div>
95
+ """
96
+ st.markdown(result_html, unsafe_allow_html=True)
97
+
98
+ # Show chart
99
+ fig_topic = create_topic_chart(topic_predictions)
100
+ st.plotly_chart(fig_topic, use_container_width=True)
101
+
102
+ with col_sentiment:
103
+ st.markdown("**😊 Sentiment Analysis:**")
104
+
105
+ sentiment = sentiment_prediction['sentiment']
106
+ confidence = sentiment_prediction['confidence']
107
+
108
+ sentiment_emoji = {"positive": "😊", "neutral": "😐", "negative": "😞"}
109
+ sentiment_class = f"topic-{sentiment}"
110
+
111
+ result_html = f"""
112
+ <div class="prediction-result {sentiment_class}">
113
+ {sentiment_emoji[sentiment]} <strong>{sentiment.title()}</strong>
114
+ <br>Confidence: {confidence:.2%}
115
+ </div>
116
+ """
117
+ st.markdown(result_html, unsafe_allow_html=True)
118
+
119
+ # Show chart
120
+ fig_sentiment = create_sentiment_chart(sentiment_prediction)
121
+ st.plotly_chart(fig_sentiment, use_container_width=True)
122
+
123
+ # Store in session state for statistics
124
+ if 'classification_history' not in st.session_state:
125
+ st.session_state.classification_history = []
126
+
127
+ st.session_state.classification_history.append({
128
+ 'text': text,
129
+ 'topics': topic_predictions,
130
+ 'sentiment': sentiment_prediction,
131
+ 'confidence': np.mean(list(topic_predictions.values()) + [sentiment_prediction['confidence']]),
132
+ 'timestamp': datetime.now()
133
+ })
134
+
135
+ def create_topic_chart(predictions):
136
+ """Create topic prediction chart"""
137
+ topics = list(predictions.keys())
138
+ probabilities = list(predictions.values())
139
+
140
+ fig = go.Figure(data=[
141
+ go.Bar(
142
+ x=[t.replace('_', ' ').title() for t in topics],
143
+ y=probabilities,
144
+ marker_color=['#28a745' if p >= 0.5 else '#6c757d' for p in probabilities]
145
+ )
146
+ ])
147
+
148
+ fig.update_layout(
149
+ title="Topic Classification Probabilities",
150
+ xaxis_title="Topics",
151
+ yaxis_title="Probability",
152
+ height=300,
153
+ showlegend=False
154
+ )
155
+
156
+ fig.add_hline(y=0.5, line_dash="dash", line_color="red",
157
+ annotation_text="Threshold (0.5)")
158
+
159
+ return fig
160
+
161
+ def create_sentiment_chart(prediction):
162
+ """Create sentiment prediction chart"""
163
+ sentiments = ['positive', 'neutral', 'negative']
164
+
165
+ # Use all probabilities if available, otherwise create from single prediction
166
+ if 'all_probs' in prediction:
167
+ probs = [prediction['all_probs'][s] for s in sentiments]
168
+ else:
169
+ # Fallback to original method
170
+ current_sentiment = prediction['sentiment']
171
+ confidence = prediction['confidence']
172
+ probs = [0.1, 0.1, 0.1]
173
+ idx = sentiments.index(current_sentiment)
174
+ probs[idx] = confidence
175
+ remaining = (1.0 - confidence) / 2
176
+ for i, _ in enumerate(probs):
177
+ if i != idx:
178
+ probs[i] = remaining
179
+
180
+ colors = ['#28a745', '#ffc107', '#dc3545']
181
+
182
+ # Create donut chart for single prediction
183
+ fig = go.Figure(data=[go.Pie(
184
+ labels=[s.title() for s in sentiments],
185
+ values=probs,
186
+ hole=0.3, # Creates donut effect
187
+ marker_colors=colors,
188
+ textinfo='label+percent',
189
+ textposition='auto'
190
+ )])
191
+
192
+ fig.update_layout(
193
+ title="Sentiment Analysis Probabilities",
194
+ height=300,
195
+ showlegend=False,
196
+ margin=dict(t=50, b=20, l=20, r=20)
197
+ )
198
+
199
+ return fig
200
+
201
+ def process_batch_classification(df, text_column, max_results=10):
202
+ """Process batch classification"""
203
+ st.subheader("🔄 Batch Processing Results")
204
+
205
+ progress_bar = st.progress(0)
206
+ results = []
207
+ all_topic_predictions = []
208
+ all_sentiment_predictions = []
209
+
210
+ for i, text in enumerate(df[text_column].values[:max_results]):
211
+ if isinstance(text, str) and text.strip():
212
+ try:
213
+ topic_pred = simulate_topic_prediction(text)
214
+ sentiment_pred = simulate_sentiment_prediction(text)
215
+
216
+ # Store for visualization
217
+ all_topic_predictions.append(topic_pred)
218
+ all_sentiment_predictions.append(sentiment_pred)
219
+
220
+ results.append({
221
+ 'text': text[:100] + '...' if len(text) > 100 else text,
222
+ 'topics': ', '.join([t for t, p in topic_pred.items() if p >= 0.5]),
223
+ 'sentiment': sentiment_pred['sentiment'],
224
+ 'sentiment_confidence': sentiment_pred['confidence']
225
+ })
226
+ except Exception as e:
227
+ st.error(f"Error processing text {i+1}: {str(e)}")
228
+ continue
229
+
230
+ progress_bar.progress((i + 1) / min(len(df), max_results))
231
+
232
+ # Display results table
233
+ if results:
234
+ results_df = pd.DataFrame(results)
235
+ st.dataframe(results_df, use_container_width=True)
236
+
237
+ # Create visualization section
238
+ st.markdown("---")
239
+ st.subheader("📊 Batch Analysis Visualization")
240
+
241
+ col_topic_viz, col_sentiment_viz = st.columns(2)
242
+
243
+ with col_topic_viz:
244
+ st.markdown("**Topic Distribution**")
245
+ create_batch_topic_chart(all_topic_predictions)
246
+
247
+ with col_sentiment_viz:
248
+ st.markdown("**Sentiment Distribution**")
249
+ create_batch_sentiment_chart(all_sentiment_predictions)
250
+
251
+ # Download results
252
+ csv = results_df.to_csv(index=False)
253
+ st.download_button(
254
+ label="📥 Download Results",
255
+ data=csv,
256
+ file_name=f"classification_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
257
+ mime="text/csv"
258
+ )
259
+
260
+ def create_batch_topic_chart(all_predictions):
261
+ """Create batch topic analysis chart"""
262
+ topics = ['product', 'customer_service', 'shipping']
263
+ topic_counts = {topic: 0 for topic in topics}
264
+ total_texts = len(all_predictions)
265
+
266
+ # Count how many texts were classified for each topic (above threshold)
267
+ for pred in all_predictions:
268
+ for topic, prob in pred.items():
269
+ if prob >= 0.5:
270
+ topic_counts[topic] += 1
271
+
272
+ # Convert to percentages
273
+ topic_percentages = {topic: (count / total_texts) * 100 for topic, count in topic_counts.items()}
274
+
275
+ # Create bar chart
276
+ fig = go.Figure(data=[
277
+ go.Bar(
278
+ x=[t.replace('_', ' ').title() for t in topics],
279
+ y=list(topic_percentages.values()),
280
+ marker_color=['#28a745', '#17a2b8', '#ffc107'],
281
+ text=[f'{v:.1f}%' for v in topic_percentages.values()],
282
+ textposition='auto'
283
+ )
284
+ ])
285
+
286
+ fig.update_layout(
287
+ title=f"Topic Distribution Across {total_texts} Texts",
288
+ xaxis_title="Topics",
289
+ yaxis_title="Percentage of Texts (%)",
290
+ height=400,
291
+ showlegend=False
292
+ )
293
+
294
+ st.plotly_chart(fig, use_container_width=True)
295
+
296
+ def create_batch_sentiment_chart(all_predictions):
297
+ """Create batch sentiment analysis chart (rounded/donut)"""
298
+ sentiments = ['positive', 'neutral', 'negative']
299
+ sentiment_counts = {sentiment: 0 for sentiment in sentiments}
300
+ total_texts = len(all_predictions)
301
+
302
+ # Count sentiment predictions
303
+ for pred in all_predictions:
304
+ sentiment = pred['sentiment']
305
+ sentiment_counts[sentiment] += 1
306
+
307
+ # Convert to percentages
308
+ sentiment_percentages = [(count / total_texts) * 100 for count in sentiment_counts.values()]
309
+
310
+ # Create donut chart
311
+ colors = ['#28a745', '#ffc107', '#dc3545']
312
+
313
+ fig = go.Figure(data=[go.Pie(
314
+ labels=[s.title() for s in sentiments],
315
+ values=sentiment_percentages,
316
+ hole=0.4, # Creates donut effect
317
+ marker_colors=colors,
318
+ textinfo='label+percent',
319
+ textposition='auto'
320
+ )])
321
+
322
+ fig.update_layout(
323
+ title=f"Sentiment Distribution Across {total_texts} Texts",
324
+ height=400,
325
+ showlegend=True,
326
+ legend=dict(
327
+ orientation="h",
328
+ yanchor="bottom",
329
+ y=-0.1,
330
+ xanchor="center",
331
+ x=0.5
332
+ )
333
+ )
334
+
335
+ st.plotly_chart(fig, use_container_width=True)
336
+
337
+ # ====================== Main Application ======================
338
+
339
+ # Page configuration
340
+ st.set_page_config(
341
+ page_title="Text Classification System",
342
+ page_icon="🔍",
343
+ layout="wide",
344
+ initial_sidebar_state="expanded"
345
+ )
346
+
347
+ # Custom CSS for better styling
348
+ st.markdown("""
349
+ <style>
350
+ .main-header {
351
+ font-size: 2.5rem;
352
+ font-weight: bold;
353
+ text-align: center;
354
+ margin-bottom: 2rem;
355
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
356
+ -webkit-background-clip: text;
357
+ -webkit-text-fill-color: transparent;
358
+ }
359
+
360
+ .model-card {
361
+ padding: 1rem;
362
+ border-radius: 10px;
363
+ border: 1px solid #e0e0e0;
364
+ margin: 1rem 0;
365
+ background-color: #f8f9fa;
366
+ }
367
+
368
+ .prediction-result {
369
+ padding: 1rem;
370
+ border-radius: 8px;
371
+ margin: 0.5rem 0;
372
+ }
373
+
374
+ .topic-positive {
375
+ background-color: #d4edda;
376
+ border-left: 4px solid #28a745;
377
+ color: #155724 !important;
378
+ }
379
+ .topic-neutral {
380
+ background-color: #fff3cd;
381
+ border-left: 4px solid #ffc107;
382
+ color: #856404 !important;
383
+ }
384
+ .topic-negative {
385
+ background-color: #f8d7da;
386
+ border-left: 4px solid #dc3545;
387
+ color: #721c24 !important;
388
+ }
389
+ .metrics-container {
390
+ display: flex;
391
+ justify-content: space-around;
392
+ margin: 1rem 0;
393
+ }
394
+
395
+ .metric-box {
396
+ text-align: center;
397
+ padding: 1rem;
398
+ border-radius: 8px;
399
+ background-color: #ffffff;
400
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
401
+ min-width: 120px;
402
+ }
403
+ </style>
404
+ """, unsafe_allow_html=True)
405
+
406
+ # Title
407
+ st.markdown('<h1 >🔍 Ecommerce Product Review Analysis - Indonesian Language </h1>', unsafe_allow_html=True)
408
+ st.markdown("---")
409
+
410
+ # Sidebar for configuration
411
+ with st.sidebar:
412
+
413
+ st.header("📊 Model Information")
414
+
415
+ # Topic Model Info
416
+ with st.expander("🏷️ Topic Classification Model", expanded=True):
417
+ st.markdown("""
418
+ **Model Type:** Multi-label Classification
419
+ **Categories:**
420
+ - 📦 Product
421
+ - 🎧 Customer Service
422
+ - 🚚 Shipping
423
+
424
+ **Note:** Text can belong to multiple categories
425
+ """)
426
+
427
+ # Sentiment Model Info
428
+ with st.expander("😊 Sentiment Analysis Model", expanded=True):
429
+ st.markdown("""
430
+ **Model Type:** Single-label Classification
431
+ **Categories:**
432
+ - 😊 Positive
433
+ - 😐 Neutral
434
+ - 😞 Negative
435
+ """)
436
+
437
+ # Statistics (if available)
438
+ if 'classification_history' in st.session_state:
439
+ st.header("📈 Session Statistics")
440
+ history = st.session_state.classification_history
441
+
442
+ col_stat1, col_stat2 = st.columns(2)
443
+ with col_stat1:
444
+ st.metric("Texts Classified", len(history))
445
+ with col_stat2:
446
+ avg_confidence = np.mean([h['confidence'] for h in history])
447
+ st.metric("Avg Confidence", f"{avg_confidence:.2f}")
448
+
449
+ # Import pipeline module with error handling
450
+ try:
451
+ import pipeline
452
+ # st.success("✅ Models loaded successfully!")
453
+ except ImportError as e:
454
+ st.error(f"❌ Error importing pipeline module: {str(e)}")
455
+ st.info("Please make sure your pipeline.py file is in the same directory and contains the required models.")
456
+ st.stop()
457
+ except Exception as e:
458
+ st.error(f"❌ Error loading models: {str(e)}")
459
+ st.stop()
460
+
461
+ # Main content
462
+ st.header("📝 Text Input")
463
+
464
+ # Input methods
465
+ input_method = st.radio("Choose input method:", ["Single Text", "Batch Upload", "Example Texts"])
466
+
467
+ if input_method == "Single Text":
468
+ user_text = st.text_area(
469
+ "Enter text to classify:",
470
+ placeholder="Type or paste your text here...",
471
+ height=150
472
+ )
473
+
474
+ if st.button("🚀 Classify Text", type="primary"):
475
+ if user_text.strip():
476
+ try:
477
+ # Call your actual model prediction functions
478
+ topic_predictions = simulate_topic_prediction(user_text)
479
+ sentiment_prediction = simulate_sentiment_prediction(user_text)
480
+
481
+ display_predictions(user_text, topic_predictions, sentiment_prediction)
482
+ except Exception as e:
483
+ st.error(f"Error during classification: {str(e)}")
484
+ else:
485
+ st.warning("Please enter some text to classify!")
486
+
487
+ elif input_method == "Batch Upload":
488
+ uploaded_file = st.file_uploader("Upload CSV file", type=['csv'])
489
+
490
+ if uploaded_file is not None:
491
+ # Delimiter options
492
+ col_delim, col_encoding = st.columns(2)
493
+ with col_delim:
494
+ delimiter = st.selectbox(
495
+ "Select delimiter:",
496
+ options=[",", ";", "\t", "|", " "],
497
+ format_func=lambda x: {"," : "Comma (,)", ";" : "Semicolon (;)", "\t" : "Tab", "|" : "Pipe (|)", " " : "Space"}[x],
498
+ index=0
499
+ )
500
+
501
+ with col_encoding:
502
+ encoding = st.selectbox(
503
+ "Select encoding:",
504
+ options=["utf-8", "latin-1", "cp1252", "ascii"],
505
+ index=0
506
+ )
507
+
508
+ try:
509
+ df = pd.read_csv(uploaded_file, delimiter=delimiter, encoding=encoding)
510
+ st.write("Preview of uploaded data:")
511
+ st.dataframe(df.head())
512
+
513
+ text_column = st.selectbox("Select text column:", df.columns)
514
+ # maximum number of rows to process
515
+ max_rows = st.slider(
516
+ "Maximum rows to process:",
517
+ min_value=1,
518
+ max_value=len(df),
519
+ value=min(100, len(df)),
520
+ step=1
521
+ )
522
+ if st.button("🔄 Process Batch", type="primary"):
523
+ process_batch_classification(df, text_column, max_results=max_rows)
524
+ except Exception as e:
525
+ st.error(f"Error reading CSV file: {str(e)}")
526
+ st.info("Try different delimiter or encoding options if the file doesn't load correctly.")
527
+
528
+ else: # Example Texts
529
+ st.subheader("Try these example texts:")
530
+
531
+ # Example type selection
532
+ example_type = st.radio(
533
+ "Choose example type:",
534
+ ["Single Examples", "CSV Examples"],
535
+ horizontal=True
536
+ )
537
+
538
+ if example_type == "Single Examples":
539
+ examples = [
540
+ "Pengiriman terlambat 3 hari dan paketnya rusak.",
541
+ "Pelayanan pelanggan sangat baik! Tim support sangat membantu dan responsif.",
542
+ "Kualitas produknya sangat bagus, sesuai dengan yang saya harapkan.",
543
+ "Saya kesulitan dengan proses pengembalian barang, sangat membingungkan.",
544
+ "Pengiriman cepat dan barang sampai dalam kondisi sempurna!"
545
+ ]
546
+
547
+ # Initialize session state for tracking which example to show results for
548
+ if 'selected_example' not in st.session_state:
549
+ st.session_state.selected_example = None
550
+ st.session_state.example_results = None
551
+
552
+ for i, example in enumerate(examples):
553
+ col_ex1, col_ex2 = st.columns([4, 1])
554
+ with col_ex1:
555
+ st.text(f"{i+1}. {example}")
556
+ with col_ex2:
557
+ if st.button(f"Classify", key=f"example_{i}"):
558
+ try:
559
+ topic_predictions = simulate_topic_prediction(example)
560
+ sentiment_prediction = simulate_sentiment_prediction(example)
561
+
562
+ # Store results in session state
563
+ st.session_state.selected_example = i
564
+ st.session_state.example_results = {
565
+ 'text': example,
566
+ 'topic_predictions': topic_predictions,
567
+ 'sentiment_prediction': sentiment_prediction
568
+ }
569
+ st.rerun()
570
+ except Exception as e:
571
+ st.error(f"Error during classification: {str(e)}")
572
+
573
+ # Display results below all examples if any example was classified
574
+ if st.session_state.selected_example is not None and st.session_state.example_results:
575
+ results = st.session_state.example_results
576
+ display_predictions(
577
+ results['text'],
578
+ results['topic_predictions'],
579
+ results['sentiment_prediction']
580
+ )
581
+
582
+
583
+ else: # CSV Examples
584
+ st.markdown("**Pre-prepared CSV datasets for testing:**")
585
+
586
+ # Predefined CSV options
587
+ csv_options = {
588
+ "Sample E-commerce Reviews": {
589
+ "data": {
590
+ "review_text": [
591
+ "Produk bagus tapi pengiriman lama",
592
+ "Customer service tidak responsif",
593
+ "Barang sesuai deskripsi, packing aman",
594
+ "Pengiriman cepat tapi produk cacat",
595
+ "Pelayanan memuaskan, akan order lagi",
596
+ "Kualitas produk mengecewakan",
597
+ "Pengiriman sangat cepat dan aman",
598
+ "Tim support sangat membantu menyelesaikan masalah",
599
+ "Produk original dan sesuai gambar",
600
+ "Proses refund sangat lambat dan rumit"
601
+ ],
602
+ "rating": [4, 2, 5, 3, 5, 1, 5, 5, 4, 2],
603
+ "category": ["Electronics", "Fashion", "Books", "Electronics", "Fashion", "Electronics", "Books", "Fashion", "Electronics", "Fashion"]
604
+ },
605
+ "description": "Indonesian e-commerce reviews with mixed sentiments and topics"
606
+ },
607
+ "Product Reviews Dataset": {
608
+ "data": {
609
+ "review_text": [
610
+ "Laptop ini performanya sangat bagus untuk gaming",
611
+ "Baju ini bahannya halus dan nyaman dipakai",
612
+ "Buku ini sangat informatif dan mudah dipahami",
613
+ "Handphone rusak setelah 2 minggu pemakaian",
614
+ "Sepatu ini sangat nyaman untuk jogging",
615
+ "Kamera foto hasil jelek, tidak sesuai harga",
616
+ "Pelayanan toko online ini sangat memuaskan",
617
+ "Pengiriman terlambat tapi barang aman",
618
+ "Produk tidak sesuai dengan deskripsi",
619
+ "Kualitas packaging sangat baik dan rapi"
620
+ ],
621
+ "product_type": ["Laptop", "Clothing", "Book", "Phone", "Shoes", "Camera", "Service", "Shipping", "General", "Packaging"],
622
+ "sentiment_label": ["positive", "positive", "positive", "negative", "positive", "negative", "positive", "neutral", "negative", "positive"]
623
+ },
624
+ "description": "Product-focused reviews with pre-labeled sentiments"
625
+ },
626
+ "Customer Service Reviews": {
627
+ "data": {
628
+ "review_text": [
629
+ "CS sangat ramah dan membantu menyelesaikan komplain",
630
+ "Susah menghubungi customer service via telepon",
631
+ "Live chat responsive tapi solusi kurang tepat",
632
+ "Tim support email sangat profesional",
633
+ "Customer service tidak memberikan solusi yang jelas",
634
+ "Pelayanan 24/7 sangat membantu customer",
635
+ "CS galak dan tidak sabar melayani customer",
636
+ "Support ticket dijawab dengan cepat dan tepat"
637
+ ],
638
+ "channel": ["Phone", "Phone", "Chat", "Email", "Phone", "24/7", "Phone", "Ticket"],
639
+ "resolution": ["Resolved", "Unresolved", "Partial", "Resolved", "Unresolved", "Resolved", "Unresolved", "Resolved"]
640
+ },
641
+ "description": "Customer service specific reviews and interactions"
642
+ }
643
+ }
644
+
645
+ # CSV selection
646
+ selected_csv = st.selectbox(
647
+ "Choose a pre-prepared dataset:",
648
+ options=list(csv_options.keys()),
649
+ help="Select from curated datasets for testing different scenarios"
650
+ )
651
+
652
+ if selected_csv:
653
+ csv_info = csv_options[selected_csv]
654
+ sample_df = pd.DataFrame(csv_info["data"])
655
+
656
+ # Display info and preview
657
+ st.info(f"📋 **{selected_csv}**: {csv_info['description']}")
658
+
659
+ col_preview, col_actions = st.columns([3, 1])
660
+
661
+ with col_preview:
662
+ st.dataframe(sample_df, use_container_width=True)
663
+
664
+ with col_actions:
665
+ # Download button
666
+ csv_data = sample_df.to_csv(index=False)
667
+ st.download_button(
668
+ label="📥 Download CSV",
669
+ data=csv_data,
670
+ file_name=f"{selected_csv.lower().replace(' ', '_')}.csv",
671
+ mime="text/csv",
672
+ help="Download this dataset to test batch processing"
673
+ )
674
+
675
+ # Quick test button
676
+ if st.button("🚀 Quick Test", help="Automatically process this dataset"):
677
+ st.session_state['quick_test_df'] = sample_df
678
+ st.session_state['quick_test_column'] = 'review_text'
679
+ st.rerun()
680
+
681
+ # Handle quick test
682
+ if 'quick_test_df' in st.session_state:
683
+ st.markdown("---")
684
+ st.subheader("🔄 Quick Test Results")
685
+ process_batch_classification(
686
+ st.session_state['quick_test_df'],
687
+ st.session_state['quick_test_column'],
688
+ len(st.session_state['quick_test_df'])
689
+ )
690
+ # Clear session state
691
+ del st.session_state['quick_test_df']
692
+ del st.session_state['quick_test_column']
693
+
694
+ st.info("💡 **Tip:** Download any dataset above and upload it in the 'Batch Upload' section, or use 'Quick Test' for immediate processing!")
695
+
696
+ # Footer
697
+ st.markdown("---")
model_class.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+
6
+ class CustomClassifierAspect(nn.Module, PyTorchModelHubMixin):
7
+ def __init__(self, bert, num_labels):
8
+ super(CustomClassifierAspect, self).__init__()
9
+ self.bert = bert
10
+ self.linear38 = nn.Linear(bert.config.hidden_size, 38)
11
+ self.dropout38 = nn.Dropout(0.2)
12
+ self.linear8 = nn.Linear(38, 8)
13
+ self.linear3 = nn.Linear(8, 3)
14
+ self.linearOutput = nn.Linear(3, num_labels)
15
+ self.sigmoid = nn.Sigmoid()
16
+
17
+ def forward(self, input_ids, attention_mask):
18
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
19
+ pooled_output = outputs.pooler_output
20
+ logits38 = self.linear38(pooled_output)
21
+ logits38 = self.dropout38(logits38)
22
+ logits8 = self.linear8(logits38)
23
+ logits3 = self.linear3(logits8)
24
+ logits = self.linearOutput(logits3)
25
+ probabilities = self.sigmoid(logits)
26
+ return probabilities
27
+
28
+ class CustomClassifierSentiment(nn.Module, PyTorchModelHubMixin):
29
+ def __init__(self, bert, num_labels):
30
+ super(CustomClassifierSentiment, self).__init__()
31
+ self.bert = bert
32
+ self.linear38 = nn.Linear(bert.config.hidden_size, 38)
33
+ self.dropout38 = nn.Dropout(0.2)
34
+ self.linear8 = nn.Linear(38, 8)
35
+ self.linear3 = nn.Linear(8, 3)
36
+ self.linearOutput = nn.Linear(3, num_labels)
37
+ self.softmax = nn.Softmax()
38
+
39
+ def forward(self, input_ids, attention_mask):
40
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
41
+ pooled_output = outputs.pooler_output
42
+ logits38 = self.linear38(pooled_output)
43
+ logits38 = self.dropout38(logits38)
44
+ logits8 = self.linear8(logits38)
45
+ logits3 = self.linear3(logits8)
46
+ logits = self.linearOutput(logits3)
47
+ probabilities = self.softmax(logits)
48
+ return probabilities
pipeline.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import BertModel, AutoTokenizer
4
+ from model_class import CustomClassifierAspect, CustomClassifierSentiment
5
+ import streamlit as st
6
+
7
+ ready_status = False
8
+ bert = None
9
+ tokenizer = None
10
+ aspect_model = None
11
+ sentiment_model = None
12
+
13
+
14
+ with st.status("Loading models...", expanded=True, state='running') as status:
15
+ # Load the base model and tokenizer
16
+ bertAspect = BertModel.from_pretrained("indobenchmark/indobert-base-p1",
17
+ num_labels=3,
18
+ problem_type="multi_label_classification")
19
+ bertSentiment = BertModel.from_pretrained("indobenchmark/indobert-base-p1")
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p1")
22
+
23
+ # Load custom models
24
+ aspect_model = CustomClassifierAspect.from_pretrained("fahrendrakhoirul/indobert-finetuned-ecommerce-product-reviews-aspect-multilabel", bert=bertAspect)
25
+ sentiment_model = CustomClassifierSentiment.from_pretrained("fahrendrakhoirul/indobert-finetuned-ecommerce-product-reviews-sentiment", bert=bertSentiment)
26
+ st.write("Model loaded")
27
+
28
+
29
+ # Update status to indicate models are ready
30
+ if aspect_model and sentiment_model != None:
31
+ ready_status = True
32
+ if ready_status:
33
+ status.update(label="Models loaded successfully", expanded=False)
34
+ status.success("Models loaded successfully", icon="✅")
35
+ else:
36
+ status.error("Failed to load models")