Ullaas commited on
Commit
a708a9e
·
verified ·
1 Parent(s): 0a5d729

Upload 2 files

Browse files
Files changed (2) hide show
  1. backend/backend.py +134 -0
  2. backend/requirements.txt +5 -0
backend/backend.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ from flask import Flask, request, jsonify
5
+ from flask_cors import CORS
6
+ from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
7
+ from langgraph.graph import StateGraph
8
+ # --- Model and Workflow Setup ---
9
+ model_id = "ibm-granite/granite-4.0-tiny-preview"
10
+ processor = AutoProcessor.from_pretrained(model_id)
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
12
+ model = AutoModelForCausalLM.from_pretrained(model_id)
13
+ def generate_with_granite(prompt: str, max_tokens: int = 200, use_gpu: bool = False) -> str:
14
+ device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
15
+ model.to(device)
16
+ messages = [{"role": "user", "content": prompt}]
17
+ inputs = processor.apply_chat_template(
18
+ messages,
19
+ add_generation_prompt=True,
20
+ tokenize=True,
21
+ return_tensors="pt"
22
+ ).to(device)
23
+ outputs = model.generate(
24
+ input_ids=inputs,
25
+ max_new_tokens=max_tokens,
26
+ do_sample=True,
27
+ temperature=0.7,
28
+ top_p=0.9,
29
+ pad_token_id=tokenizer.eos_token_id
30
+ )
31
+ generated = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
32
+ return generated.strip()
33
+ def select_genre_node(state: dict) -> dict:
34
+ prompt = f"""
35
+ You are a creative assistant. The user wants to write a short animated story.
36
+ Based on the following input, suggest a suitable genre and tone for the story.
37
+ User Input: {state['user_input']}
38
+ Respond in this format:
39
+ Genre: <genre>
40
+ Tone: <tone>
41
+ """.strip()
42
+ response = generate_with_granite(prompt)
43
+ genre, tone = None, None
44
+ for line in response.splitlines():
45
+ if "Genre:" in line:
46
+ genre = line.split("Genre:")[1].strip()
47
+ elif "Tone:" in line:
48
+ tone = line.split("Tone:")[1].strip()
49
+ state["genre"] = genre
50
+ state["tone"] = tone
51
+ return state
52
+ def generate_outline_node(state: dict) -> dict:
53
+ prompt = f"""
54
+ You are a creative writing assistant helping to write a short animated screenplay.
55
+ The user wants to write a story with the following details:
56
+ Genre: {state.get('genre')}
57
+ Tone: {state.get('tone')}
58
+ Idea: {state.get('user_input')}
59
+ Write a brief plot outline (3–5 sentences) for the story.
60
+ """.strip()
61
+ response = generate_with_granite(prompt, max_tokens=250)
62
+ state["outline"] = response
63
+ return state
64
+ def generate_scene_node(state: dict) -> dict:
65
+ prompt = f"""
66
+ You are a screenwriter.
67
+ Based on the following plot outline, write a key scene from the story.
68
+ Focus on a turning point or climax moment. Make the scene vivid, descriptive, and suitable for an animated short film.
69
+ Genre: {state.get('genre')}
70
+ Tone: {state.get('tone')}
71
+ Outline: {state.get('outline')}
72
+ Write the scene in prose format (not screenplay format).
73
+ """.strip()
74
+ response = generate_with_granite(prompt, max_tokens=300)
75
+ state["scene"] = response
76
+ return state
77
+ def write_dialogue_node(state: dict) -> dict:
78
+ prompt = f"""
79
+ You are a dialogue writer for an animated screenplay.
80
+ Below is a scene from the story:
81
+ {state.get('scene')}
82
+ Write the dialogue between the characters in screenplay format.
83
+ Keep it short, expressive, and suitable for a short animated film.
84
+ Use character names (you may invent them if needed), and format as:
85
+ CHARACTER:
86
+ Dialogue line
87
+ CHARACTER:
88
+ Dialogue line
89
+ """.strip()
90
+ response = generate_with_granite(prompt, max_tokens=100)
91
+ state["dialogue"] = response
92
+ return state
93
+ def with_progress(fn, label, index, total):
94
+ def wrapper(state):
95
+ print(f"\n[{index}/{total}] Starting: {label}")
96
+ start = time.time()
97
+ result = fn(state)
98
+ duration = time.time() - start
99
+ print(f"[{index}/{total}] Completed: {label} in {duration:.2f} seconds")
100
+ return result
101
+ return wrapper
102
+ def build_workflow():
103
+ graph = StateGraph(dict)
104
+ graph.add_node("select_genre", with_progress(select_genre_node, "Select Genre", 1, 4))
105
+ graph.add_node("generate_outline", with_progress(generate_outline_node, "Generate Outline", 2, 4))
106
+ graph.add_node("generate_scene", with_progress(generate_scene_node, "Generate Scene", 3, 4))
107
+ graph.add_node("write_dialogue", with_progress(write_dialogue_node, "Write Dialogue", 4, 4))
108
+ graph.set_entry_point("select_genre")
109
+ graph.add_edge("select_genre", "generate_outline")
110
+ graph.add_edge("generate_outline", "generate_scene")
111
+ graph.add_edge("generate_scene", "write_dialogue")
112
+ graph.set_finish_point("write_dialogue")
113
+ return graph.compile()
114
+ workflow = build_workflow()
115
+ # --- Flask App ---
116
+ app = Flask(__name__)
117
+ CORS(app)
118
+ @app.route("/generate-story", methods=["POST"])
119
+ def generate_story():
120
+ data = request.get_json()
121
+ user_input = data.get("user_input")
122
+ if not user_input:
123
+ return jsonify({"error": "Missing 'user_input' in request."}), 400
124
+ initial_state = {"user_input": user_input}
125
+ final_state = workflow.invoke(initial_state)
126
+ return jsonify({
127
+ "genre": final_state.get("genre"),
128
+ "tone": final_state.get("tone"),
129
+ "outline": final_state.get("outline"),
130
+ "scene": final_state.get("scene"),
131
+ "dialogue": final_state.get("dialogue")
132
+ })
133
+ if __name__ == "__main__":
134
+ app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 8000)), debug=True)
backend/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ flask
2
+ flask-cors
3
+ torch
4
+ transformers
5
+ langgraph