Anudeep05 commited on
Commit
5e272cd
·
verified ·
1 Parent(s): a38e917

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -11
app.py CHANGED
@@ -3,57 +3,92 @@ from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
3
  from decord import VideoReader, cpu
4
  import gradio as gr
5
 
 
 
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
 
8
  # Load processor and model
9
- processor = VideoMAEImageProcessor.from_pretrained("MCG-NJU/videomae-small-finetuned-ssv2")
 
 
 
 
10
  model = VideoMAEForVideoClassification.from_pretrained(
11
  "MCG-NJU/videomae-small-finetuned-ssv2",
12
  num_labels=14,
13
  ignore_mismatched_sizes=True
14
  )
 
15
  checkpoint = torch.load("videomae_best.pth", map_location=device)
16
  model.load_state_dict(checkpoint["model_state_dict"])
17
  model.to(device)
18
  model.eval()
19
 
 
20
  # Class mapping
 
21
  id2class = {
22
- 0: "AFGHANISTAN", 1: "AFRICA", 2: "ANDHRA_PRADESH", 3: "ARGENTINA",
23
- 4: "DELHI", 5: "DENMARK", 6: "ENGLAND", 7: "GANGTOK",
24
- 8: "GOA", 9: "GUJARAT", 10: "HARYANA", 11: "HIMACHAL_PRADESH",
25
- 12: "JAIPUR", 13: "JAMMU_AND_KASHMIR"
 
 
 
 
 
 
 
 
 
 
26
  }
27
 
 
28
  # Video preprocessing
29
- def preprocess_video(video_path, processor, num_frames=16):
 
 
 
 
 
30
  vr = VideoReader(video_path, ctx=cpu(0))
31
  total_frames = len(vr)
 
32
  if total_frames < num_frames:
33
  indices = [i % total_frames for i in range(num_frames)]
34
  else:
35
  indices = torch.linspace(0, total_frames - 1, num_frames).long().tolist()
 
36
  video = vr.get_batch(indices).asnumpy()
37
  inputs = processor(list(video), return_tensors="pt")
38
  return inputs["pixel_values"][0]
39
 
 
40
  # Prediction function
 
41
  def predict_video(video_file):
42
- video_path = video_file.name
43
- pixel_values = preprocess_video(video_path, processor)
44
  pixel_values = pixel_values.unsqueeze(0).to(device)
 
45
  with torch.no_grad():
46
  logits = model(pixel_values=pixel_values).logits
47
  pred_index = torch.argmax(logits, dim=1).item()
 
48
  return id2class[pred_index]
49
 
50
- # Gradio interface
 
 
51
  iface = gr.Interface(
52
  fn=predict_video,
53
- inputs=gr.Video(), # just this
54
  outputs="text",
55
  title="VideoMAE Classification API",
56
- description="Upload a video and get the predicted class."
57
  )
58
 
 
59
  iface.launch(share=True)
 
3
  from decord import VideoReader, cpu
4
  import gradio as gr
5
 
6
+ # -------------------------------
7
+ # Device
8
+ # -------------------------------
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ # -------------------------------
12
  # Load processor and model
13
+ # -------------------------------
14
+ processor = VideoMAEImageProcessor.from_pretrained(
15
+ "MCG-NJU/videomae-small-finetuned-ssv2"
16
+ )
17
+
18
  model = VideoMAEForVideoClassification.from_pretrained(
19
  "MCG-NJU/videomae-small-finetuned-ssv2",
20
  num_labels=14,
21
  ignore_mismatched_sizes=True
22
  )
23
+
24
  checkpoint = torch.load("videomae_best.pth", map_location=device)
25
  model.load_state_dict(checkpoint["model_state_dict"])
26
  model.to(device)
27
  model.eval()
28
 
29
+ # -------------------------------
30
  # Class mapping
31
+ # -------------------------------
32
  id2class = {
33
+ 0: "AFGHANISTAN",
34
+ 1: "AFRICA",
35
+ 2: "ANDHRA_PRADESH",
36
+ 3: "ARGENTINA",
37
+ 4: "DELHI",
38
+ 5: "DENMARK",
39
+ 6: "ENGLAND",
40
+ 7: "GANGTOK",
41
+ 8: "GOA",
42
+ 9: "GUJARAT",
43
+ 10: "HARYANA",
44
+ 11: "HIMACHAL_PRADESH",
45
+ 12: "JAIPUR",
46
+ 13: "JAMMU_AND_KASHMIR"
47
  }
48
 
49
+ # -------------------------------
50
  # Video preprocessing
51
+ # -------------------------------
52
+ def preprocess_video(video_file, processor, num_frames=16):
53
+ """
54
+ Preprocess a video file-like object for VideoMAE.
55
+ """
56
+ video_path = video_file.name
57
  vr = VideoReader(video_path, ctx=cpu(0))
58
  total_frames = len(vr)
59
+
60
  if total_frames < num_frames:
61
  indices = [i % total_frames for i in range(num_frames)]
62
  else:
63
  indices = torch.linspace(0, total_frames - 1, num_frames).long().tolist()
64
+
65
  video = vr.get_batch(indices).asnumpy()
66
  inputs = processor(list(video), return_tensors="pt")
67
  return inputs["pixel_values"][0]
68
 
69
+ # -------------------------------
70
  # Prediction function
71
+ # -------------------------------
72
  def predict_video(video_file):
73
+ pixel_values = preprocess_video(video_file, processor)
 
74
  pixel_values = pixel_values.unsqueeze(0).to(device)
75
+
76
  with torch.no_grad():
77
  logits = model(pixel_values=pixel_values).logits
78
  pred_index = torch.argmax(logits, dim=1).item()
79
+
80
  return id2class[pred_index]
81
 
82
+ # -------------------------------
83
+ # Gradio Interface
84
+ # -------------------------------
85
  iface = gr.Interface(
86
  fn=predict_video,
87
+ inputs=gr.File(file_types=[".mp4"]), # Accept any MP4 file
88
  outputs="text",
89
  title="VideoMAE Classification API",
90
+ description="Upload a .mp4 video file to get the predicted class."
91
  )
92
 
93
+ # Launch Space (public URL)
94
  iface.launch(share=True)