| | from typing import List |
| | from diffusers.modular_pipelines import ( |
| | PipelineState, |
| | ModularPipelineBlocks, |
| | InputParam, |
| | OutputParam, |
| | ) |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | import os |
| |
|
| | SYSTEM_PROMPT = ( |
| | "You are an expert image generation assistant. " |
| | "Take the user's short description and expand it into a vivid, detailed, and clear image generation prompt. " |
| | "Ensure rich colors, depth, realistic lighting, and an imaginative composition. " |
| | "Avoid vague terms — be specific about style, perspective, and mood. " |
| | "Try to keep the output under 512 tokens. " |
| | "Please don't return any prefix or suffix tokens, just the expanded user description." |
| | ) |
| |
|
| | class QwenPromptExpander(ModularPipelineBlocks): |
| | def __init__(self, model_id="Qwen/Qwen2.5-3B-Instruct", system_prompt=SYSTEM_PROMPT): |
| | super().__init__() |
| |
|
| | self.system_prompt = system_prompt |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | model_id, torch_dtype="auto" |
| | ).to("cuda") |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
| |
|
| | @property |
| | def expected_components(self): |
| | return [] |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam( |
| | "prompt", |
| | type_hint=str, |
| | required=True, |
| | description="Prompt to use", |
| | ) |
| | ] |
| | |
| | @property |
| | def intermediate_inputs(self) -> List[InputParam]: |
| | return [] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "prompt", |
| | type_hint=str, |
| | description="Expanded prompt by the LLM", |
| | ), |
| | OutputParam( |
| | "old_prompt", |
| | type_hint=str, |
| | description="Old prompt provided by the user", |
| | ) |
| | ] |
| |
|
| |
|
| | def __call__(self, components, state: PipelineState) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| |
|
| | old_prompt = block_state.prompt |
| | print(f"Actual prompt: {old_prompt}") |
| | |
| | messages = [ |
| | {"role": "system", "content": self.system_prompt}, |
| | {"role": "user", "content": old_prompt} |
| | ] |
| | text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) |
| | generated_ids = self.model.generate(**model_inputs,max_new_tokens=512) |
| | generated_ids = [ |
| | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
| | ] |
| |
|
| | block_state.prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| | block_state.old_prompt = old_prompt |
| | print(f"{block_state.prompt=}") |
| | self.set_block_state(state, block_state) |
| |
|
| | return components, state |