Spaces:
Running
Running
gpt-omni
commited on
Commit
·
7ba9b1d
1
Parent(s):
e1adc1c
update
Browse files
app.py
CHANGED
|
@@ -128,7 +128,7 @@ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
|
|
| 128 |
stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
|
| 129 |
return torch.stack([audio_feature, audio_feature]), stacked_inputids
|
| 130 |
|
| 131 |
-
|
| 132 |
@spaces.GPU
|
| 133 |
def next_token_batch(
|
| 134 |
model: GPT,
|
|
@@ -156,7 +156,7 @@ def next_token_batch(
|
|
| 156 |
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
|
| 157 |
return next_audio_tokens, next_t
|
| 158 |
|
| 159 |
-
|
| 160 |
def load_audio(path):
|
| 161 |
audio = whisper.load_audio(path)
|
| 162 |
duration_ms = (len(audio) / 16000) * 1000
|
|
@@ -164,7 +164,7 @@ def load_audio(path):
|
|
| 164 |
mel = whisper.log_mel_spectrogram(audio)
|
| 165 |
return mel, int(duration_ms / 20) + 1
|
| 166 |
|
| 167 |
-
|
| 168 |
@spaces.GPU
|
| 169 |
def generate_audio_data(snac_tokens, snacmodel, device=None):
|
| 170 |
audio = reconstruct_tensors(snac_tokens, device)
|
|
@@ -190,7 +190,7 @@ def run_AT_batch_stream(
|
|
| 190 |
|
| 191 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
| 192 |
|
| 193 |
-
model.set_kv_cache(batch_size=2)
|
| 194 |
|
| 195 |
mel, leng = load_audio(audio_path)
|
| 196 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
|
|
@@ -295,7 +295,7 @@ def run_AT_batch_stream(
|
|
| 295 |
model.clear_kv_cache()
|
| 296 |
return list_output
|
| 297 |
|
| 298 |
-
|
| 299 |
for chunk in run_AT_batch_stream('./data/samples/output1.wav'):
|
| 300 |
pass
|
| 301 |
|
|
@@ -326,4 +326,4 @@ demo = gr.Interface(
|
|
| 326 |
# live=True,
|
| 327 |
)
|
| 328 |
demo.queue()
|
| 329 |
-
demo.launch()
|
|
|
|
| 128 |
stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
|
| 129 |
return torch.stack([audio_feature, audio_feature]), stacked_inputids
|
| 130 |
|
| 131 |
+
|
| 132 |
@spaces.GPU
|
| 133 |
def next_token_batch(
|
| 134 |
model: GPT,
|
|
|
|
| 156 |
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
|
| 157 |
return next_audio_tokens, next_t
|
| 158 |
|
| 159 |
+
|
| 160 |
def load_audio(path):
|
| 161 |
audio = whisper.load_audio(path)
|
| 162 |
duration_ms = (len(audio) / 16000) * 1000
|
|
|
|
| 164 |
mel = whisper.log_mel_spectrogram(audio)
|
| 165 |
return mel, int(duration_ms / 20) + 1
|
| 166 |
|
| 167 |
+
|
| 168 |
@spaces.GPU
|
| 169 |
def generate_audio_data(snac_tokens, snacmodel, device=None):
|
| 170 |
audio = reconstruct_tensors(snac_tokens, device)
|
|
|
|
| 190 |
|
| 191 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
| 192 |
|
| 193 |
+
model.set_kv_cache(batch_size=2, device=device)
|
| 194 |
|
| 195 |
mel, leng = load_audio(audio_path)
|
| 196 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
|
|
|
|
| 295 |
model.clear_kv_cache()
|
| 296 |
return list_output
|
| 297 |
|
| 298 |
+
|
| 299 |
for chunk in run_AT_batch_stream('./data/samples/output1.wav'):
|
| 300 |
pass
|
| 301 |
|
|
|
|
| 326 |
# live=True,
|
| 327 |
)
|
| 328 |
demo.queue()
|
| 329 |
+
demo.launch()
|