proyecto-meis commited on
Commit
27b1d64
verified
1 Parent(s): 1c5761b

Modelo y app.py

Browse files
Files changed (3) hide show
  1. app.py +137 -0
  2. model.h5 +3 -0
  3. requirements.txt +43 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.layers import Layer
2
+ import keras.backend as K
3
+ from transformers import TFAutoModel, AutoTokenizer
4
+ from tensorflow.keras.layers import (
5
+ Softmax, GlobalAveragePooling1D, GlobalMaxPooling1D, Activation, Concatenate,
6
+ Conv1D, MultiHeadAttention, LayerNormalization, Input, LSTM, Embedding,
7
+ Lambda, Dense, Dropout, concatenate, SpatialDropout1D, Bidirectional
8
+ )
9
+ from keras.models import Model
10
+ from tcn import TCN
11
+ import keras.ops as ops
12
+ from keras import initializers
13
+ import tensorflow as tf
14
+ import re
15
+ import os
16
+ import gradio as gr
17
+
18
+ bert_model_name = "dccuchile/bert-base-spanish-wwm-uncased"
19
+ MAX_LEN = 274
20
+ WEIGHTS_PATH = os.getenv("WEIGHTS_PATH", "model.h5")
21
+ THRESHOLD = float(os.getenv("THRESHOLD", "0.5"))
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
24
+
25
+ bert_model = TFAutoModel.from_pretrained(
26
+ bert_model_name,
27
+ output_hidden_states=False,
28
+ output_attentions=False,
29
+ )
30
+ bert_model.trainable = False
31
+
32
+ def tcn_model_with_bert(bert_model_name="google-bert/bert-base-multilingual-uncased", max_length=512):
33
+ input_ids = Input(shape=(max_length,), dtype=tf.int32, name='input_ids')
34
+ attention_mask = Input(shape=(max_length,),
35
+ dtype=tf.int32, name='attention_mask')
36
+
37
+ def extract_bert_embeddings(inputs):
38
+ return tf.cast(
39
+ bert_model(
40
+ {'input_ids': inputs[0], 'attention_mask': inputs[1]}).last_hidden_state,
41
+ tf.float32
42
+ )
43
+
44
+ bert_output = Lambda(extract_bert_embeddings, output_shape=(
45
+ max_length, 768))([input_ids, attention_mask])
46
+
47
+ x = SpatialDropout1D(0.15)(bert_output)
48
+ x = LSTM(128, activation='tanh', stateful=False,
49
+ return_sequences=True, dropout=0.1)(x)
50
+ x = LayerNormalization()(x)
51
+ x = Bidirectional(TCN(128, dilations=[
52
+ 1, 2, 4, 8], kernel_size=5, return_sequences=True, activation='gelu', name='tcn1'))(x)
53
+
54
+ gap = GlobalAveragePooling1D()(x)
55
+ gmp = GlobalMaxPooling1D()(x)
56
+ head = Concatenate()([gap, gmp])
57
+
58
+ head = Dense(64, activation="gelu")(head)
59
+ head = Dropout(0.2)(head)
60
+ outp = Dense(1, activation="sigmoid")(head)
61
+
62
+ model = Model(inputs=[input_ids, attention_mask], outputs=outp)
63
+ model.compile(
64
+ optimizer=tf.keras.optimizers.AdamW(
65
+ learning_rate=1e-4, weight_decay=0.01, clipnorm=1.0),
66
+ loss="binary_crossentropy",
67
+ metrics=['accuracy']
68
+ )
69
+ return model
70
+
71
+ def preprocessing(text):
72
+ if not isinstance(text, str) or not text:
73
+ return ""
74
+ text = re.sub(r'\s*https?://\S+(\s+|$)', ' ', text).strip()
75
+ text = re.sub(r'\S*@\S*\s?', ' ', text).strip()
76
+ text = re.sub(r'#\S*\s?', ' ', text).strip()
77
+ text = re.sub(r'[.?!隆驴]+$', '', text)
78
+ text = text.lower()
79
+ text = text.strip()
80
+ return text
81
+
82
+ model = tcn_model_with_bert(
83
+ bert_model_name=bert_model_name, max_length=MAX_LEN)
84
+
85
+ _loaded = False
86
+ if os.path.exists(WEIGHTS_PATH):
87
+ try:
88
+ model.load_weights(WEIGHTS_PATH)
89
+ _loaded = True
90
+ except Exception:
91
+ try:
92
+ from tensorflow.keras.models import load_model
93
+ model = load_model(WEIGHTS_PATH, custom_objects={"TCN": TCN})
94
+ _loaded = True
95
+ except Exception:
96
+ pass
97
+
98
+ def predict_text(text: str, max_len: int = MAX_LEN, threshold: float = THRESHOLD):
99
+ preprocessed_text = preprocessing(text)
100
+ enc = tokenizer(
101
+ preprocessed_text,
102
+ truncation=True,
103
+ padding='max_length',
104
+ max_length=max_len,
105
+ return_tensors='tf'
106
+ )
107
+ probs = model.predict(
108
+ {'input_ids': enc['input_ids'],
109
+ 'attention_mask': enc['attention_mask']},
110
+ verbose=0
111
+ )
112
+ score = float(probs[0][0])
113
+ label = int(score >= threshold)
114
+ return {
115
+ "text": text,
116
+ "preprocessed": preprocessed_text,
117
+ "score": score,
118
+ "label": label
119
+ }
120
+
121
+ def predict_fn(texto):
122
+ if not isinstance(texto, list):
123
+ texto = [texto]
124
+ details = []
125
+ for t in texto:
126
+ result = predict_text(t)
127
+ details.append({
128
+ "txt": t,
129
+ "probability": round(float(result["score"]), 3),
130
+ "risk": "ALTO" if result["label"] == 1 else "BAJO"
131
+ })
132
+ return details
133
+
134
+ iface = gr.Interface(fn=predict_fn, inputs="text", outputs="json")
135
+
136
+ if __name__ == "__main__":
137
+ iface.launch()
model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c4d1f868a5464614f1a4530426aa076d8fb254e4b29a9e5c0986599415c90e9
3
+ size 21791632
requirements.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tensorflow==2.20.0
2
+ tf-keras==2.20.1
3
+ keras-tcn>=3.5.6
4
+ transformers>=4.44.0
5
+ huggingface_hub
6
+ sentencepiece
7
+ annotated-types==0.7.0
8
+ anyio==4.9.0
9
+ async-timeout==5.0.1
10
+ asyncpg==0.30.0
11
+ bcrypt==4.3.0
12
+ certifi==2025.8.3
13
+ charset-normalizer==3.4.3
14
+ click==8.1.8
15
+ dnspython==2.7.0
16
+ ecdsa==0.19.1
17
+ email_validator==2.2.0
18
+ exceptiongroup==1.2.2
19
+ fastapi==0.115.12
20
+ greenlet==3.1.1
21
+ h11==0.16.0
22
+ httpcore==1.0.9
23
+ httpx==0.28.1
24
+ idna==3.10
25
+ passlib==1.7.4
26
+ psycopg2-binary==2.9.10
27
+ pyasn1==0.6.1
28
+ pydantic==2.11.2
29
+ pydantic_core==2.33.1
30
+ PyJWT==2.10.1
31
+ python-dotenv==1.1.0
32
+ python-jose==3.5.0
33
+ requests==2.32.5
34
+ resend==2.16.0
35
+ rsa==4.9.1
36
+ six==1.17.0
37
+ sniffio==1.3.1
38
+ SQLAlchemy==2.0.40
39
+ starlette==0.46.1
40
+ typing-inspection==0.4.0
41
+ typing_extensions==4.13.1
42
+ urllib3==2.5.0
43
+ uvicorn==0.34.0