Delete src
Browse files- src/audioseal/__init__.py +0 -21
- src/audioseal/builder.py +0 -118
- src/audioseal/cards/audioseal_detector_16bits.yaml +0 -33
- src/audioseal/cards/audioseal_wm_16bits.yaml +0 -39
- src/audioseal/libs/__init__.py +0 -5
- src/audioseal/libs/audiocraft/__init__.py +0 -5
- src/audioseal/libs/audiocraft/modules/__init__.py +0 -8
- src/audioseal/libs/audiocraft/modules/conv.py +0 -337
- src/audioseal/libs/audiocraft/modules/lstm.py +0 -28
- src/audioseal/libs/audiocraft/modules/seanet.py +0 -426
- src/audioseal/loader.py +0 -227
- src/audioseal/models.py +0 -175
- src/audioseal/py.typed +0 -0
- src/scripts/checkpoints.py +0 -51
src/audioseal/__init__.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Watermarking and detection for speech audios
|
| 9 |
-
|
| 10 |
-
A Pytorch-based localized algorithm for proactive detection
|
| 11 |
-
of the watermarkings in AI-generated audios, with very fast
|
| 12 |
-
detector.
|
| 13 |
-
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
__version__ = "0.1.4"
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
from audioseal import builder
|
| 20 |
-
from audioseal.loader import AudioSeal
|
| 21 |
-
from audioseal.models import AudioSealDetector, AudioSealWM, MsgProcessor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/audioseal/builder.py
DELETED
|
@@ -1,118 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
from dataclasses import asdict, dataclass, field, is_dataclass
|
| 8 |
-
from typing import Any, Dict, List, Optional
|
| 9 |
-
|
| 10 |
-
from omegaconf import DictConfig, OmegaConf
|
| 11 |
-
from torch import device, dtype
|
| 12 |
-
from typing_extensions import TypeAlias
|
| 13 |
-
|
| 14 |
-
from audioseal.libs import audiocraft
|
| 15 |
-
from audioseal.models import AudioSealDetector, AudioSealWM, MsgProcessor
|
| 16 |
-
|
| 17 |
-
Device: TypeAlias = device
|
| 18 |
-
|
| 19 |
-
DataType: TypeAlias = dtype
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
@dataclass
|
| 23 |
-
class SEANetConfig:
|
| 24 |
-
"""
|
| 25 |
-
Map common hparams of SEANet encoder and decoder.
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
channels: int
|
| 29 |
-
dimension: int
|
| 30 |
-
n_filters: int
|
| 31 |
-
n_residual_layers: int
|
| 32 |
-
ratios: List[int]
|
| 33 |
-
activation: str
|
| 34 |
-
activation_params: Dict[str, float]
|
| 35 |
-
norm: str
|
| 36 |
-
norm_params: Dict[str, Any]
|
| 37 |
-
kernel_size: int
|
| 38 |
-
last_kernel_size: int
|
| 39 |
-
residual_kernel_size: int
|
| 40 |
-
dilation_base: int
|
| 41 |
-
causal: bool
|
| 42 |
-
pad_mode: str
|
| 43 |
-
true_skip: bool
|
| 44 |
-
compress: int
|
| 45 |
-
lstm: int
|
| 46 |
-
disable_norm_outer_blocks: int
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
@dataclass
|
| 50 |
-
class DecoderConfig:
|
| 51 |
-
final_activation: Optional[str]
|
| 52 |
-
final_activation_params: Optional[dict]
|
| 53 |
-
trim_right_ratio: float
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
@dataclass
|
| 57 |
-
class DetectorConfig:
|
| 58 |
-
output_dim: int = 32
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
@dataclass
|
| 62 |
-
class AudioSealWMConfig:
|
| 63 |
-
nbits: int
|
| 64 |
-
seanet: SEANetConfig
|
| 65 |
-
decoder: DecoderConfig
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
@dataclass
|
| 69 |
-
class AudioSealDetectorConfig:
|
| 70 |
-
nbits: int
|
| 71 |
-
seanet: SEANetConfig
|
| 72 |
-
detector: DetectorConfig = field(default_factory=lambda: DetectorConfig())
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def as_dict(obj: Any) -> Dict[str, Any]:
|
| 76 |
-
if isinstance(obj, dict):
|
| 77 |
-
return obj
|
| 78 |
-
if is_dataclass(obj) and not isinstance(obj, type):
|
| 79 |
-
return asdict(obj)
|
| 80 |
-
elif isinstance(obj, DictConfig):
|
| 81 |
-
return OmegaConf.to_container(obj) # type: ignore
|
| 82 |
-
else:
|
| 83 |
-
raise NotImplementedError(f"Unsupported type for config: {type(obj)}")
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def create_generator(
|
| 87 |
-
config: AudioSealWMConfig,
|
| 88 |
-
*,
|
| 89 |
-
device: Optional[Device] = None,
|
| 90 |
-
dtype: Optional[DataType] = None,
|
| 91 |
-
) -> AudioSealWM:
|
| 92 |
-
"""Create a generator from hparams"""
|
| 93 |
-
|
| 94 |
-
# Currently the encoder hparams are the same as
|
| 95 |
-
# SEANet, but this can be changed in the future.
|
| 96 |
-
encoder = audiocraft.modules.SEANetEncoder(**as_dict(config.seanet))
|
| 97 |
-
encoder = encoder.to(device=device, dtype=dtype)
|
| 98 |
-
|
| 99 |
-
decoder_config = {**as_dict(config.seanet), **as_dict(config.decoder)}
|
| 100 |
-
decoder = audiocraft.modules.SEANetDecoder(**as_dict(decoder_config))
|
| 101 |
-
decoder = decoder.to(device=device, dtype=dtype)
|
| 102 |
-
|
| 103 |
-
msgprocessor = MsgProcessor(nbits=config.nbits, hidden_size=config.seanet.dimension)
|
| 104 |
-
msgprocessor = msgprocessor.to(device=device, dtype=dtype)
|
| 105 |
-
|
| 106 |
-
return AudioSealWM(encoder=encoder, decoder=decoder, msg_processor=msgprocessor)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def create_detector(
|
| 110 |
-
config: AudioSealDetectorConfig,
|
| 111 |
-
*,
|
| 112 |
-
device: Optional[Device] = None,
|
| 113 |
-
dtype: Optional[DataType] = None,
|
| 114 |
-
) -> AudioSealDetector:
|
| 115 |
-
detector_config = {**as_dict(config.seanet), **as_dict(config.detector)}
|
| 116 |
-
detector = AudioSealDetector(nbits=config.nbits, **detector_config)
|
| 117 |
-
detector = detector.to(device=device, dtype=dtype)
|
| 118 |
-
return detector
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/audioseal/cards/audioseal_detector_16bits.yaml
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
# @package __global__
|
| 2 |
-
|
| 3 |
-
name: audioseal_detector_16bits
|
| 4 |
-
model_type: seanet
|
| 5 |
-
checkpoint: "https://huggingface.co/facebook/audioseal/resolve/main/detector_base.pth"
|
| 6 |
-
nbits: 16
|
| 7 |
-
seanet:
|
| 8 |
-
activation: ELU
|
| 9 |
-
activation_params:
|
| 10 |
-
alpha: 1.0
|
| 11 |
-
causal: false
|
| 12 |
-
channels: 1
|
| 13 |
-
compress: 2
|
| 14 |
-
dilation_base: 2
|
| 15 |
-
dimension: 128
|
| 16 |
-
disable_norm_outer_blocks: 0
|
| 17 |
-
kernel_size: 7
|
| 18 |
-
last_kernel_size: 7
|
| 19 |
-
lstm: 2
|
| 20 |
-
n_filters: 32
|
| 21 |
-
n_residual_layers: 1
|
| 22 |
-
norm: weight_norm
|
| 23 |
-
norm_params: {}
|
| 24 |
-
pad_mode: constant
|
| 25 |
-
ratios:
|
| 26 |
-
- 8
|
| 27 |
-
- 5
|
| 28 |
-
- 4
|
| 29 |
-
- 2
|
| 30 |
-
residual_kernel_size: 3
|
| 31 |
-
true_skip: true
|
| 32 |
-
detector:
|
| 33 |
-
output_dim: 32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/audioseal/cards/audioseal_wm_16bits.yaml
DELETED
|
@@ -1,39 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
name: audioseal_wm_16bits
|
| 8 |
-
model_type: seanet
|
| 9 |
-
checkpoint: "https://huggingface.co/facebook/audioseal/resolve/main/generator_base.pth"
|
| 10 |
-
nbits: 16
|
| 11 |
-
seanet:
|
| 12 |
-
activation: ELU
|
| 13 |
-
activation_params:
|
| 14 |
-
alpha: 1.0
|
| 15 |
-
causal: false
|
| 16 |
-
channels: 1
|
| 17 |
-
compress: 2
|
| 18 |
-
dilation_base: 2
|
| 19 |
-
dimension: 128
|
| 20 |
-
disable_norm_outer_blocks: 0
|
| 21 |
-
kernel_size: 7
|
| 22 |
-
last_kernel_size: 7
|
| 23 |
-
lstm: 2
|
| 24 |
-
n_filters: 32
|
| 25 |
-
n_residual_layers: 1
|
| 26 |
-
norm: weight_norm
|
| 27 |
-
norm_params: {}
|
| 28 |
-
pad_mode: constant
|
| 29 |
-
ratios:
|
| 30 |
-
- 8
|
| 31 |
-
- 5
|
| 32 |
-
- 4
|
| 33 |
-
- 2
|
| 34 |
-
residual_kernel_size: 3
|
| 35 |
-
true_skip: true
|
| 36 |
-
decoder:
|
| 37 |
-
final_activation: null
|
| 38 |
-
final_activation_params: null
|
| 39 |
-
trim_right_ratio: 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/audioseal/libs/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/audioseal/libs/audiocraft/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/audioseal/libs/audiocraft/modules/__init__.py
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
from .seanet import SEANetDecoder, SEANetEncoder, SEANetEncoderKeepDimension
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/audioseal/libs/audiocraft/modules/conv.py
DELETED
|
@@ -1,337 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
#
|
| 7 |
-
# Vendor from https://github.com/facebookresearch/audiocraft
|
| 8 |
-
|
| 9 |
-
import math
|
| 10 |
-
import typing as tp
|
| 11 |
-
import warnings
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
-
from torch import nn
|
| 15 |
-
from torch.nn import functional as F
|
| 16 |
-
from torch.nn.utils import spectral_norm
|
| 17 |
-
|
| 18 |
-
try:
|
| 19 |
-
from torch.nn.utils.parametrizations import weight_norm
|
| 20 |
-
except ImportError:
|
| 21 |
-
# Old Pytorch
|
| 22 |
-
from torch.nn.utils import weight_norm
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
CONV_NORMALIZATIONS = frozenset(
|
| 26 |
-
["none", "weight_norm", "spectral_norm", "time_group_norm"]
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def apply_parametrization_norm(module: nn.Module, norm: str = "none"):
|
| 31 |
-
assert norm in CONV_NORMALIZATIONS
|
| 32 |
-
if norm == "weight_norm":
|
| 33 |
-
return weight_norm(module)
|
| 34 |
-
elif norm == "spectral_norm":
|
| 35 |
-
return spectral_norm(module)
|
| 36 |
-
else:
|
| 37 |
-
# We already check was in CONV_NORMALIZATION, so any other choice
|
| 38 |
-
# doesn't need reparametrization.
|
| 39 |
-
return module
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def get_norm_module(
|
| 43 |
-
module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
|
| 44 |
-
):
|
| 45 |
-
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
| 46 |
-
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
| 47 |
-
"""
|
| 48 |
-
assert norm in CONV_NORMALIZATIONS
|
| 49 |
-
if norm == "time_group_norm":
|
| 50 |
-
if causal:
|
| 51 |
-
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
| 52 |
-
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 53 |
-
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
| 54 |
-
else:
|
| 55 |
-
return nn.Identity()
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def get_extra_padding_for_conv1d(
|
| 59 |
-
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 60 |
-
) -> int:
|
| 61 |
-
"""See `pad_for_conv1d`."""
|
| 62 |
-
length = x.shape[-1]
|
| 63 |
-
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 64 |
-
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 65 |
-
return ideal_length - length
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def pad_for_conv1d(
|
| 69 |
-
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 70 |
-
):
|
| 71 |
-
"""Pad for a convolution to make sure that the last window is full.
|
| 72 |
-
Extra padding is added at the end. This is required to ensure that we can rebuild
|
| 73 |
-
an output of the same length, as otherwise, even with padding, some time steps
|
| 74 |
-
might get removed.
|
| 75 |
-
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
| 76 |
-
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
| 77 |
-
1 2 3 # (output frames of a convolution, last 0 is never used)
|
| 78 |
-
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
| 79 |
-
1 2 3 4 # once you removed padding, we are missing one time step !
|
| 80 |
-
"""
|
| 81 |
-
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 82 |
-
return F.pad(x, (0, extra_padding))
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def pad1d(
|
| 86 |
-
x: torch.Tensor,
|
| 87 |
-
paddings: tp.Tuple[int, int],
|
| 88 |
-
mode: str = "constant",
|
| 89 |
-
value: float = 0.0,
|
| 90 |
-
):
|
| 91 |
-
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
| 92 |
-
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
| 93 |
-
"""
|
| 94 |
-
length = x.shape[-1]
|
| 95 |
-
padding_left, padding_right = paddings
|
| 96 |
-
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 97 |
-
if mode == "reflect":
|
| 98 |
-
max_pad = max(padding_left, padding_right)
|
| 99 |
-
extra_pad = 0
|
| 100 |
-
if length <= max_pad:
|
| 101 |
-
extra_pad = max_pad - length + 1
|
| 102 |
-
x = F.pad(x, (0, extra_pad))
|
| 103 |
-
padded = F.pad(x, paddings, mode, value)
|
| 104 |
-
end = padded.shape[-1] - extra_pad
|
| 105 |
-
return padded[..., :end]
|
| 106 |
-
else:
|
| 107 |
-
return F.pad(x, paddings, mode, value)
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
| 111 |
-
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
| 112 |
-
padding_left, padding_right = paddings
|
| 113 |
-
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 114 |
-
assert (padding_left + padding_right) <= x.shape[-1]
|
| 115 |
-
end = x.shape[-1] - padding_right
|
| 116 |
-
return x[..., padding_left:end]
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
class NormConv1d(nn.Module):
|
| 120 |
-
"""Wrapper around Conv1d and normalization applied to this conv
|
| 121 |
-
to provide a uniform interface across normalization approaches.
|
| 122 |
-
"""
|
| 123 |
-
|
| 124 |
-
def __init__(
|
| 125 |
-
self,
|
| 126 |
-
*args,
|
| 127 |
-
causal: bool = False,
|
| 128 |
-
norm: str = "none",
|
| 129 |
-
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 130 |
-
**kwargs,
|
| 131 |
-
):
|
| 132 |
-
super().__init__()
|
| 133 |
-
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
| 134 |
-
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
| 135 |
-
self.norm_type = norm
|
| 136 |
-
|
| 137 |
-
def forward(self, x):
|
| 138 |
-
x = self.conv(x)
|
| 139 |
-
x = self.norm(x)
|
| 140 |
-
return x
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
class NormConv2d(nn.Module):
|
| 144 |
-
"""Wrapper around Conv2d and normalization applied to this conv
|
| 145 |
-
to provide a uniform interface across normalization approaches.
|
| 146 |
-
"""
|
| 147 |
-
|
| 148 |
-
def __init__(
|
| 149 |
-
self,
|
| 150 |
-
*args,
|
| 151 |
-
norm: str = "none",
|
| 152 |
-
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 153 |
-
**kwargs,
|
| 154 |
-
):
|
| 155 |
-
super().__init__()
|
| 156 |
-
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
| 157 |
-
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
| 158 |
-
self.norm_type = norm
|
| 159 |
-
|
| 160 |
-
def forward(self, x):
|
| 161 |
-
x = self.conv(x)
|
| 162 |
-
x = self.norm(x)
|
| 163 |
-
return x
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
class NormConvTranspose1d(nn.Module):
|
| 167 |
-
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
| 168 |
-
to provide a uniform interface across normalization approaches.
|
| 169 |
-
"""
|
| 170 |
-
|
| 171 |
-
def __init__(
|
| 172 |
-
self,
|
| 173 |
-
*args,
|
| 174 |
-
causal: bool = False,
|
| 175 |
-
norm: str = "none",
|
| 176 |
-
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 177 |
-
**kwargs,
|
| 178 |
-
):
|
| 179 |
-
super().__init__()
|
| 180 |
-
self.convtr = apply_parametrization_norm(
|
| 181 |
-
nn.ConvTranspose1d(*args, **kwargs), norm
|
| 182 |
-
)
|
| 183 |
-
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
| 184 |
-
self.norm_type = norm
|
| 185 |
-
|
| 186 |
-
def forward(self, x):
|
| 187 |
-
x = self.convtr(x)
|
| 188 |
-
x = self.norm(x)
|
| 189 |
-
return x
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
class NormConvTranspose2d(nn.Module):
|
| 193 |
-
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
| 194 |
-
to provide a uniform interface across normalization approaches.
|
| 195 |
-
"""
|
| 196 |
-
|
| 197 |
-
def __init__(
|
| 198 |
-
self,
|
| 199 |
-
*args,
|
| 200 |
-
norm: str = "none",
|
| 201 |
-
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 202 |
-
**kwargs,
|
| 203 |
-
):
|
| 204 |
-
super().__init__()
|
| 205 |
-
self.convtr = apply_parametrization_norm(
|
| 206 |
-
nn.ConvTranspose2d(*args, **kwargs), norm
|
| 207 |
-
)
|
| 208 |
-
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
| 209 |
-
|
| 210 |
-
def forward(self, x):
|
| 211 |
-
x = self.convtr(x)
|
| 212 |
-
x = self.norm(x)
|
| 213 |
-
return x
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
class StreamableConv1d(nn.Module):
|
| 217 |
-
"""Conv1d with some builtin handling of asymmetric or causal padding
|
| 218 |
-
and normalization.
|
| 219 |
-
"""
|
| 220 |
-
|
| 221 |
-
def __init__(
|
| 222 |
-
self,
|
| 223 |
-
in_channels: int,
|
| 224 |
-
out_channels: int,
|
| 225 |
-
kernel_size: int,
|
| 226 |
-
stride: int = 1,
|
| 227 |
-
dilation: int = 1,
|
| 228 |
-
groups: int = 1,
|
| 229 |
-
bias: bool = True,
|
| 230 |
-
causal: bool = False,
|
| 231 |
-
norm: str = "none",
|
| 232 |
-
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 233 |
-
pad_mode: str = "reflect",
|
| 234 |
-
):
|
| 235 |
-
super().__init__()
|
| 236 |
-
# warn user on unusual setup between dilation and stride
|
| 237 |
-
if stride > 1 and dilation > 1:
|
| 238 |
-
warnings.warn(
|
| 239 |
-
"StreamableConv1d has been initialized with stride > 1 and dilation > 1"
|
| 240 |
-
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
|
| 241 |
-
)
|
| 242 |
-
self.conv = NormConv1d(
|
| 243 |
-
in_channels,
|
| 244 |
-
out_channels,
|
| 245 |
-
kernel_size,
|
| 246 |
-
stride,
|
| 247 |
-
dilation=dilation,
|
| 248 |
-
groups=groups,
|
| 249 |
-
bias=bias,
|
| 250 |
-
causal=causal,
|
| 251 |
-
norm=norm,
|
| 252 |
-
norm_kwargs=norm_kwargs,
|
| 253 |
-
)
|
| 254 |
-
self.causal = causal
|
| 255 |
-
self.pad_mode = pad_mode
|
| 256 |
-
|
| 257 |
-
def forward(self, x):
|
| 258 |
-
B, C, T = x.shape
|
| 259 |
-
kernel_size = self.conv.conv.kernel_size[0]
|
| 260 |
-
stride = self.conv.conv.stride[0]
|
| 261 |
-
dilation = self.conv.conv.dilation[0]
|
| 262 |
-
kernel_size = (
|
| 263 |
-
kernel_size - 1
|
| 264 |
-
) * dilation + 1 # effective kernel size with dilations
|
| 265 |
-
padding_total = kernel_size - stride
|
| 266 |
-
extra_padding = get_extra_padding_for_conv1d(
|
| 267 |
-
x, kernel_size, stride, padding_total
|
| 268 |
-
)
|
| 269 |
-
if self.causal:
|
| 270 |
-
# Left padding for causal
|
| 271 |
-
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
| 272 |
-
else:
|
| 273 |
-
# Asymmetric padding required for odd strides
|
| 274 |
-
padding_right = padding_total // 2
|
| 275 |
-
padding_left = padding_total - padding_right
|
| 276 |
-
x = pad1d(
|
| 277 |
-
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
|
| 278 |
-
)
|
| 279 |
-
return self.conv(x)
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
class StreamableConvTranspose1d(nn.Module):
|
| 283 |
-
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
| 284 |
-
and normalization.
|
| 285 |
-
"""
|
| 286 |
-
|
| 287 |
-
def __init__(
|
| 288 |
-
self,
|
| 289 |
-
in_channels: int,
|
| 290 |
-
out_channels: int,
|
| 291 |
-
kernel_size: int,
|
| 292 |
-
stride: int = 1,
|
| 293 |
-
causal: bool = False,
|
| 294 |
-
norm: str = "none",
|
| 295 |
-
trim_right_ratio: float = 1.0,
|
| 296 |
-
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 297 |
-
):
|
| 298 |
-
super().__init__()
|
| 299 |
-
self.convtr = NormConvTranspose1d(
|
| 300 |
-
in_channels,
|
| 301 |
-
out_channels,
|
| 302 |
-
kernel_size,
|
| 303 |
-
stride,
|
| 304 |
-
causal=causal,
|
| 305 |
-
norm=norm,
|
| 306 |
-
norm_kwargs=norm_kwargs,
|
| 307 |
-
)
|
| 308 |
-
self.causal = causal
|
| 309 |
-
self.trim_right_ratio = trim_right_ratio
|
| 310 |
-
assert (
|
| 311 |
-
self.causal or self.trim_right_ratio == 1.0
|
| 312 |
-
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
| 313 |
-
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
|
| 314 |
-
|
| 315 |
-
def forward(self, x):
|
| 316 |
-
kernel_size = self.convtr.convtr.kernel_size[0]
|
| 317 |
-
stride = self.convtr.convtr.stride[0]
|
| 318 |
-
padding_total = kernel_size - stride
|
| 319 |
-
|
| 320 |
-
y = self.convtr(x)
|
| 321 |
-
|
| 322 |
-
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
| 323 |
-
# removed at the very end, when keeping only the right length for the output,
|
| 324 |
-
# as removing it here would require also passing the length at the matching layer
|
| 325 |
-
# in the encoder.
|
| 326 |
-
if self.causal:
|
| 327 |
-
# Trim the padding on the right according to the specified ratio
|
| 328 |
-
# if trim_right_ratio = 1.0, trim everything from right
|
| 329 |
-
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
| 330 |
-
padding_left = padding_total - padding_right
|
| 331 |
-
y = unpad1d(y, (padding_left, padding_right))
|
| 332 |
-
else:
|
| 333 |
-
# Asymmetric padding required for odd strides
|
| 334 |
-
padding_right = padding_total // 2
|
| 335 |
-
padding_left = padding_total - padding_right
|
| 336 |
-
y = unpad1d(y, (padding_left, padding_right))
|
| 337 |
-
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/audioseal/libs/audiocraft/modules/lstm.py
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
#
|
| 7 |
-
# Vendor from https://github.com/facebookresearch/audiocraft
|
| 8 |
-
|
| 9 |
-
from torch import nn
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class StreamableLSTM(nn.Module):
|
| 13 |
-
"""LSTM without worrying about the hidden state, nor the layout of the data.
|
| 14 |
-
Expects input as convolutional layout.
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
| 18 |
-
super().__init__()
|
| 19 |
-
self.skip = skip
|
| 20 |
-
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
| 21 |
-
|
| 22 |
-
def forward(self, x):
|
| 23 |
-
x = x.permute(2, 0, 1)
|
| 24 |
-
y, _ = self.lstm(x)
|
| 25 |
-
if self.skip:
|
| 26 |
-
y = y + x
|
| 27 |
-
y = y.permute(1, 2, 0)
|
| 28 |
-
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/audioseal/libs/audiocraft/modules/seanet.py
DELETED
|
@@ -1,426 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
#
|
| 7 |
-
# Vendor from https://github.com/facebookresearch/audiocraft
|
| 8 |
-
|
| 9 |
-
import math
|
| 10 |
-
import typing as tp
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
import torch.nn as nn
|
| 14 |
-
|
| 15 |
-
from audioseal.libs.audiocraft.modules.conv import (
|
| 16 |
-
StreamableConv1d,
|
| 17 |
-
StreamableConvTranspose1d,
|
| 18 |
-
)
|
| 19 |
-
from audioseal.libs.audiocraft.modules.lstm import StreamableLSTM
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class SEANetResnetBlock(nn.Module):
|
| 23 |
-
"""Residual block from SEANet model.
|
| 24 |
-
|
| 25 |
-
Args:
|
| 26 |
-
dim (int): Dimension of the input/output.
|
| 27 |
-
kernel_sizes (list): List of kernel sizes for the convolutions.
|
| 28 |
-
dilations (list): List of dilations for the convolutions.
|
| 29 |
-
activation (str): Activation function.
|
| 30 |
-
activation_params (dict): Parameters to provide to the activation function.
|
| 31 |
-
norm (str): Normalization method.
|
| 32 |
-
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 33 |
-
causal (bool): Whether to use fully causal convolution.
|
| 34 |
-
pad_mode (str): Padding mode for the convolutions.
|
| 35 |
-
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 36 |
-
true_skip (bool): Whether to use true skip connection or a simple
|
| 37 |
-
(streamable) convolution as the skip connection.
|
| 38 |
-
"""
|
| 39 |
-
|
| 40 |
-
def __init__(
|
| 41 |
-
self,
|
| 42 |
-
dim: int,
|
| 43 |
-
kernel_sizes: tp.List[int] = [3, 1],
|
| 44 |
-
dilations: tp.List[int] = [1, 1],
|
| 45 |
-
activation: str = "ELU",
|
| 46 |
-
activation_params: dict = {"alpha": 1.0},
|
| 47 |
-
norm: str = "none",
|
| 48 |
-
norm_params: tp.Dict[str, tp.Any] = {},
|
| 49 |
-
causal: bool = False,
|
| 50 |
-
pad_mode: str = "reflect",
|
| 51 |
-
compress: int = 2,
|
| 52 |
-
true_skip: bool = True,
|
| 53 |
-
):
|
| 54 |
-
super().__init__()
|
| 55 |
-
assert len(kernel_sizes) == len(
|
| 56 |
-
dilations
|
| 57 |
-
), "Number of kernel sizes should match number of dilations"
|
| 58 |
-
act = getattr(nn, activation)
|
| 59 |
-
hidden = dim // compress
|
| 60 |
-
block = []
|
| 61 |
-
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
| 62 |
-
in_chs = dim if i == 0 else hidden
|
| 63 |
-
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
| 64 |
-
block += [
|
| 65 |
-
act(**activation_params),
|
| 66 |
-
StreamableConv1d(
|
| 67 |
-
in_chs,
|
| 68 |
-
out_chs,
|
| 69 |
-
kernel_size=kernel_size,
|
| 70 |
-
dilation=dilation,
|
| 71 |
-
norm=norm,
|
| 72 |
-
norm_kwargs=norm_params,
|
| 73 |
-
causal=causal,
|
| 74 |
-
pad_mode=pad_mode,
|
| 75 |
-
),
|
| 76 |
-
]
|
| 77 |
-
self.block = nn.Sequential(*block)
|
| 78 |
-
self.shortcut: nn.Module
|
| 79 |
-
if true_skip:
|
| 80 |
-
self.shortcut = nn.Identity()
|
| 81 |
-
else:
|
| 82 |
-
self.shortcut = StreamableConv1d(
|
| 83 |
-
dim,
|
| 84 |
-
dim,
|
| 85 |
-
kernel_size=1,
|
| 86 |
-
norm=norm,
|
| 87 |
-
norm_kwargs=norm_params,
|
| 88 |
-
causal=causal,
|
| 89 |
-
pad_mode=pad_mode,
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
def forward(self, x):
|
| 93 |
-
return self.shortcut(x) + self.block(x)
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
class SEANetEncoder(nn.Module):
|
| 97 |
-
"""SEANet encoder.
|
| 98 |
-
|
| 99 |
-
Args:
|
| 100 |
-
channels (int): Audio channels.
|
| 101 |
-
dimension (int): Intermediate representation dimension.
|
| 102 |
-
n_filters (int): Base width for the model.
|
| 103 |
-
n_residual_layers (int): nb of residual layers.
|
| 104 |
-
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
| 105 |
-
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
| 106 |
-
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
|
| 107 |
-
activation (str): Activation function.
|
| 108 |
-
activation_params (dict): Parameters to provide to the activation function.
|
| 109 |
-
norm (str): Normalization method.
|
| 110 |
-
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 111 |
-
kernel_size (int): Kernel size for the initial convolution.
|
| 112 |
-
last_kernel_size (int): Kernel size for the initial convolution.
|
| 113 |
-
residual_kernel_size (int): Kernel size for the residual layers.
|
| 114 |
-
dilation_base (int): How much to increase the dilation with each layer.
|
| 115 |
-
causal (bool): Whether to use fully causal convolution.
|
| 116 |
-
pad_mode (str): Padding mode for the convolutions.
|
| 117 |
-
true_skip (bool): Whether to use true skip connection or a simple
|
| 118 |
-
(streamable) convolution as the skip connection in the residual network blocks.
|
| 119 |
-
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 120 |
-
lstm (int): Number of LSTM layers at the end of the encoder.
|
| 121 |
-
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
| 122 |
-
For the encoder, it corresponds to the N first blocks.
|
| 123 |
-
"""
|
| 124 |
-
|
| 125 |
-
def __init__(
|
| 126 |
-
self,
|
| 127 |
-
channels: int = 1,
|
| 128 |
-
dimension: int = 128,
|
| 129 |
-
n_filters: int = 32,
|
| 130 |
-
n_residual_layers: int = 3,
|
| 131 |
-
ratios: tp.List[int] = [8, 5, 4, 2],
|
| 132 |
-
activation: str = "ELU",
|
| 133 |
-
activation_params: dict = {"alpha": 1.0},
|
| 134 |
-
norm: str = "none",
|
| 135 |
-
norm_params: tp.Dict[str, tp.Any] = {},
|
| 136 |
-
kernel_size: int = 7,
|
| 137 |
-
last_kernel_size: int = 7,
|
| 138 |
-
residual_kernel_size: int = 3,
|
| 139 |
-
dilation_base: int = 2,
|
| 140 |
-
causal: bool = False,
|
| 141 |
-
pad_mode: str = "reflect",
|
| 142 |
-
true_skip: bool = True,
|
| 143 |
-
compress: int = 2,
|
| 144 |
-
lstm: int = 0,
|
| 145 |
-
disable_norm_outer_blocks: int = 0,
|
| 146 |
-
):
|
| 147 |
-
super().__init__()
|
| 148 |
-
self.channels = channels
|
| 149 |
-
self.dimension = dimension
|
| 150 |
-
self.n_filters = n_filters
|
| 151 |
-
self.ratios = list(reversed(ratios))
|
| 152 |
-
del ratios
|
| 153 |
-
self.n_residual_layers = n_residual_layers
|
| 154 |
-
self.hop_length = np.prod(self.ratios)
|
| 155 |
-
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
| 156 |
-
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
| 157 |
-
assert (
|
| 158 |
-
self.disable_norm_outer_blocks >= 0
|
| 159 |
-
and self.disable_norm_outer_blocks <= self.n_blocks
|
| 160 |
-
), (
|
| 161 |
-
"Number of blocks for which to disable norm is invalid."
|
| 162 |
-
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
act = getattr(nn, activation)
|
| 166 |
-
mult = 1
|
| 167 |
-
model: tp.List[nn.Module] = [
|
| 168 |
-
StreamableConv1d(
|
| 169 |
-
channels,
|
| 170 |
-
mult * n_filters,
|
| 171 |
-
kernel_size,
|
| 172 |
-
norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
|
| 173 |
-
norm_kwargs=norm_params,
|
| 174 |
-
causal=causal,
|
| 175 |
-
pad_mode=pad_mode,
|
| 176 |
-
)
|
| 177 |
-
]
|
| 178 |
-
# Downsample to raw audio scale
|
| 179 |
-
for i, ratio in enumerate(self.ratios):
|
| 180 |
-
block_norm = "none" if self.disable_norm_outer_blocks >= i + 2 else norm
|
| 181 |
-
# Add residual layers
|
| 182 |
-
for j in range(n_residual_layers):
|
| 183 |
-
model += [
|
| 184 |
-
SEANetResnetBlock(
|
| 185 |
-
mult * n_filters,
|
| 186 |
-
kernel_sizes=[residual_kernel_size, 1],
|
| 187 |
-
dilations=[dilation_base**j, 1],
|
| 188 |
-
norm=block_norm,
|
| 189 |
-
norm_params=norm_params,
|
| 190 |
-
activation=activation,
|
| 191 |
-
activation_params=activation_params,
|
| 192 |
-
causal=causal,
|
| 193 |
-
pad_mode=pad_mode,
|
| 194 |
-
compress=compress,
|
| 195 |
-
true_skip=true_skip,
|
| 196 |
-
)
|
| 197 |
-
]
|
| 198 |
-
|
| 199 |
-
# Add downsampling layers
|
| 200 |
-
model += [
|
| 201 |
-
act(**activation_params),
|
| 202 |
-
StreamableConv1d(
|
| 203 |
-
mult * n_filters,
|
| 204 |
-
mult * n_filters * 2,
|
| 205 |
-
kernel_size=ratio * 2,
|
| 206 |
-
stride=ratio,
|
| 207 |
-
norm=block_norm,
|
| 208 |
-
norm_kwargs=norm_params,
|
| 209 |
-
causal=causal,
|
| 210 |
-
pad_mode=pad_mode,
|
| 211 |
-
),
|
| 212 |
-
]
|
| 213 |
-
mult *= 2
|
| 214 |
-
|
| 215 |
-
if lstm:
|
| 216 |
-
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
| 217 |
-
|
| 218 |
-
model += [
|
| 219 |
-
act(**activation_params),
|
| 220 |
-
StreamableConv1d(
|
| 221 |
-
mult * n_filters,
|
| 222 |
-
dimension,
|
| 223 |
-
last_kernel_size,
|
| 224 |
-
norm=(
|
| 225 |
-
"none" if self.disable_norm_outer_blocks == self.n_blocks else norm
|
| 226 |
-
),
|
| 227 |
-
norm_kwargs=norm_params,
|
| 228 |
-
causal=causal,
|
| 229 |
-
pad_mode=pad_mode,
|
| 230 |
-
),
|
| 231 |
-
]
|
| 232 |
-
|
| 233 |
-
self.model = nn.Sequential(*model)
|
| 234 |
-
|
| 235 |
-
def forward(self, x):
|
| 236 |
-
return self.model(x)
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
class SEANetEncoderKeepDimension(SEANetEncoder):
|
| 240 |
-
"""
|
| 241 |
-
similar architecture to the SEANet encoder but with an extra step that
|
| 242 |
-
projects the output dimension to the same input dimension by repeating
|
| 243 |
-
the sequential
|
| 244 |
-
|
| 245 |
-
Args:
|
| 246 |
-
SEANetEncoder (_type_): _description_
|
| 247 |
-
"""
|
| 248 |
-
|
| 249 |
-
def __init__(self, *args, **kwargs):
|
| 250 |
-
|
| 251 |
-
self.output_dim = kwargs.pop("output_dim")
|
| 252 |
-
super().__init__(*args, **kwargs)
|
| 253 |
-
# Adding a reverse convolution layer
|
| 254 |
-
self.reverse_convolution = nn.ConvTranspose1d(
|
| 255 |
-
in_channels=self.dimension,
|
| 256 |
-
out_channels=self.output_dim,
|
| 257 |
-
kernel_size=math.prod(self.ratios),
|
| 258 |
-
stride=math.prod(self.ratios),
|
| 259 |
-
padding=0,
|
| 260 |
-
)
|
| 261 |
-
|
| 262 |
-
def forward(self, x):
|
| 263 |
-
orig_nframes = x.shape[-1]
|
| 264 |
-
x = self.model(x)
|
| 265 |
-
x = self.reverse_convolution(x)
|
| 266 |
-
# make sure dim didn't change
|
| 267 |
-
return x[:, :, :orig_nframes]
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
class SEANetDecoder(nn.Module):
|
| 271 |
-
"""SEANet decoder.
|
| 272 |
-
|
| 273 |
-
Args:
|
| 274 |
-
channels (int): Audio channels.
|
| 275 |
-
dimension (int): Intermediate representation dimension.
|
| 276 |
-
n_filters (int): Base width for the model.
|
| 277 |
-
n_residual_layers (int): nb of residual layers.
|
| 278 |
-
ratios (Sequence[int]): kernel size and stride ratios.
|
| 279 |
-
activation (str): Activation function.
|
| 280 |
-
activation_params (dict): Parameters to provide to the activation function.
|
| 281 |
-
final_activation (str): Final activation function after all convolutions.
|
| 282 |
-
final_activation_params (dict): Parameters to provide to the activation function.
|
| 283 |
-
norm (str): Normalization method.
|
| 284 |
-
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 285 |
-
kernel_size (int): Kernel size for the initial convolution.
|
| 286 |
-
last_kernel_size (int): Kernel size for the initial convolution.
|
| 287 |
-
residual_kernel_size (int): Kernel size for the residual layers.
|
| 288 |
-
dilation_base (int): How much to increase the dilation with each layer.
|
| 289 |
-
causal (bool): Whether to use fully causal convolution.
|
| 290 |
-
pad_mode (str): Padding mode for the convolutions.
|
| 291 |
-
true_skip (bool): Whether to use true skip connection or a simple.
|
| 292 |
-
(streamable) convolution as the skip connection in the residual network blocks.
|
| 293 |
-
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 294 |
-
lstm (int): Number of LSTM layers at the end of the encoder.
|
| 295 |
-
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
| 296 |
-
For the decoder, it corresponds to the N last blocks.
|
| 297 |
-
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
| 298 |
-
If equal to 1.0, it means that all the trimming is done at the right.
|
| 299 |
-
"""
|
| 300 |
-
|
| 301 |
-
def __init__(
|
| 302 |
-
self,
|
| 303 |
-
channels: int = 1,
|
| 304 |
-
dimension: int = 128,
|
| 305 |
-
n_filters: int = 32,
|
| 306 |
-
n_residual_layers: int = 3,
|
| 307 |
-
ratios: tp.List[int] = [8, 5, 4, 2],
|
| 308 |
-
activation: str = "ELU",
|
| 309 |
-
activation_params: dict = {"alpha": 1.0},
|
| 310 |
-
final_activation: tp.Optional[str] = None,
|
| 311 |
-
final_activation_params: tp.Optional[dict] = None,
|
| 312 |
-
norm: str = "none",
|
| 313 |
-
norm_params: tp.Dict[str, tp.Any] = {},
|
| 314 |
-
kernel_size: int = 7,
|
| 315 |
-
last_kernel_size: int = 7,
|
| 316 |
-
residual_kernel_size: int = 3,
|
| 317 |
-
dilation_base: int = 2,
|
| 318 |
-
causal: bool = False,
|
| 319 |
-
pad_mode: str = "reflect",
|
| 320 |
-
true_skip: bool = True,
|
| 321 |
-
compress: int = 2,
|
| 322 |
-
lstm: int = 0,
|
| 323 |
-
disable_norm_outer_blocks: int = 0,
|
| 324 |
-
trim_right_ratio: float = 1.0,
|
| 325 |
-
):
|
| 326 |
-
super().__init__()
|
| 327 |
-
self.dimension = dimension
|
| 328 |
-
self.channels = channels
|
| 329 |
-
self.n_filters = n_filters
|
| 330 |
-
self.ratios = ratios
|
| 331 |
-
del ratios
|
| 332 |
-
self.n_residual_layers = n_residual_layers
|
| 333 |
-
self.hop_length = np.prod(self.ratios)
|
| 334 |
-
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
| 335 |
-
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
| 336 |
-
assert (
|
| 337 |
-
self.disable_norm_outer_blocks >= 0
|
| 338 |
-
and self.disable_norm_outer_blocks <= self.n_blocks
|
| 339 |
-
), (
|
| 340 |
-
"Number of blocks for which to disable norm is invalid."
|
| 341 |
-
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
| 342 |
-
)
|
| 343 |
-
|
| 344 |
-
act = getattr(nn, activation)
|
| 345 |
-
mult = int(2 ** len(self.ratios))
|
| 346 |
-
model: tp.List[nn.Module] = [
|
| 347 |
-
StreamableConv1d(
|
| 348 |
-
dimension,
|
| 349 |
-
mult * n_filters,
|
| 350 |
-
kernel_size,
|
| 351 |
-
norm=(
|
| 352 |
-
"none" if self.disable_norm_outer_blocks == self.n_blocks else norm
|
| 353 |
-
),
|
| 354 |
-
norm_kwargs=norm_params,
|
| 355 |
-
causal=causal,
|
| 356 |
-
pad_mode=pad_mode,
|
| 357 |
-
)
|
| 358 |
-
]
|
| 359 |
-
|
| 360 |
-
if lstm:
|
| 361 |
-
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
| 362 |
-
|
| 363 |
-
# Upsample to raw audio scale
|
| 364 |
-
for i, ratio in enumerate(self.ratios):
|
| 365 |
-
block_norm = (
|
| 366 |
-
"none"
|
| 367 |
-
if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1)
|
| 368 |
-
else norm
|
| 369 |
-
)
|
| 370 |
-
# Add upsampling layers
|
| 371 |
-
model += [
|
| 372 |
-
act(**activation_params),
|
| 373 |
-
StreamableConvTranspose1d(
|
| 374 |
-
mult * n_filters,
|
| 375 |
-
mult * n_filters // 2,
|
| 376 |
-
kernel_size=ratio * 2,
|
| 377 |
-
stride=ratio,
|
| 378 |
-
norm=block_norm,
|
| 379 |
-
norm_kwargs=norm_params,
|
| 380 |
-
causal=causal,
|
| 381 |
-
trim_right_ratio=trim_right_ratio,
|
| 382 |
-
),
|
| 383 |
-
]
|
| 384 |
-
# Add residual layers
|
| 385 |
-
for j in range(n_residual_layers):
|
| 386 |
-
model += [
|
| 387 |
-
SEANetResnetBlock(
|
| 388 |
-
mult * n_filters // 2,
|
| 389 |
-
kernel_sizes=[residual_kernel_size, 1],
|
| 390 |
-
dilations=[dilation_base**j, 1],
|
| 391 |
-
activation=activation,
|
| 392 |
-
activation_params=activation_params,
|
| 393 |
-
norm=block_norm,
|
| 394 |
-
norm_params=norm_params,
|
| 395 |
-
causal=causal,
|
| 396 |
-
pad_mode=pad_mode,
|
| 397 |
-
compress=compress,
|
| 398 |
-
true_skip=true_skip,
|
| 399 |
-
)
|
| 400 |
-
]
|
| 401 |
-
|
| 402 |
-
mult //= 2
|
| 403 |
-
|
| 404 |
-
# Add final layers
|
| 405 |
-
model += [
|
| 406 |
-
act(**activation_params),
|
| 407 |
-
StreamableConv1d(
|
| 408 |
-
n_filters,
|
| 409 |
-
channels,
|
| 410 |
-
last_kernel_size,
|
| 411 |
-
norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
|
| 412 |
-
norm_kwargs=norm_params,
|
| 413 |
-
causal=causal,
|
| 414 |
-
pad_mode=pad_mode,
|
| 415 |
-
),
|
| 416 |
-
]
|
| 417 |
-
# Add optional final activation to decoder (eg. tanh)
|
| 418 |
-
if final_activation is not None:
|
| 419 |
-
final_act = getattr(nn, final_activation)
|
| 420 |
-
final_activation_params = final_activation_params or {}
|
| 421 |
-
model += [final_act(**final_activation_params)]
|
| 422 |
-
self.model = nn.Sequential(*model)
|
| 423 |
-
|
| 424 |
-
def forward(self, z):
|
| 425 |
-
y = self.model(z)
|
| 426 |
-
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/audioseal/loader.py
DELETED
|
@@ -1,227 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
import os
|
| 9 |
-
from dataclasses import fields
|
| 10 |
-
from hashlib import sha1
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
from typing import ( # type: ignore[attr-defined]
|
| 13 |
-
Any,
|
| 14 |
-
Dict,
|
| 15 |
-
List,
|
| 16 |
-
Optional,
|
| 17 |
-
Tuple,
|
| 18 |
-
Type,
|
| 19 |
-
TypeVar,
|
| 20 |
-
Union,
|
| 21 |
-
cast,
|
| 22 |
-
)
|
| 23 |
-
from urllib.parse import urlparse # noqa: F401
|
| 24 |
-
|
| 25 |
-
import torch
|
| 26 |
-
from omegaconf import DictConfig, OmegaConf
|
| 27 |
-
|
| 28 |
-
import audioseal
|
| 29 |
-
from audioseal.builder import (
|
| 30 |
-
AudioSealDetectorConfig,
|
| 31 |
-
AudioSealWMConfig,
|
| 32 |
-
create_detector,
|
| 33 |
-
create_generator,
|
| 34 |
-
)
|
| 35 |
-
from audioseal.models import AudioSealDetector, AudioSealWM
|
| 36 |
-
|
| 37 |
-
AudioSealT = TypeVar("AudioSealT", AudioSealWMConfig, AudioSealDetectorConfig)
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class ModelLoadError(RuntimeError):
|
| 41 |
-
"""Raised when the model loading fails"""
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def _get_path_from_env(var_name: str) -> Optional[Path]:
|
| 45 |
-
pathname = os.getenv(var_name)
|
| 46 |
-
if not pathname:
|
| 47 |
-
return None
|
| 48 |
-
|
| 49 |
-
try:
|
| 50 |
-
return Path(pathname)
|
| 51 |
-
except ValueError as ex:
|
| 52 |
-
raise RuntimeError(f"Expect valid pathname, get '{pathname}'.") from ex
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def _get_cache_dir(env_names: List[str]):
|
| 56 |
-
"""Re-use cache dir from a list of existing caches"""
|
| 57 |
-
for env in env_names:
|
| 58 |
-
cache_dir = _get_path_from_env(env)
|
| 59 |
-
if cache_dir:
|
| 60 |
-
break
|
| 61 |
-
else:
|
| 62 |
-
cache_dir = Path("~/.cache").expanduser().resolve()
|
| 63 |
-
|
| 64 |
-
# Create a sub-dir to not mess up with existing caches
|
| 65 |
-
cache_dir = cache_dir / "audioseal"
|
| 66 |
-
cache_dir.mkdir(exist_ok=True, parents=True)
|
| 67 |
-
|
| 68 |
-
return cache_dir
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def load_model_checkpoint(
|
| 72 |
-
model_path: Union[Path, str],
|
| 73 |
-
device: Union[str, torch.device] = "cpu",
|
| 74 |
-
):
|
| 75 |
-
if Path(model_path).is_file():
|
| 76 |
-
return torch.load(model_path, map_location=device)
|
| 77 |
-
|
| 78 |
-
cache_dir = _get_cache_dir(
|
| 79 |
-
["AUDIOSEAL_CACHE_DIR", "AUDIOCRAFT_CACHE_DIR", "XDG_CACHE_HOME"]
|
| 80 |
-
)
|
| 81 |
-
parts = urlparse(str(model_path))
|
| 82 |
-
if parts.scheme == "https":
|
| 83 |
-
|
| 84 |
-
hash_ = sha1(parts.path.encode()).hexdigest()[:24]
|
| 85 |
-
return torch.hub.load_state_dict_from_url(
|
| 86 |
-
str(model_path), model_dir=cache_dir, map_location=device, file_name=hash_
|
| 87 |
-
)
|
| 88 |
-
elif str(model_path).startswith("facebook/audioseal/"):
|
| 89 |
-
hf_filename = str(model_path)[len("facebook/audioseal/") :]
|
| 90 |
-
|
| 91 |
-
try:
|
| 92 |
-
from huggingface_hub import hf_hub_download
|
| 93 |
-
except ModuleNotFoundError:
|
| 94 |
-
print(
|
| 95 |
-
f"The model path {model_path} seems to be a direct HF path, "
|
| 96 |
-
"but you do not install Huggingface_hub. Install with for example "
|
| 97 |
-
"`pip install huggingface_hub` to use this feature."
|
| 98 |
-
)
|
| 99 |
-
file = hf_hub_download(
|
| 100 |
-
repo_id="facebook/audioseal",
|
| 101 |
-
filename=hf_filename,
|
| 102 |
-
cache_dir=cache_dir,
|
| 103 |
-
library_name="audioseal",
|
| 104 |
-
library_version=audioseal.__version__,
|
| 105 |
-
)
|
| 106 |
-
return torch.load(file, map_location=device)
|
| 107 |
-
else:
|
| 108 |
-
raise ModelLoadError(f"Path or uri {model_path} is unknown or does not exist")
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
def load_local_model_config(model_card: str) -> Optional[DictConfig]:
|
| 112 |
-
config_file = Path(__file__).parent / "cards" / (model_card + ".yaml")
|
| 113 |
-
if Path(config_file).is_file():
|
| 114 |
-
return cast(DictConfig, OmegaConf.load(config_file.resolve()))
|
| 115 |
-
else:
|
| 116 |
-
return None
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
class AudioSeal:
|
| 120 |
-
|
| 121 |
-
@staticmethod
|
| 122 |
-
def parse_model(
|
| 123 |
-
model_card_or_path: str,
|
| 124 |
-
model_type: Type[AudioSealT],
|
| 125 |
-
nbits: Optional[int] = None,
|
| 126 |
-
) -> Tuple[Dict[str, Any], AudioSealT]:
|
| 127 |
-
"""
|
| 128 |
-
Parse the information from the model card or checkpoint path using
|
| 129 |
-
the schema `model_type` that defines the model type
|
| 130 |
-
"""
|
| 131 |
-
# Get the raw checkpoint and config from the local model cards
|
| 132 |
-
config = load_local_model_config(model_card_or_path)
|
| 133 |
-
|
| 134 |
-
if config:
|
| 135 |
-
assert "checkpoint" in config, f"Checkpoint missing in {model_card_or_path}"
|
| 136 |
-
config_dict = OmegaConf.to_container(config)
|
| 137 |
-
assert isinstance(
|
| 138 |
-
config_dict, dict
|
| 139 |
-
), f"Cannot parse config from {model_card_or_path}"
|
| 140 |
-
checkpoint = config_dict.pop("checkpoint")
|
| 141 |
-
checkpoint = load_model_checkpoint(checkpoint)
|
| 142 |
-
|
| 143 |
-
# Get the raw checkpoint and config from the checkpoint path
|
| 144 |
-
else:
|
| 145 |
-
config_dict = {}
|
| 146 |
-
checkpoint = load_model_checkpoint(model_card_or_path)
|
| 147 |
-
|
| 148 |
-
if "xp.cfg" in checkpoint:
|
| 149 |
-
config_dict = {**checkpoint["xp.cfg"], **config_dict} # type: ignore
|
| 150 |
-
|
| 151 |
-
model_config = AudioSeal.parse_config(config_dict, config_type=model_type, nbits=nbits) # type: ignore
|
| 152 |
-
|
| 153 |
-
if "model" in checkpoint:
|
| 154 |
-
checkpoint = checkpoint["model"]
|
| 155 |
-
|
| 156 |
-
return checkpoint, model_config
|
| 157 |
-
|
| 158 |
-
@staticmethod
|
| 159 |
-
def parse_config(
|
| 160 |
-
config: Dict[str, Any],
|
| 161 |
-
config_type: Type[AudioSealT],
|
| 162 |
-
nbits: Optional[int] = None,
|
| 163 |
-
) -> AudioSealT:
|
| 164 |
-
|
| 165 |
-
assert "seanet" in config, f"missing seanet backbone config in {config}"
|
| 166 |
-
|
| 167 |
-
# Patch 1: Resolve the variables in the checkpoint
|
| 168 |
-
config = OmegaConf.create(config) # type: ignore
|
| 169 |
-
OmegaConf.resolve(config) # type: ignore
|
| 170 |
-
config = OmegaConf.to_container(config) # type: ignore
|
| 171 |
-
|
| 172 |
-
# Patch 2: Put decoder, encoder and detector outside seanet
|
| 173 |
-
seanet_config = config["seanet"]
|
| 174 |
-
for key_to_patch in ["encoder", "decoder", "detector"]:
|
| 175 |
-
if key_to_patch in seanet_config:
|
| 176 |
-
config_to_patch = config.get(key_to_patch) or {}
|
| 177 |
-
config[key_to_patch] = {
|
| 178 |
-
**config_to_patch,
|
| 179 |
-
**seanet_config.pop(key_to_patch),
|
| 180 |
-
}
|
| 181 |
-
|
| 182 |
-
config["seanet"] = seanet_config
|
| 183 |
-
|
| 184 |
-
# Patch 3: Put nbits into config if specified
|
| 185 |
-
if nbits and "nbits" not in config:
|
| 186 |
-
config["nbits"] = nbits
|
| 187 |
-
|
| 188 |
-
# remove attributes not related to the model_type
|
| 189 |
-
result_config = {}
|
| 190 |
-
assert config, f"Empty config"
|
| 191 |
-
for field in fields(config_type):
|
| 192 |
-
if field.name in config:
|
| 193 |
-
result_config[field.name] = config[field.name]
|
| 194 |
-
|
| 195 |
-
schema = OmegaConf.structured(config_type)
|
| 196 |
-
schema.merge_with(result_config)
|
| 197 |
-
return schema
|
| 198 |
-
|
| 199 |
-
@staticmethod
|
| 200 |
-
def load_generator(
|
| 201 |
-
model_card_or_path: str,
|
| 202 |
-
nbits: Optional[int] = None,
|
| 203 |
-
) -> AudioSealWM:
|
| 204 |
-
"""Load the AudioSeal generator from the model card"""
|
| 205 |
-
checkpoint, config = AudioSeal.parse_model(
|
| 206 |
-
model_card_or_path,
|
| 207 |
-
AudioSealWMConfig,
|
| 208 |
-
nbits=nbits,
|
| 209 |
-
)
|
| 210 |
-
|
| 211 |
-
model = create_generator(config)
|
| 212 |
-
model.load_state_dict(checkpoint)
|
| 213 |
-
return model
|
| 214 |
-
|
| 215 |
-
@staticmethod
|
| 216 |
-
def load_detector(
|
| 217 |
-
model_card_or_path: str,
|
| 218 |
-
nbits: Optional[int] = None,
|
| 219 |
-
) -> AudioSealDetector:
|
| 220 |
-
checkpoint, config = AudioSeal.parse_model(
|
| 221 |
-
model_card_or_path,
|
| 222 |
-
AudioSealDetectorConfig,
|
| 223 |
-
nbits=nbits,
|
| 224 |
-
)
|
| 225 |
-
model = create_detector(config)
|
| 226 |
-
model.load_state_dict(checkpoint)
|
| 227 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/audioseal/models.py
DELETED
|
@@ -1,175 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import logging
|
| 8 |
-
from typing import Optional, Tuple
|
| 9 |
-
import librosa
|
| 10 |
-
import numpy as np
|
| 11 |
-
import torch
|
| 12 |
-
|
| 13 |
-
from audioseal.libs.audiocraft.modules.seanet import SEANetEncoderKeepDimension
|
| 14 |
-
|
| 15 |
-
logger = logging.getLogger("Audioseal")
|
| 16 |
-
|
| 17 |
-
COMPATIBLE_WARNING = """
|
| 18 |
-
AudioSeal is designed to work at a sample rate 16khz.
|
| 19 |
-
Implicit sampling rate usage is deprecated and will be removed in future version.
|
| 20 |
-
To remove this warning please add this argument to the function call:
|
| 21 |
-
sample_rate = your_sample_rate
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
class MsgProcessor(torch.nn.Module):
|
| 25 |
-
def __init__(self, nbits: int, hidden_size: int):
|
| 26 |
-
super().__init__()
|
| 27 |
-
assert nbits > 0, "MsgProcessor should not be built in 0bit watermarking"
|
| 28 |
-
self.nbits = nbits
|
| 29 |
-
self.hidden_size = hidden_size
|
| 30 |
-
self.msg_processor = torch.nn.Embedding(2 * nbits, hidden_size)
|
| 31 |
-
|
| 32 |
-
def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor:
|
| 33 |
-
indices = 2 * torch.arange(msg.shape[-1]).to(msg.device)
|
| 34 |
-
indices = indices.repeat(msg.shape[0], 1)
|
| 35 |
-
indices = (indices + msg).long()
|
| 36 |
-
msg_aux = self.msg_processor(indices)
|
| 37 |
-
msg_aux = msg_aux.sum(dim=-2)
|
| 38 |
-
msg_aux = msg_aux.unsqueeze(-1).repeat(1, 1, hidden.shape[2])
|
| 39 |
-
hidden = hidden + msg_aux
|
| 40 |
-
return hidden
|
| 41 |
-
|
| 42 |
-
def compute_stft_energy(audio: torch.Tensor, sr: int, n_fft: int = 2048, hop_length: int = 512) -> torch.Tensor:
|
| 43 |
-
batch_size = audio.size(0)
|
| 44 |
-
energy_values = []
|
| 45 |
-
|
| 46 |
-
for i in range(batch_size):
|
| 47 |
-
y = audio[i].cpu().numpy()
|
| 48 |
-
stft = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop_length))
|
| 49 |
-
frame_energy = torch.tensor(np.sum(stft ** 2, axis=0), device=audio.device)
|
| 50 |
-
energy_values.append(frame_energy)
|
| 51 |
-
|
| 52 |
-
energy_values = torch.stack(energy_values, dim=0)
|
| 53 |
-
return energy_values
|
| 54 |
-
|
| 55 |
-
def compute_adaptive_alpha_librosa(energy_values: torch.Tensor, min_alpha: float = 0.5, max_alpha: float = 1.5) -> torch.Tensor:
|
| 56 |
-
normalized_energy = (energy_values - energy_values.min(dim=1, keepdim=True)[0]) / (
|
| 57 |
-
energy_values.max(dim=1, keepdim=True)[0] - energy_values.min(dim=1, keepdim=True)[0] + 1e-6
|
| 58 |
-
)
|
| 59 |
-
alpha_values = min_alpha + normalized_energy * (max_alpha - min_alpha)
|
| 60 |
-
return alpha_values
|
| 61 |
-
|
| 62 |
-
class AudioSealWM(torch.nn.Module):
|
| 63 |
-
def __init__(self, encoder: torch.nn.Module, decoder: torch.nn.Module, msg_processor: Optional[torch.nn.Module] = None):
|
| 64 |
-
super().__init__()
|
| 65 |
-
self.encoder = encoder
|
| 66 |
-
self.decoder = decoder
|
| 67 |
-
self.msg_processor = msg_processor
|
| 68 |
-
self._message: Optional[torch.Tensor] = None
|
| 69 |
-
self._original_payload: Optional[torch.Tensor] = None
|
| 70 |
-
|
| 71 |
-
@property
|
| 72 |
-
def message(self) -> Optional[torch.Tensor]:
|
| 73 |
-
return self._message
|
| 74 |
-
|
| 75 |
-
@message.setter
|
| 76 |
-
def message(self, message: torch.Tensor) -> None:
|
| 77 |
-
self._message = message
|
| 78 |
-
|
| 79 |
-
def get_original_payload(self) -> Optional[torch.Tensor]:
|
| 80 |
-
return self._original_payload
|
| 81 |
-
|
| 82 |
-
def get_watermark(self, x: torch.Tensor, sample_rate: Optional[int] = None, message: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 83 |
-
# Call the forward method manually here
|
| 84 |
-
return self.forward(x, sample_rate, message)
|
| 85 |
-
|
| 86 |
-
def forward(self, x: torch.Tensor, sample_rate: Optional[int] = None, message: Optional[torch.Tensor] = None,
|
| 87 |
-
n_fft: int = 2048, hop_length: int = 512, min_alpha: float = 0.5, max_alpha: float = 1.5) -> torch.Tensor:
|
| 88 |
-
print("Forward method called!") # This should always print if forward is being executed
|
| 89 |
-
if sample_rate is None:
|
| 90 |
-
logger.warning(COMPATIBLE_WARNING)
|
| 91 |
-
sample_rate = 16_000
|
| 92 |
-
|
| 93 |
-
if sample_rate != 16000:
|
| 94 |
-
x_np = x.detach().cpu().numpy() # Ensure detached tensor is converted to NumPy array
|
| 95 |
-
resampled_x = librosa.resample(x_np, orig_sr=sample_rate, target_sr=16000)
|
| 96 |
-
x = torch.tensor(resampled_x, device=x.device)
|
| 97 |
-
|
| 98 |
-
hidden = self.encoder(x)
|
| 99 |
-
|
| 100 |
-
if self.msg_processor is not None:
|
| 101 |
-
if message is None:
|
| 102 |
-
if self.message is None:
|
| 103 |
-
message = torch.randint(0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device)
|
| 104 |
-
else:
|
| 105 |
-
message = self.message.to(device=x.device)
|
| 106 |
-
else:
|
| 107 |
-
message = message.to(device=x.device)
|
| 108 |
-
|
| 109 |
-
hidden = self.msg_processor(hidden, message)
|
| 110 |
-
self._original_payload = message
|
| 111 |
-
|
| 112 |
-
watermark = self.decoder(hidden)
|
| 113 |
-
|
| 114 |
-
if sample_rate != 16000:
|
| 115 |
-
watermark_np = watermark.detach().cpu().numpy()
|
| 116 |
-
resampled_watermark = librosa.resample(watermark_np, orig_sr=16000, target_sr=sample_rate)
|
| 117 |
-
watermark = torch.tensor(resampled_watermark, device=watermark.device)
|
| 118 |
-
|
| 119 |
-
energy_values = compute_stft_energy(x, sr=sample_rate, n_fft=n_fft, hop_length=hop_length)
|
| 120 |
-
adaptive_alpha = compute_adaptive_alpha_librosa(energy_values, min_alpha=min_alpha, max_alpha=max_alpha)
|
| 121 |
-
|
| 122 |
-
# Adjust stretched_alpha to match the dimensions of watermark
|
| 123 |
-
num_frames = adaptive_alpha.size(1)
|
| 124 |
-
stretched_alpha = torch.repeat_interleave(adaptive_alpha, hop_length, dim=1)
|
| 125 |
-
stretched_alpha = stretched_alpha[:, :x.size(1)]
|
| 126 |
-
|
| 127 |
-
# Make sure dimensions align
|
| 128 |
-
if stretched_alpha.dim() < watermark.dim():
|
| 129 |
-
stretched_alpha = stretched_alpha.unsqueeze(-1) # Add extra dimension
|
| 130 |
-
|
| 131 |
-
stretched_alpha = stretched_alpha.expand_as(watermark) # Match dimensions
|
| 132 |
-
print(f"stretched_alpha shape: {stretched_alpha.shape} for debugging")
|
| 133 |
-
|
| 134 |
-
watermarked_audio = x + stretched_alpha * watermark
|
| 135 |
-
|
| 136 |
-
return watermarked_audio
|
| 137 |
-
|
| 138 |
-
class AudioSealDetector(torch.nn.Module):
|
| 139 |
-
def __init__(self, *args, nbits: int = 0, **kwargs):
|
| 140 |
-
super().__init__()
|
| 141 |
-
encoder = SEANetEncoderKeepDimension(*args, **kwargs)
|
| 142 |
-
last_layer = torch.nn.Conv1d(encoder.output_dim, 2 + nbits, 1)
|
| 143 |
-
self.detector = torch.nn.Sequential(encoder, last_layer)
|
| 144 |
-
self.nbits = nbits
|
| 145 |
-
|
| 146 |
-
def detect_watermark(self, x: torch.Tensor, sample_rate: Optional[int] = None, message_threshold: float = 0.5) -> Tuple[float, torch.Tensor]:
|
| 147 |
-
result, message = self.forward(x, sample_rate=sample_rate)
|
| 148 |
-
print("Forward method in detector called!")
|
| 149 |
-
detected = (torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1])
|
| 150 |
-
detect_prob = detected.cpu().item()
|
| 151 |
-
message = torch.gt(message, message_threshold).int()
|
| 152 |
-
return detect_prob, message
|
| 153 |
-
|
| 154 |
-
def decode_message(self, result: torch.Tensor) -> torch.Tensor:
|
| 155 |
-
assert (result.dim() > 2 and result.shape[1] == self.nbits) or (
|
| 156 |
-
result.dim() == 2 and result.shape[0] == self.nbits
|
| 157 |
-
), f"Expect message of size [,{self.nbits}, frames] (get {result.size()})"
|
| 158 |
-
decoded_message = result.mean(dim=-1)
|
| 159 |
-
return torch.sigmoid(decoded_message)
|
| 160 |
-
|
| 161 |
-
def forward(self, x: torch.Tensor, sample_rate: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 162 |
-
if sample_rate is None:
|
| 163 |
-
logger.warning(COMPATIBLE_WARNING)
|
| 164 |
-
sample_rate = 16_000
|
| 165 |
-
|
| 166 |
-
if sample_rate != 16000:
|
| 167 |
-
x_np = x.detach().cpu().numpy()
|
| 168 |
-
resampled_x = librosa.resample(x_np, orig_sr=sample_rate, target_sr=16000)
|
| 169 |
-
x = torch.tensor(resampled_x, device=x.device)
|
| 170 |
-
|
| 171 |
-
result = self.detector(x)
|
| 172 |
-
result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1)
|
| 173 |
-
message = self.decode_message(result[:, 2:, :])
|
| 174 |
-
return result[:, :2, :], message
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/audioseal/py.typed
DELETED
|
File without changes
|
src/scripts/checkpoints.py
DELETED
|
@@ -1,51 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
|
| 10 |
-
import torch
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def convert(checkpoint: str, outdir: str, suffix: str = "base"):
|
| 14 |
-
"""Convert the checkpoint to generator and detector"""
|
| 15 |
-
outdir_path = Path(outdir)
|
| 16 |
-
ckpt = torch.load(checkpoint)
|
| 17 |
-
|
| 18 |
-
# keep inference-related params only
|
| 19 |
-
infer_cfg = {
|
| 20 |
-
"seanet": ckpt["xp.cfg"]["seanet"],
|
| 21 |
-
"channels": ckpt["xp.cfg"]["channels"],
|
| 22 |
-
"dtype": ckpt["xp.cfg"]["dtype"],
|
| 23 |
-
"sample_rate": ckpt["xp.cfg"]["sample_rate"],
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
-
generator_ckpt = {"xp.cfg": infer_cfg, "model": {}}
|
| 27 |
-
detector_ckpt = {"xp.cfg": infer_cfg, "model": {}}
|
| 28 |
-
|
| 29 |
-
for layer in ckpt["model"].keys():
|
| 30 |
-
if layer.startswith("detector"):
|
| 31 |
-
new_layer = layer[9:]
|
| 32 |
-
detector_ckpt["model"][new_layer] = ckpt["model"][layer] # type: ignore
|
| 33 |
-
elif layer == "msg_processor.msg_processor.0.weight":
|
| 34 |
-
generator_ckpt["model"]["msg_processor.msg_processor.weight"] = ckpt[ # type: ignore
|
| 35 |
-
"model"
|
| 36 |
-
][
|
| 37 |
-
layer
|
| 38 |
-
]
|
| 39 |
-
else:
|
| 40 |
-
assert layer.startswith("generator"), f"Invalid layer: {layer}"
|
| 41 |
-
new_layer = layer[10:]
|
| 42 |
-
generator_ckpt["model"][new_layer] = ckpt["model"][layer] # type: ignore
|
| 43 |
-
|
| 44 |
-
torch.save(generator_ckpt, outdir_path / (f"checkpoint_generator_{suffix}.pth"))
|
| 45 |
-
torch.save(detector_ckpt, outdir_path / (f"checkpoint_detector_{suffix}.pth"))
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
if __name__ == "__main__":
|
| 49 |
-
import fire
|
| 50 |
-
|
| 51 |
-
fire.Fire(convert)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|