AtAndDev commited on
Commit
6e7a915
·
verified ·
1 Parent(s): 9fd8dda

Upload ultravox_pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ultravox_pipeline.py +128 -0
ultravox_pipeline.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import numpy as np
5
+ import transformers
6
+
7
+ # We must use relative import in this directory to allow uploading to HF Hub
8
+ # Even "from . import X" pattern doesn't work (undocumented and unclear why)
9
+ from .ultravox_model import UltravoxModel
10
+ from .ultravox_processing import UltravoxProcessor
11
+
12
+
13
+ class UltravoxPipeline(transformers.Pipeline):
14
+ def __init__(
15
+ self,
16
+ model: UltravoxModel,
17
+ tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
18
+ audio_processor: Optional[transformers.ProcessorMixin] = None,
19
+ **kwargs
20
+ ):
21
+ if tokenizer is None:
22
+ try:
23
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
24
+ model.config._name_or_path
25
+ )
26
+ except:
27
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
28
+ model.config.text_model_id or model.config.text_config._name_or_path
29
+ )
30
+
31
+ if audio_processor is None:
32
+ audio_processor = transformers.AutoProcessor.from_pretrained(
33
+ model.config.audio_model_id or model.config.audio_config._name_or_path
34
+ )
35
+
36
+ super().__init__(model=model, tokenizer=tokenizer, **kwargs)
37
+
38
+ self.processor = UltravoxProcessor(
39
+ audio_processor=audio_processor,
40
+ tokenizer=tokenizer,
41
+ stack_factor=model.config.stack_factor,
42
+ audio_context_size=model.audio_tower_context_length,
43
+ )
44
+
45
+ def _sanitize_parameters(self, **kwargs):
46
+ generation_keys = ["temperature", "max_new_tokens", "repetition_penalty"]
47
+ generation_kwargs = {k: kwargs[k] for k in kwargs if k in generation_keys}
48
+ return {}, generation_kwargs, {}
49
+
50
+ def preprocess(self, inputs: Dict[str, Any]):
51
+ turns: list = inputs.get("turns", [])
52
+
53
+ audio = inputs.get("audio", None)
54
+ # Convert to float32 if needed.
55
+ if isinstance(audio, np.ndarray):
56
+ if audio.dtype == np.float64:
57
+ audio = audio.astype(np.float32)
58
+ elif audio.dtype == np.int16:
59
+ audio = audio.astype(np.float32) / np.float32(32768.0)
60
+ elif audio.dtype == np.int32:
61
+ audio = audio.astype(np.float32) / np.float32(2147483648.0)
62
+
63
+ if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"):
64
+ prompt = inputs.get("prompt", "<|audio|>")
65
+ if "<|audio|>" not in prompt:
66
+ logging.warning(
67
+ "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
68
+ )
69
+
70
+ prompt += " <|audio|>"
71
+ turns.append({"role": "user", "content": prompt})
72
+
73
+ text = self.processor.tokenizer.apply_chat_template(
74
+ turns, add_generation_prompt=True, tokenize=False
75
+ )
76
+
77
+ if "sampling_rate" not in inputs and audio is not None:
78
+ logging.warning(
79
+ "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
80
+ )
81
+
82
+ output = self.processor(
83
+ text=text,
84
+ audio=audio,
85
+ sampling_rate=inputs.get("sampling_rate", 16000),
86
+ )
87
+ if "audio_values" in output:
88
+ output["audio_values"] = output["audio_values"].to(self.model.dtype)
89
+
90
+ return output
91
+
92
+ def _forward(
93
+ self,
94
+ model_inputs: Dict[str, Any],
95
+ temperature: Optional[float] = None,
96
+ max_new_tokens: Optional[int] = None,
97
+ repetition_penalty: float = 1.1,
98
+ ) -> List[int]:
99
+ temperature = temperature or None
100
+ do_sample = temperature is not None
101
+
102
+ terminators = [self.tokenizer.eos_token_id]
103
+ if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
104
+ terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
105
+
106
+ input_len = model_inputs["input_ids"].shape[1]
107
+
108
+ outputs = self.model.generate(
109
+ **model_inputs,
110
+ do_sample=do_sample,
111
+ temperature=temperature,
112
+ max_new_tokens=max_new_tokens,
113
+ repetition_penalty=repetition_penalty,
114
+ eos_token_id=terminators
115
+ )
116
+ return outputs[0][input_len:]
117
+
118
+ def postprocess(self, model_outputs) -> str:
119
+ output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
120
+ return output_text
121
+
122
+
123
+ transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
124
+ "ultravox-pipeline",
125
+ pipeline_class=UltravoxPipeline,
126
+ pt_model=transformers.AutoModel,
127
+ type="multimodal",
128
+ )