Anudeep Tippabathuni commited on
Commit
47b7a70
·
1 Parent(s): c4a7cf5

Add model weights via Git LFS

Browse files
Files changed (3) hide show
  1. app.py +84 -0
  2. requirements.txt +5 -0
  3. videomae_best.pth +3 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
3
+ from decord import VideoReader, cpu
4
+ import gradio as gr
5
+
6
+ # -------------------------------
7
+ # Device
8
+ # -------------------------------
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # -------------------------------
12
+ # Load processor and model
13
+ # -------------------------------
14
+ processor = VideoMAEImageProcessor.from_pretrained("MCG-NJU/videomae-small-finetuned-ssv2")
15
+ model = VideoMAEForVideoClassification.from_pretrained(
16
+ "MCG-NJU/videomae-small-finetuned-ssv2",
17
+ num_labels=14,
18
+ ignore_mismatched_sizes=True
19
+ )
20
+ checkpoint = torch.load("videomae_best.pth", map_location=device)
21
+ model.load_state_dict(checkpoint["model_state_dict"])
22
+ model.to(device)
23
+ model.eval()
24
+
25
+ # -------------------------------
26
+ # Class mapping
27
+ # -------------------------------
28
+ id2class = {
29
+ 0: "AFGHANISTAN",
30
+ 1: "AFRICA",
31
+ 2: "ANDHRA_PRADESH",
32
+ 3: "ARGENTINA",
33
+ 4: "DELHI",
34
+ 5: "DENMARK",
35
+ 6: "ENGLAND",
36
+ 7: "GANGTOK",
37
+ 8: "GOA",
38
+ 9: "GUJARAT",
39
+ 10: "HARYANA",
40
+ 11: "HIMACHAL_PRADESH",
41
+ 12: "JAIPUR",
42
+ 13: "JAMMU_AND_KASHMIR"
43
+ }
44
+
45
+ # -------------------------------
46
+ # Video preprocessing
47
+ # -------------------------------
48
+ def preprocess_video(video_path, processor, num_frames=16):
49
+ vr = VideoReader(video_path, ctx=cpu(0))
50
+ total_frames = len(vr)
51
+ if total_frames < num_frames:
52
+ indices = [i % total_frames for i in range(num_frames)]
53
+ else:
54
+ indices = torch.linspace(0, total_frames - 1, num_frames).long().tolist()
55
+ video = vr.get_batch(indices).asnumpy()
56
+ inputs = processor(list(video), return_tensors="pt")
57
+ return inputs["pixel_values"][0]
58
+
59
+ # -------------------------------
60
+ # Prediction function
61
+ # -------------------------------
62
+ def predict_video(video_file):
63
+ # video_file is a file-like object from Gradio
64
+ video_path = video_file.name
65
+ pixel_values = preprocess_video(video_path, processor)
66
+ pixel_values = pixel_values.unsqueeze(0).to(device)
67
+ with torch.no_grad():
68
+ logits = model(pixel_values=pixel_values).logits
69
+ pred_index = torch.argmax(logits, dim=1).item()
70
+ return id2class[pred_index]
71
+
72
+ # -------------------------------
73
+ # Gradio Interface
74
+ # -------------------------------
75
+ iface = gr.Interface(
76
+ fn=predict_video,
77
+ inputs=gr.Video(type="file"),
78
+ outputs="text",
79
+ title="VideoMAE Classification API",
80
+ description="Upload a video and get the predicted class."
81
+ )
82
+
83
+ # Expose API
84
+ iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ decord
4
+ gradio
5
+ numpy
videomae_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abe3dfa9deb07a43cf2dfb246a5b33fe11f5b1223237b210abee2aed11844d92
3
+ size 144487411