File size: 1,148 Bytes
d0d4eb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from typing import Any, Dict, List
import torch
from transformers import AutoProcessor, Llama4ForConditionalGeneration
class EndpointHandler():
def __init__(self, path=""):
self.model=Llama4ForConditionalGeneration.from_pretrained(path,device_map="cuda",torch_dtype=torch.bfloat16)
self.processor=AutoProcessor.from_pretrained(path)
def __call__(self, data: Any):
inputs = data.pop("inputs", data)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": inputs}
]
},
]
process_inputs=self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to('cuda')
outputs=self.model.generate(
**process_inputs,
max_new_tokens=10,
do_sample=False,
)
response=self.processor.batch_decode(outputs[:, process_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
return response
|