Anudeep05 commited on
Commit
a38e917
·
verified ·
1 Parent(s): dab5eb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -32
app.py CHANGED
@@ -3,14 +3,9 @@ from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
3
  from decord import VideoReader, cpu
4
  import gradio as gr
5
 
6
- # -------------------------------
7
- # Device
8
- # -------------------------------
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- # -------------------------------
12
  # Load processor and model
13
- # -------------------------------
14
  processor = VideoMAEImageProcessor.from_pretrained("MCG-NJU/videomae-small-finetuned-ssv2")
15
  model = VideoMAEForVideoClassification.from_pretrained(
16
  "MCG-NJU/videomae-small-finetuned-ssv2",
@@ -22,29 +17,15 @@ model.load_state_dict(checkpoint["model_state_dict"])
22
  model.to(device)
23
  model.eval()
24
 
25
- # -------------------------------
26
  # Class mapping
27
- # -------------------------------
28
  id2class = {
29
- 0: "AFGHANISTAN",
30
- 1: "AFRICA",
31
- 2: "ANDHRA_PRADESH",
32
- 3: "ARGENTINA",
33
- 4: "DELHI",
34
- 5: "DENMARK",
35
- 6: "ENGLAND",
36
- 7: "GANGTOK",
37
- 8: "GOA",
38
- 9: "GUJARAT",
39
- 10: "HARYANA",
40
- 11: "HIMACHAL_PRADESH",
41
- 12: "JAIPUR",
42
- 13: "JAMMU_AND_KASHMIR"
43
  }
44
 
45
- # -------------------------------
46
  # Video preprocessing
47
- # -------------------------------
48
  def preprocess_video(video_path, processor, num_frames=16):
49
  vr = VideoReader(video_path, ctx=cpu(0))
50
  total_frames = len(vr)
@@ -56,11 +37,8 @@ def preprocess_video(video_path, processor, num_frames=16):
56
  inputs = processor(list(video), return_tensors="pt")
57
  return inputs["pixel_values"][0]
58
 
59
- # -------------------------------
60
  # Prediction function
61
- # -------------------------------
62
  def predict_video(video_file):
63
- # video_file is a file-like object from Gradio
64
  video_path = video_file.name
65
  pixel_values = preprocess_video(video_path, processor)
66
  pixel_values = pixel_values.unsqueeze(0).to(device)
@@ -69,16 +47,13 @@ def predict_video(video_file):
69
  pred_index = torch.argmax(logits, dim=1).item()
70
  return id2class[pred_index]
71
 
72
- # -------------------------------
73
- # Gradio Interface
74
- # -------------------------------
75
  iface = gr.Interface(
76
  fn=predict_video,
77
- inputs=gr.Video(source="upload"), # corrected argument
78
  outputs="text",
79
  title="VideoMAE Classification API",
80
  description="Upload a video and get the predicted class."
81
  )
82
 
83
- # Expose API
84
- iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
3
  from decord import VideoReader, cpu
4
  import gradio as gr
5
 
 
 
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
 
8
  # Load processor and model
 
9
  processor = VideoMAEImageProcessor.from_pretrained("MCG-NJU/videomae-small-finetuned-ssv2")
10
  model = VideoMAEForVideoClassification.from_pretrained(
11
  "MCG-NJU/videomae-small-finetuned-ssv2",
 
17
  model.to(device)
18
  model.eval()
19
 
 
20
  # Class mapping
 
21
  id2class = {
22
+ 0: "AFGHANISTAN", 1: "AFRICA", 2: "ANDHRA_PRADESH", 3: "ARGENTINA",
23
+ 4: "DELHI", 5: "DENMARK", 6: "ENGLAND", 7: "GANGTOK",
24
+ 8: "GOA", 9: "GUJARAT", 10: "HARYANA", 11: "HIMACHAL_PRADESH",
25
+ 12: "JAIPUR", 13: "JAMMU_AND_KASHMIR"
 
 
 
 
 
 
 
 
 
 
26
  }
27
 
 
28
  # Video preprocessing
 
29
  def preprocess_video(video_path, processor, num_frames=16):
30
  vr = VideoReader(video_path, ctx=cpu(0))
31
  total_frames = len(vr)
 
37
  inputs = processor(list(video), return_tensors="pt")
38
  return inputs["pixel_values"][0]
39
 
 
40
  # Prediction function
 
41
  def predict_video(video_file):
 
42
  video_path = video_file.name
43
  pixel_values = preprocess_video(video_path, processor)
44
  pixel_values = pixel_values.unsqueeze(0).to(device)
 
47
  pred_index = torch.argmax(logits, dim=1).item()
48
  return id2class[pred_index]
49
 
50
+ # Gradio interface
 
 
51
  iface = gr.Interface(
52
  fn=predict_video,
53
+ inputs=gr.Video(), # just this
54
  outputs="text",
55
  title="VideoMAE Classification API",
56
  description="Upload a video and get the predicted class."
57
  )
58
 
59
+ iface.launch(share=True)