Richard Guo
commited on
Commit
·
81aaa4e
1
Parent(s):
f47c911
huggingface cli requirement and webhook route
Browse files- main.py +38 -11
- requirements.txt +1 -0
main.py
CHANGED
|
@@ -1,25 +1,48 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from fastapi.responses import HTMLResponse
|
| 3 |
from fastapi.templating import Jinja2Templates
|
| 4 |
-
|
| 5 |
-
from uuid import uuid4
|
| 6 |
-
import time
|
| 7 |
-
import asyncio
|
| 8 |
|
| 9 |
from build_map import load_dataset_and_metadata, upload_dataset_to_atlas
|
| 10 |
-
from models import WebhookPayload
|
| 11 |
|
|
|
|
|
|
|
| 12 |
|
| 13 |
app = FastAPI()
|
| 14 |
# TODO: use task management queue
|
| 15 |
tasks = {}
|
| 16 |
templates = Jinja2Templates(directory="templates")
|
| 17 |
|
| 18 |
-
def upload_atlas_task(task_id,
|
|
|
|
|
|
|
|
|
|
| 19 |
dataset_dict = load_dataset_and_metadata(dataset_name)
|
| 20 |
-
map_url = upload_dataset_to_atlas(dataset_dict
|
| 21 |
tasks[task_id]['status'] = 'done'
|
| 22 |
tasks[task_id]['url'] = map_url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
@app.on_event("startup")
|
| 25 |
async def startup_event():
|
|
@@ -47,7 +70,6 @@ async def form_post(background_tasks: BackgroundTasks, dataset_name: str = Form(
|
|
| 47 |
tasks[task_id] = {'status': 'running'}
|
| 48 |
#form_data = DatasetForm(dataset_name=dataset_name)
|
| 49 |
background_tasks.add_task(upload_atlas_task, task_id, dataset_name)
|
| 50 |
-
|
| 51 |
return {'task_id': task_id}
|
| 52 |
|
| 53 |
@app.get("/status/{task_id}")
|
|
@@ -58,7 +80,12 @@ async def read_task(task_id: str):
|
|
| 58 |
return tasks[task_id]
|
| 59 |
|
| 60 |
@app.post("/webhook")
|
| 61 |
-
async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayload):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
if not (
|
| 63 |
payload.event.action == "update"
|
| 64 |
and payload.event.scope.startswith("repo.content")
|
|
@@ -69,5 +96,5 @@ async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayloa
|
|
| 69 |
task_id = str(uuid4())
|
| 70 |
tasks[task_id] = {'status': 'running'}
|
| 71 |
#form_data = DatasetForm(dataset_name=dataset_name)
|
| 72 |
-
background_tasks.add_task(upload_atlas_task, task_id, payload.repo.name)
|
| 73 |
return {'task_id': task_id}
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from uuid import uuid4
|
| 6 |
+
|
| 7 |
+
from fastapi import FastAPI, Form, Header, HTTPException, Request, BackgroundTasks
|
| 8 |
from fastapi.responses import HTMLResponse
|
| 9 |
from fastapi.templating import Jinja2Templates
|
| 10 |
+
from huggingface_hub import create_discussion, comment_discussion
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from build_map import load_dataset_and_metadata, upload_dataset_to_atlas
|
| 13 |
+
from models import WebhookPayload
|
| 14 |
|
| 15 |
+
WEBHOOK_SECRET = os.environ.get("WEBHOOK_SECRET")
|
| 16 |
+
HUGGINGFACE_ACCESS_TOKEN = os.environ.get("HUGGINGFACE_ACCESS_TOKEN")
|
| 17 |
|
| 18 |
app = FastAPI()
|
| 19 |
# TODO: use task management queue
|
| 20 |
tasks = {}
|
| 21 |
templates = Jinja2Templates(directory="templates")
|
| 22 |
|
| 23 |
+
def upload_atlas_task(task_id,
|
| 24 |
+
dataset_name,
|
| 25 |
+
webhook_payload: WebhookPayload = None,
|
| 26 |
+
webhook_notify: bool = False):
|
| 27 |
dataset_dict = load_dataset_and_metadata(dataset_name)
|
| 28 |
+
map_url = upload_dataset_to_atlas(dataset_dict)
|
| 29 |
tasks[task_id]['status'] = 'done'
|
| 30 |
tasks[task_id]['url'] = map_url
|
| 31 |
+
tasks[task_id]['finish_time'] = time.time()
|
| 32 |
+
|
| 33 |
+
if webhook_notify:
|
| 34 |
+
discussion = create_discussion(
|
| 35 |
+
repo_id=webhook_payload.repo.id,
|
| 36 |
+
title="Atlas Maps",
|
| 37 |
+
token=HUGGINGFACE_ACCESS_TOKEN,
|
| 38 |
+
)
|
| 39 |
+
comment_discussion(
|
| 40 |
+
repo_id=webhook_payload.repo.id,
|
| 41 |
+
discussion_num=discussion.num,
|
| 42 |
+
comment="Atlas Map: " + map_url,
|
| 43 |
+
token=HUGGINGFACE_ACCESS_TOKEN
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
|
| 47 |
@app.on_event("startup")
|
| 48 |
async def startup_event():
|
|
|
|
| 70 |
tasks[task_id] = {'status': 'running'}
|
| 71 |
#form_data = DatasetForm(dataset_name=dataset_name)
|
| 72 |
background_tasks.add_task(upload_atlas_task, task_id, dataset_name)
|
|
|
|
| 73 |
return {'task_id': task_id}
|
| 74 |
|
| 75 |
@app.get("/status/{task_id}")
|
|
|
|
| 80 |
return tasks[task_id]
|
| 81 |
|
| 82 |
@app.post("/webhook")
|
| 83 |
+
async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayload, x_webhook_secret: Optional[str] = Header(default=None)):
|
| 84 |
+
if x_webhook_secret is None:
|
| 85 |
+
raise HTTPException(401)
|
| 86 |
+
if x_webhook_secret != WEBHOOK_SECRET:
|
| 87 |
+
raise HTTPException(403)
|
| 88 |
+
|
| 89 |
if not (
|
| 90 |
payload.event.action == "update"
|
| 91 |
and payload.event.scope.startswith("repo.content")
|
|
|
|
| 96 |
task_id = str(uuid4())
|
| 97 |
tasks[task_id] = {'status': 'running'}
|
| 98 |
#form_data = DatasetForm(dataset_name=dataset_name)
|
| 99 |
+
background_tasks.add_task(upload_atlas_task, task_id, payload.repo.name, payload, True)
|
| 100 |
return {'task_id': task_id}
|
requirements.txt
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
datasets==2.13.0
|
| 2 |
fastapi[all]
|
|
|
|
| 3 |
nomic==2.0.3
|
| 4 |
pandas==1.5.3
|
| 5 |
pyarrow==12.0.1
|
|
|
|
| 1 |
datasets==2.13.0
|
| 2 |
fastapi[all]
|
| 3 |
+
huggingface-hub==0.16.4
|
| 4 |
nomic==2.0.3
|
| 5 |
pandas==1.5.3
|
| 6 |
pyarrow==12.0.1
|