preserve only few last kv
Browse files- README.md +3 -1
- audiocraft/builders.py +8 -14
- audiocraft/transformer.py +13 -21
- msinference.py +1 -1
README.md
CHANGED
|
@@ -67,7 +67,9 @@ CUDA_DEVICE_ORDER=PCI_BUS_ID HF_HOME=/data/dkounadis/.hf7/ CUDA_VISIBLE_DEVICES=
|
|
| 67 |
|
| 68 |
Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
|
| 69 |
|
| 70 |
-
|
|
|
|
|
|
|
| 71 |
|
| 72 |
</details>
|
| 73 |
|
|
|
|
| 67 |
|
| 68 |
Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
|
| 69 |
|
| 70 |
+
```
|
| 71 |
+
python tts.py --text assets/ocr.txt --image assets/ocr.jpg --soundscape "battle hero" --voice romanian
|
| 72 |
+
```
|
| 73 |
|
| 74 |
</details>
|
| 75 |
|
audiocraft/builders.py
CHANGED
|
@@ -11,15 +11,13 @@ from .lm import LMModel
|
|
| 11 |
from .seanet import SEANetDecoder
|
| 12 |
from .vq import ResidualVectorQuantizer
|
| 13 |
|
| 14 |
-
N_REPEAT =
|
| 15 |
|
| 16 |
def _shift(x):
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
x[i, :] = torch.roll(batch_elem, offset, dims=0) # batch_elem = [400000, ]
|
| 22 |
-
return x
|
| 23 |
|
| 24 |
def _delete_param(cfg, full_name):
|
| 25 |
parts = full_name.split('.')
|
|
@@ -70,18 +68,14 @@ class AudioGen(nn.Module):
|
|
| 70 |
|
| 71 |
# AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
|
| 72 |
|
| 73 |
-
x = self.resample_fn(x)
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
x = x.repeat(1, N_REPEAT)
|
| 78 |
-
|
| 79 |
-
# less periodic - shift every batch elem
|
| 80 |
|
| 81 |
for _ in range(7):
|
| 82 |
x = _shift(x)
|
| 83 |
|
| 84 |
-
|
| 85 |
print(x.abs().max(), 'MAX')
|
| 86 |
return x / (x.abs().max() + 1e-7)
|
| 87 |
|
|
|
|
| 11 |
from .seanet import SEANetDecoder
|
| 12 |
from .vq import ResidualVectorQuantizer
|
| 13 |
|
| 14 |
+
N_REPEAT = 3 # num (virtual batch_size) clones of audio sounds
|
| 15 |
|
| 16 |
def _shift(x):
|
| 17 |
+
n = x.shape[0]
|
| 18 |
+
offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
|
| 19 |
+
return torch.roll(x, offset, dims=0)
|
| 20 |
+
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def _delete_param(cfg, full_name):
|
| 23 |
parts = full_name.split('.')
|
|
|
|
| 68 |
|
| 69 |
# AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
|
| 70 |
|
| 71 |
+
x = self.resample_fn(x) # [N_REPEAT, duration]
|
| 72 |
|
| 73 |
+
x = x.repeat(1, N_REPEAT).reshape(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
for _ in range(7):
|
| 76 |
x = _shift(x)
|
| 77 |
|
| 78 |
+
|
| 79 |
print(x.abs().max(), 'MAX')
|
| 80 |
return x / (x.abs().max() + 1e-7)
|
| 81 |
|
audiocraft/transformer.py
CHANGED
|
@@ -3,8 +3,8 @@ import torch.nn as nn
|
|
| 3 |
from torch.nn import functional as F
|
| 4 |
from einops import rearrange
|
| 5 |
|
| 6 |
-
def create_sin_embedding(positions,
|
| 7 |
-
dim,
|
| 8 |
max_period = 10000,
|
| 9 |
dtype = torch.float32):
|
| 10 |
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
|
|
@@ -78,28 +78,20 @@ class StreamingMultiheadAttention(nn.Module):
|
|
| 78 |
|
| 79 |
|
| 80 |
if self.k_history is not None:
|
| 81 |
-
#
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
self.k_history = torch.cat([self.k_history, k], 2) # IF ctrl^c here during live demo it is non-atomic k!=v
|
| 84 |
self.v_history = torch.cat([self.v_history, v], 2) # thus it will try to continue with incompatible k/v dims!
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# find LOWEST l2 norm of keys > https://arxiv.org/pdf/2406.11430v4
|
| 89 |
-
|
| 90 |
-
low_norm = (self.k_history * self.k_history).mean(3, keepdims=True).sum(1, keepdims=True) # [bs, 24, T, 64] -> [bs, T]
|
| 91 |
-
_, _ix = torch.topk(low_norm, k=10, dim=2, largest=False) # shows background music due to cfg - looses the txt conditioning if flushed!
|
| 92 |
-
_ix = _ix.repeat(1, 24, 1, 64)
|
| 93 |
-
# print(_ix.shape)
|
| 94 |
-
self.k_history = torch.gather(self.k_history, 2, _ix)
|
| 95 |
-
self.v_history = torch.gather(self.v_history, 2, _ix)
|
| 96 |
-
|
| 97 |
-
else:
|
| 98 |
-
# init on 1st token (for all 47 transf layers)
|
| 99 |
-
print(f'AudioGen kv cache Flush')
|
| 100 |
self.k_history = k
|
| 101 |
-
self.v_history = v
|
| 102 |
-
|
| 103 |
k = self.k_history
|
| 104 |
v = self.v_history
|
| 105 |
|
|
|
|
| 3 |
from torch.nn import functional as F
|
| 4 |
from einops import rearrange
|
| 5 |
|
| 6 |
+
def create_sin_embedding(positions,
|
| 7 |
+
dim,
|
| 8 |
max_period = 10000,
|
| 9 |
dtype = torch.float32):
|
| 10 |
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
if self.k_history is not None:
|
| 81 |
+
# flush
|
| 82 |
+
if self.k_history.shape[2] > 71:
|
| 83 |
+
|
| 84 |
+
self.k_history = torch.cat([self.k_history[:, :, :4, :], self.k_history[:, :, -1:, :]], 2)
|
| 85 |
+
self.v_history = torch.cat([self.v_history[:, :, :4, :], self.v_history[:, :, -1:, :]], 2)
|
| 86 |
+
# fill new k/v
|
| 87 |
self.k_history = torch.cat([self.k_history, k], 2) # IF ctrl^c here during live demo it is non-atomic k!=v
|
| 88 |
self.v_history = torch.cat([self.v_history, v], 2) # thus it will try to continue with incompatible k/v dims!
|
| 89 |
+
|
| 90 |
+
else:
|
| 91 |
+
# init
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
self.k_history = k
|
| 93 |
+
self.v_history = v
|
| 94 |
+
# For self attn prepare
|
| 95 |
k = self.k_history
|
| 96 |
v = self.v_history
|
| 97 |
|
msinference.py
CHANGED
|
@@ -390,7 +390,7 @@ def foreign(text=None, # split sentences here so we can prepend a txt for germ
|
|
| 390 |
|
| 391 |
x = net_g(input_ids=inputs.input_ids.to(device),
|
| 392 |
attention_mask=inputs.attention_mask.to(device),
|
| 393 |
-
speed = .94 + .
|
| 394 |
)[0, :]
|
| 395 |
|
| 396 |
# crop the 1st audio - is PREFIX text 156000 samples to chose deu voice / VitsAttention()
|
|
|
|
| 390 |
|
| 391 |
x = net_g(input_ids=inputs.input_ids.to(device),
|
| 392 |
attention_mask=inputs.attention_mask.to(device),
|
| 393 |
+
speed = .94 + .84 * np.random.rand() # variable speed / sentence
|
| 394 |
)[0, :]
|
| 395 |
|
| 396 |
# crop the 1st audio - is PREFIX text 156000 samples to chose deu voice / VitsAttention()
|