Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import tensorflow as tf | |
| from tensorflow.keras.layers import ( | |
| Softmax, GlobalAveragePooling1D, GlobalMaxPooling1D, Activation, Concatenate, | |
| Conv1D, MultiHeadAttention, LayerNormalization, Input, LSTM, Embedding, | |
| Lambda, Dense, Dropout, concatenate, SpatialDropout1D, Bidirectional | |
| ) | |
| from tensorflow.keras.models import Model | |
| from transformers import TFAutoModel, AutoTokenizer | |
| from tcn import TCN | |
| import re | |
| import os | |
| bert_model_name = "dccuchile/bert-base-spanish-wwm-uncased" | |
| MAX_LEN = 274 | |
| WEIGHTS_PATH = os.getenv("WEIGHTS_PATH", "model.h5") | |
| THRESHOLD = float(os.getenv("THRESHOLD", "0.5")) | |
| tokenizer = AutoTokenizer.from_pretrained(bert_model_name) | |
| bert_model = TFAutoModel.from_pretrained(bert_model_name, output_hidden_states=False, output_attentions=False) | |
| bert_model.trainable = False | |
| def tcn_model_with_bert(bert_model_name="google-bert/bert-base-multilingual-uncased", max_length=512): | |
| input_ids = Input(shape=(max_length,), dtype=tf.int32, name='input_ids') | |
| attention_mask = Input(shape=(max_length,), dtype=tf.int32, name='attention_mask') | |
| def extract_bert_embeddings(inputs): | |
| return tf.cast( | |
| bert_model({'input_ids': inputs[0], 'attention_mask': inputs[1]}).last_hidden_state, | |
| tf.float32 | |
| ) | |
| bert_output = Lambda(extract_bert_embeddings, output_shape=(max_length, 768))([input_ids, attention_mask]) | |
| x = SpatialDropout1D(0.15)(bert_output) | |
| x = LSTM(128, activation='tanh', stateful=False, return_sequences=True, dropout=0.1)(x) | |
| x = LayerNormalization()(x) | |
| x = Bidirectional(TCN(128, dilations=[1, 2, 4, 8], kernel_size=5, return_sequences=True, activation='gelu', name='tcn1'))(x) | |
| gap = GlobalAveragePooling1D()(x) | |
| gmp = GlobalMaxPooling1D()(x) | |
| head = Concatenate()([gap, gmp]) | |
| head = Dense(64, activation="gelu")(head) | |
| head = Dropout(0.2)(head) | |
| outp = Dense(1, activation="sigmoid")(head) | |
| model = Model(inputs=[input_ids, attention_mask], outputs=outp) | |
| model.compile( | |
| optimizer=tf.keras.optimizers.AdamW(learning_rate=1e-4, weight_decay=0.01, clipnorm=1.0), | |
| loss="binary_crossentropy", | |
| metrics=['accuracy'] | |
| ) | |
| return model | |
| def preprocessing(text): | |
| if not isinstance(text, str) or not text: | |
| return "" | |
| text = re.sub(r'\s*https?://\S+(\s+|$)', ' ', text).strip() | |
| text = re.sub(r'\S*@\S*\s?', ' ', text).strip() | |
| text = re.sub(r'#\S*\s?', ' ', text).strip() | |
| text = re.sub(r'[.?!¡¿]+$', '', text) | |
| text = text.lower().strip() | |
| return text | |
| model = tcn_model_with_bert(bert_model_name=bert_model_name, max_length=MAX_LEN) | |
| if os.path.exists(WEIGHTS_PATH): | |
| try: | |
| model.load_weights(WEIGHTS_PATH) | |
| except Exception: | |
| from tensorflow.keras.models import load_model | |
| model = load_model(WEIGHTS_PATH, custom_objects={"TCN": TCN}) | |
| def predict_text(text: str, max_len: int = MAX_LEN, threshold: float = THRESHOLD): | |
| preprocessed_text = preprocessing(text) | |
| enc = tokenizer( | |
| preprocessed_text, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=max_len, | |
| return_tensors='tf' | |
| ) | |
| probs = model.predict( | |
| {'input_ids': enc['input_ids'], 'attention_mask': enc['attention_mask']}, | |
| verbose=0 | |
| ) | |
| score = float(probs[0][0]) | |
| label = int(score >= threshold) | |
| return { | |
| "txt": text, | |
| "probability": round(score, 3), | |
| "risk": "ALTO" if label == 1 else "BAJO" | |
| } | |
| app = FastAPI() | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| async def predict(payload: dict): | |
| textos = payload.get("texto", []) | |
| if not isinstance(textos, list): | |
| textos = [textos] | |
| details = [predict_text(t) for t in textos] | |
| return {"details": details} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |