Ullaas commited on
Commit
76c5a46
·
verified ·
1 Parent(s): da9b7b6

Update backend.py

Browse files
Files changed (1) hide show
  1. backend.py +91 -62
backend.py CHANGED
@@ -1,72 +1,84 @@
1
  import os
 
 
2
  from flask import Flask, request, jsonify
3
  from flask_cors import CORS
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- import torch
6
-
7
- # Path to your local model directory (update as needed)
8
- MODEL_PATH = "ibm-granite/granite-4.0-tiny-preview"
9
-
10
- # Load model and tokenizer
11
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
12
- model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float32)
13
-
14
- def generate_text(prompt, max_tokens=200):
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  model.to(device)
17
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
 
 
 
 
18
  outputs = model.generate(
19
- **inputs,
20
  max_new_tokens=max_tokens,
21
  do_sample=True,
22
  temperature=0.7,
23
  top_p=0.9,
24
  pad_token_id=tokenizer.eos_token_id
25
  )
26
- return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
27
-
28
- def select_genre_and_tone(user_input):
29
- prompt = f"""You are a creative assistant. The user wants to write a short animated story.
 
30
  Based on the following input, suggest a suitable genre and tone for the story.
31
- User Input: {user_input}
32
  Respond in this format:
33
  Genre: <genre>
34
  Tone: <tone>
35
- """
36
- response = generate_text(prompt, max_tokens=64)
37
  genre, tone = None, None
38
  for line in response.splitlines():
39
  if "Genre:" in line:
40
  genre = line.split("Genre:")[1].strip()
41
  elif "Tone:" in line:
42
  tone = line.split("Tone:")[1].strip()
43
- return genre or "Unknown", tone or "Unknown"
44
-
45
- def generate_outline(user_input, genre, tone):
46
- prompt = f"""You are a creative writing assistant helping to write a short animated screenplay.
 
 
47
  The user wants to write a story with the following details:
48
- Genre: {genre}
49
- Tone: {tone}
50
- Idea: {user_input}
51
  Write a brief plot outline (3–5 sentences) for the story.
52
- """
53
- return generate_text(prompt, max_tokens=128)
54
-
55
- def generate_scene(genre, tone, outline):
56
- prompt = f"""You are a screenwriter.
 
 
57
  Based on the following plot outline, write a key scene from the story.
58
  Focus on a turning point or climax moment. Make the scene vivid, descriptive, and suitable for an animated short film.
59
- Genre: {genre}
60
- Tone: {tone}
61
- Outline: {outline}
62
  Write the scene in prose format (not screenplay format).
63
- """
64
- return generate_text(prompt, max_tokens=128)
65
-
66
- def generate_dialogue(scene):
67
- prompt = f"""You are a dialogue writer for an animated screenplay.
 
 
68
  Below is a scene from the story:
69
- {scene}
70
  Write the dialogue between the characters in screenplay format.
71
  Keep it short, expressive, and suitable for a short animated film.
72
  Use character names (you may invent them if needed), and format as:
@@ -74,32 +86,49 @@ CHARACTER:
74
  Dialogue line
75
  CHARACTER:
76
  Dialogue line
77
- """
78
- return generate_text(prompt, max_tokens=128)
79
-
80
- # Flask app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  app = Flask(__name__)
82
  CORS(app)
83
-
84
  @app.route("/generate-story", methods=["POST"])
85
  def generate_story():
86
  data = request.get_json()
87
- user_input = data.get("user_input", "")
88
  if not user_input:
89
- return jsonify({"error": "Missing user_input"}), 400
90
-
91
- genre, tone = select_genre_and_tone(user_input)
92
- outline = generate_outline(user_input, genre, tone)
93
- scene = generate_scene(genre, tone, outline)
94
- dialogue = generate_dialogue(scene)
95
-
96
  return jsonify({
97
- "genre": genre,
98
- "tone": tone,
99
- "outline": outline,
100
- "scene": scene,
101
- "dialogue": dialogue
102
  })
103
-
104
  if __name__ == "__main__":
105
- app.run(host="0.0.0.0", port=8000, debug=True)
 
1
  import os
2
+ import time
3
+ import torch
4
  from flask import Flask, request, jsonify
5
  from flask_cors import CORS
6
+ from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
7
+ from langgraph.graph import StateGraph
8
+ # --- Model and Workflow Setup ---
9
+ model_id = "ibm-granite/granite-4.0-tiny-preview"
10
+ processor = AutoProcessor.from_pretrained(model_id)
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
12
+ model = AutoModelForCausalLM.from_pretrained(model_id)
13
+ def generate_with_granite(prompt: str, max_tokens: int = 200, use_gpu: bool = False) -> str:
14
+ device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
 
 
 
15
  model.to(device)
16
+ messages = [{"role": "user", "content": prompt}]
17
+ inputs = processor.apply_chat_template(
18
+ messages,
19
+ add_generation_prompt=True,
20
+ tokenize=True,
21
+ return_tensors="pt"
22
+ ).to(device)
23
  outputs = model.generate(
24
+ input_ids=inputs,
25
  max_new_tokens=max_tokens,
26
  do_sample=True,
27
  temperature=0.7,
28
  top_p=0.9,
29
  pad_token_id=tokenizer.eos_token_id
30
  )
31
+ generated = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
32
+ return generated.strip()
33
+ def select_genre_node(state: dict) -> dict:
34
+ prompt = f"""
35
+ You are a creative assistant. The user wants to write a short animated story.
36
  Based on the following input, suggest a suitable genre and tone for the story.
37
+ User Input: {state['user_input']}
38
  Respond in this format:
39
  Genre: <genre>
40
  Tone: <tone>
41
+ """.strip()
42
+ response = generate_with_granite(prompt)
43
  genre, tone = None, None
44
  for line in response.splitlines():
45
  if "Genre:" in line:
46
  genre = line.split("Genre:")[1].strip()
47
  elif "Tone:" in line:
48
  tone = line.split("Tone:")[1].strip()
49
+ state["genre"] = genre
50
+ state["tone"] = tone
51
+ return state
52
+ def generate_outline_node(state: dict) -> dict:
53
+ prompt = f"""
54
+ You are a creative writing assistant helping to write a short animated screenplay.
55
  The user wants to write a story with the following details:
56
+ Genre: {state.get('genre')}
57
+ Tone: {state.get('tone')}
58
+ Idea: {state.get('user_input')}
59
  Write a brief plot outline (3–5 sentences) for the story.
60
+ """.strip()
61
+ response = generate_with_granite(prompt, max_tokens=250)
62
+ state["outline"] = response
63
+ return state
64
+ def generate_scene_node(state: dict) -> dict:
65
+ prompt = f"""
66
+ You are a screenwriter.
67
  Based on the following plot outline, write a key scene from the story.
68
  Focus on a turning point or climax moment. Make the scene vivid, descriptive, and suitable for an animated short film.
69
+ Genre: {state.get('genre')}
70
+ Tone: {state.get('tone')}
71
+ Outline: {state.get('outline')}
72
  Write the scene in prose format (not screenplay format).
73
+ """.strip()
74
+ response = generate_with_granite(prompt, max_tokens=300)
75
+ state["scene"] = response
76
+ return state
77
+ def write_dialogue_node(state: dict) -> dict:
78
+ prompt = f"""
79
+ You are a dialogue writer for an animated screenplay.
80
  Below is a scene from the story:
81
+ {state.get('scene')}
82
  Write the dialogue between the characters in screenplay format.
83
  Keep it short, expressive, and suitable for a short animated film.
84
  Use character names (you may invent them if needed), and format as:
 
86
  Dialogue line
87
  CHARACTER:
88
  Dialogue line
89
+ """.strip()
90
+ response = generate_with_granite(prompt, max_tokens=100)
91
+ state["dialogue"] = response
92
+ return state
93
+ def with_progress(fn, label, index, total):
94
+ def wrapper(state):
95
+ print(f"\n[{index}/{total}] Starting: {label}")
96
+ start = time.time()
97
+ result = fn(state)
98
+ duration = time.time() - start
99
+ print(f"[{index}/{total}] Completed: {label} in {duration:.2f} seconds")
100
+ return result
101
+ return wrapper
102
+ def build_workflow():
103
+ graph = StateGraph(dict)
104
+ graph.add_node("select_genre", with_progress(select_genre_node, "Select Genre", 1, 4))
105
+ graph.add_node("generate_outline", with_progress(generate_outline_node, "Generate Outline", 2, 4))
106
+ graph.add_node("generate_scene", with_progress(generate_scene_node, "Generate Scene", 3, 4))
107
+ graph.add_node("write_dialogue", with_progress(write_dialogue_node, "Write Dialogue", 4, 4))
108
+ graph.set_entry_point("select_genre")
109
+ graph.add_edge("select_genre", "generate_outline")
110
+ graph.add_edge("generate_outline", "generate_scene")
111
+ graph.add_edge("generate_scene", "write_dialogue")
112
+ graph.set_finish_point("write_dialogue")
113
+ return graph.compile()
114
+ workflow = build_workflow()
115
+ # --- Flask App ---
116
  app = Flask(__name__)
117
  CORS(app)
 
118
  @app.route("/generate-story", methods=["POST"])
119
  def generate_story():
120
  data = request.get_json()
121
+ user_input = data.get("user_input")
122
  if not user_input:
123
+ return jsonify({"error": "Missing 'user_input' in request."}), 400
124
+ initial_state = {"user_input": user_input}
125
+ final_state = workflow.invoke(initial_state)
 
 
 
 
126
  return jsonify({
127
+ "genre": final_state.get("genre"),
128
+ "tone": final_state.get("tone"),
129
+ "outline": final_state.get("outline"),
130
+ "scene": final_state.get("scene"),
131
+ "dialogue": final_state.get("dialogue")
132
  })
 
133
  if __name__ == "__main__":
134
+ app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 8000)), debug=True)