Alp11 commited on
Commit
d940ebd
·
verified ·
1 Parent(s): c571ab3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel
3
+ from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel
5
+ import os
6
+
7
+ MODEL_NAME = os.getenv("MODEL_NAME", "jhu-clsp/mmBERT-base")
8
+
9
+ app = FastAPI(title="ModernBERT Embedding API", version="1.0.0")
10
+
11
+ print("Loading model:", MODEL_NAME)
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
+ model = AutoModel.from_pretrained(MODEL_NAME)
15
+ model.eval()
16
+
17
+
18
+ class EmbedRequest(BaseModel):
19
+ text: str
20
+
21
+
22
+ @app.get("/health")
23
+ def health():
24
+ return {"status": "ok", "model": MODEL_NAME}
25
+
26
+
27
+ @app.post("/embed")
28
+ def embed(req: EmbedRequest):
29
+ text = (req.text or "").strip()
30
+
31
+ if not text:
32
+ raise HTTPException(status_code=400, detail="Empty text")
33
+
34
+ with torch.no_grad():
35
+ inputs = tokenizer(
36
+ text,
37
+ padding=True,
38
+ truncation=True,
39
+ max_length=512,
40
+ return_tensors="pt",
41
+ )
42
+
43
+ outputs = model(**inputs)
44
+
45
+ mask = inputs["attention_mask"].unsqueeze(-1)
46
+ embeddings = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1)
47
+
48
+ emb = embeddings[0].tolist()
49
+
50
+ return {
51
+ "model": MODEL_NAME,
52
+ "dim": len(emb),
53
+ "preview_first_8": [round(x, 4) for x in emb[:8]],
54
+ "embedding": emb,
55
+ }