| from typing import Any, Dict, List | |
| import torch | |
| from transformers import AutoProcessor, Llama4ForConditionalGeneration | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| self.model=Llama4ForConditionalGeneration.from_pretrained(path,device_map="cuda",torch_dtype=torch.bfloat16) | |
| self.processor=AutoProcessor.from_pretrained(path) | |
| def __call__(self, data: Any): | |
| inputs = data.pop("inputs", data) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": inputs} | |
| ] | |
| }, | |
| ] | |
| process_inputs=self.processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| return_dict=True, | |
| ).to('cuda') | |
| outputs=self.model.generate( | |
| **process_inputs, | |
| max_new_tokens=10, | |
| do_sample=False, | |
| ) | |
| response=self.processor.batch_decode(outputs[:, process_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0] | |
| return response | |