from typing import Any, Dict, List import torch from transformers import AutoProcessor, Llama4ForConditionalGeneration class EndpointHandler(): def __init__(self, path=""): self.model=Llama4ForConditionalGeneration.from_pretrained(path,device_map="cuda",torch_dtype=torch.bfloat16) self.processor=AutoProcessor.from_pretrained(path) def __call__(self, data: Any): inputs = data.pop("inputs", data) messages = [ { "role": "user", "content": [ {"type": "text", "text": inputs} ] }, ] process_inputs=self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True, ).to('cuda') outputs=self.model.generate( **process_inputs, max_new_tokens=10, do_sample=False, ) response=self.processor.batch_decode(outputs[:, process_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0] return response