PreciousMposa commited on
Commit
519d358
·
verified ·
1 Parent(s): 16f476f

Upload 107 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.dockerignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/__pycache__
2
+ **/.venv
3
+ **/.classpath
4
+ **/.dockerignore
5
+ **/.env
6
+ **/.git
7
+ **/.gitignore
8
+ **/.project
9
+ **/.settings
10
+ **/.toolstarget
11
+ **/.vs
12
+ **/.vscode
13
+ **/*.*proj.user
14
+ **/*.dbmdl
15
+ **/*.jfm
16
+ **/bin
17
+ **/charts
18
+ **/docker-compose*
19
+ **/compose*
20
+ **/Dockerfile*
21
+ **/node_modules
22
+ **/npm-debug.log
23
+ **/obj
24
+ **/secrets.dev.yaml
25
+ **/values.dev.yaml
26
+ LICENSE
27
+ README.md
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demucs.png filter=lfs diff=lfs merge=lfs -text
37
+ test.mp3 filter=lfs diff=lfs merge=lfs -text
.github/ISSUE_TEMPLATE/bug.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🐛 Bug Report
3
+ about: Submit a bug report to help us improve
4
+ labels: 'bug'
5
+ ---
6
+
7
+ ## 🐛 Bug Report
8
+
9
+ (A clear and concise description of what the bug is)
10
+
11
+ ## To Reproduce
12
+
13
+ (Write your steps here:)
14
+
15
+ 1. Step 1...
16
+ 1. Step 2...
17
+ 1. Step 3...
18
+
19
+ ## Expected behavior
20
+
21
+ (Write what you thought would happen.)
22
+
23
+ ## Actual Behavior
24
+
25
+ (Write what happened. Add screenshots, if applicable.)
26
+
27
+ ## Your Environment
28
+
29
+ <!-- Include as many relevant details about the environment you experienced the bug in -->
30
+
31
+ - Python and PyTorch version:
32
+ - Operating system and version (desktop or mobile):
33
+ - Hardware (gpu or cpu, amount of RAM etc.):
.github/ISSUE_TEMPLATE/question.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "❓Questions/Help/Support"
3
+ about: If you have a question about the paper, code or algorithm, please ask here!
4
+ labels: question
5
+
6
+ ---
7
+
8
+ ## ❓ Questions
9
+
10
+ (Please ask your question here.)
.github/workflows/linter.yml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: linter
2
+ on:
3
+ push:
4
+ branches: [ main ]
5
+ pull_request:
6
+ branches: [ main ]
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ build:
11
+ runs-on: ubuntu-latest
12
+ if: ${{ github.repository == 'facebookresearch/demucs' || github.event_name == 'workflow_dispatch' }}
13
+ steps:
14
+ - uses: actions/checkout@v2
15
+ - uses: actions/setup-python@v2
16
+ with:
17
+ python-version: 3.8
18
+
19
+ - uses: actions/cache@v2
20
+ with:
21
+ path: env
22
+ key: env-${{ hashFiles('**/requirements.txt', '.github/workflows/*') }}
23
+
24
+ - name: Install dependencies
25
+ run: |
26
+ python3 -m venv env
27
+ . env/bin/activate
28
+ python -m pip install --upgrade pip
29
+ pip install -r requirements.txt
30
+ pip install '.[dev]'
31
+
32
+
33
+ - name: Run linter
34
+ run: |
35
+ . env/bin/activate
36
+ make linter
.github/workflows/tests.yml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tests
2
+ on:
3
+ push:
4
+ branches: [ main ]
5
+ pull_request:
6
+ branches: [ main ]
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ build:
11
+ runs-on: ubuntu-latest
12
+ if: ${{ github.repository == 'facebookresearch/demucs' || github.event_name == 'workflow_dispatch' }}
13
+ steps:
14
+ - uses: actions/checkout@v2
15
+ - uses: actions/setup-python@v2
16
+ with:
17
+ python-version: 3.8
18
+
19
+ - uses: actions/cache@v2
20
+ with:
21
+ path: env
22
+ key: env-${{ hashFiles('**/requirements.txt', '.github/workflows/*') }}
23
+
24
+ - name: Install dependencies
25
+ run: |
26
+ sudo apt-get update
27
+ sudo apt-get install -y ffmpeg
28
+ python3 -m venv env
29
+ . env/bin/activate
30
+ python -m pip install --upgrade pip
31
+ pip install -r requirements.txt
32
+
33
+ - name: Run separation test
34
+ run: |
35
+ . env/bin/activate
36
+ make test_eval
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.egg-info
2
+ __pycache__
3
+ Session.vim
4
+ /build
5
+ /dist
6
+ /lab
7
+ /metadata
8
+ /notebooks
9
+ /outputs
10
+ /release
11
+ /release_models
12
+ /separated
13
+ /tests
14
+ /trash
15
+ /misc
16
+ /mdx
17
+ .mypy_cache
.vscode/launch.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "configurations": [
3
+ {
4
+ "name": "Containers: Python - Fastapi",
5
+ "type": "docker",
6
+ "request": "launch",
7
+ "preLaunchTask": "docker-run: debug",
8
+ "python": {
9
+ "pathMappings": [
10
+ {
11
+ "localRoot": "${workspaceFolder}",
12
+ "remoteRoot": "/app"
13
+ }
14
+ ],
15
+ "projectType": "fastapi"
16
+ }
17
+ }
18
+ ]
19
+ }
.vscode/tasks.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "2.0.0",
3
+ "tasks": [
4
+ {
5
+ "type": "docker-build",
6
+ "label": "docker-build",
7
+ "platform": "python",
8
+ "dockerBuild": {
9
+ "tag": "demucs:latest",
10
+ "dockerfile": "${workspaceFolder}/Dockerfile",
11
+ "context": "${workspaceFolder}",
12
+ "pull": true
13
+ }
14
+ },
15
+ {
16
+ "type": "docker-run",
17
+ "label": "docker-run: debug",
18
+ "dependsOn": [
19
+ "docker-build"
20
+ ],
21
+ "python": {
22
+ "args": [
23
+ "predict:app",
24
+ "--host",
25
+ "0.0.0.0",
26
+ "--port",
27
+ "8000"
28
+ ],
29
+ "module": "uvicorn"
30
+ }
31
+ }
32
+ ]
33
+ }
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the project team at <[email protected]>. All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72
+
73
+ [homepage]: https://www.contributor-covenant.org
74
+
75
+ For answers to common questions about this code of conduct, see
76
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to Demucs
2
+
3
+ ## Pull Requests
4
+
5
+ In order to accept your pull request, we need you to submit a CLA. You only need
6
+ to do this once to work on any of Facebook's open source projects.
7
+
8
+ Complete your CLA here: <https://code.facebook.com/cla>
9
+
10
+ Demucs is the implementation of a research paper.
11
+ Therefore, we do not plan on accepting many pull requests for new features.
12
+ We certainly welcome them for bug fixes.
13
+
14
+
15
+ ## Issues
16
+
17
+ We use GitHub issues to track public bugs. Please ensure your description is
18
+ clear and has sufficient instructions to be able to reproduce the issue.
19
+
20
+
21
+ ## License
22
+ By contributing to this repository, you agree that your contributions will be licensed
23
+ under the LICENSE file in the root directory of this source tree.
Demucs.ipynb ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "Be9yoh-ILfRr"
8
+ },
9
+ "source": [
10
+ "# Hybrid Demucs\n",
11
+ "\n",
12
+ "Feel free to use the Colab version:\n",
13
+ "https://colab.research.google.com/drive/1dC9nVxk3V_VPjUADsnFu8EiT-xnU1tGH?usp=sharing"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {
20
+ "colab": {
21
+ "base_uri": "https://localhost:8080/",
22
+ "height": 139
23
+ },
24
+ "colab_type": "code",
25
+ "executionInfo": {
26
+ "elapsed": 12277,
27
+ "status": "ok",
28
+ "timestamp": 1583778134659,
29
+ "user": {
30
+ "displayName": "Marllus Lustosa",
31
+ "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgLl2RbW64ZyWz3Y8IBku0zhHCMnt7fz7fEl0LTdA=s64",
32
+ "userId": "14811735256675200480"
33
+ },
34
+ "user_tz": 180
35
+ },
36
+ "id": "kOjIPLlzhPfn",
37
+ "outputId": "c75f17ec-b576-4105-bc5b-c2ac9c1018a3"
38
+ },
39
+ "outputs": [],
40
+ "source": [
41
+ "!pip install -U demucs\n",
42
+ "# or for local development, if you have a clone of Demucs\n",
43
+ "# pip install -e ."
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {
50
+ "colab": {},
51
+ "colab_type": "code",
52
+ "id": "5lYOzKKCKAbJ"
53
+ },
54
+ "outputs": [],
55
+ "source": [
56
+ "# You can use the `demucs` command line to separate tracks\n",
57
+ "!demucs test.mp3"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "# You can also load directly the pretrained models,\n",
67
+ "# for instance for the MDX 2021 winning model of Track A:\n",
68
+ "from demucs import pretrained\n",
69
+ "model = pretrained.get_model('mdx')"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "# Because `model` is a bag of 4 models, you cannot directly call it on your data,\n",
79
+ "# but the `apply_model` will know what to do of it.\n",
80
+ "import torch\n",
81
+ "from demucs.apply import apply_model\n",
82
+ "x = torch.randn(1, 2, 44100 * 10) # ten seconds of white noise for the demo\n",
83
+ "out = apply_model(model, x)[0] # shape is [S, C, T] with S the number of sources\n",
84
+ "\n",
85
+ "# So let see, where is all the white noise content is going ?\n",
86
+ "for name, source in zip(model.sources, out):\n",
87
+ " print(name, source.std() / x.std())\n",
88
+ "# The outputs are quite weird to be fair, not what I would have expected."
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "# now let's take a single model from the bag, and let's test it on a pure cosine\n",
98
+ "freq = 440 # in Hz\n",
99
+ "sr = model.samplerate\n",
100
+ "t = torch.arange(10 * sr).float() / sr\n",
101
+ "x = torch.cos(2 * 3.1416 * freq * t).expand(1, 2, -1)\n",
102
+ "sub_model = model.models[3]\n",
103
+ "out = sub_model(x)[0]\n",
104
+ "\n",
105
+ "# Same question where does it go?\n",
106
+ "for name, source in zip(model.sources, out):\n",
107
+ " print(name, source.std() / x.std())\n",
108
+ " \n",
109
+ "# Well now it makes much more sense, all the energy is going\n",
110
+ "# in the `other` source.\n",
111
+ "# Feel free to try lower pitch (try 80 Hz) to see what happens !"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "# For training or more fun, refer to the Demucs README on our repo\n",
121
+ "# https://github.com/facebookresearch/demucs/tree/main/demucs"
122
+ ]
123
+ }
124
+ ],
125
+ "metadata": {
126
+ "accelerator": "GPU",
127
+ "colab": {
128
+ "authorship_tag": "ABX9TyM9xpVr1M86NRcjtQ7g9tCx",
129
+ "collapsed_sections": [],
130
+ "name": "Demucs.ipynb",
131
+ "provenance": []
132
+ },
133
+ "kernelspec": {
134
+ "display_name": "Python 3",
135
+ "language": "python",
136
+ "name": "python3"
137
+ },
138
+ "language_info": {
139
+ "codemirror_mode": {
140
+ "name": "ipython",
141
+ "version": 3
142
+ },
143
+ "file_extension": ".py",
144
+ "mimetype": "text/x-python",
145
+ "name": "python",
146
+ "nbconvert_exporter": "python",
147
+ "pygments_lexer": "ipython3",
148
+ "version": "3.8.8"
149
+ }
150
+ },
151
+ "nbformat": 4,
152
+ "nbformat_minor": 1
153
+ }
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.9 slim base
2
+ FROM python:3.9-slim
3
+
4
+ # Install system dependencies
5
+ RUN apt-get update && apt-get install -y ffmpeg git && apt-get clean
6
+
7
+ # Set work directory
8
+ WORKDIR /app
9
+
10
+ # Install Python packages
11
+ RUN pip install --upgrade pip
12
+ RUN pip install torch torchaudio
13
+ RUN pip install fastapi uvicorn
14
+ RUN pip install git+https://github.com/facebookresearch/demucs
15
+
16
+ # Copy your inference script into the container
17
+ COPY predict.py .
18
+
19
+ # Expose port for FastAPI
20
+ EXPOSE 8000
21
+
22
+ # Run the FastAPI app
23
+ CMD ["uvicorn", "predict:app", "--host", "0.0.0.0", "--port", "8000"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MANIFEST.in ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ recursive-exclude env *
2
+ recursive-include conf *.yaml
3
+ include Makefile
4
+ include LICENSE
5
+ include demucs.png
6
+ include outputs.tar.gz
7
+ include test.mp3
8
+ include requirements.txt
9
+ include requirements_minimal.txt
10
+ include mypy.ini
11
+ include demucs/py.typed
12
+ include demucs/remote/*.txt
13
+ include demucs/remote/*.yaml
Makefile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ all: linter tests
2
+
3
+ linter:
4
+ flake8 demucs
5
+ mypy demucs
6
+
7
+ tests: test_train test_eval
8
+
9
+ test_train: tests/musdb
10
+ _DORA_TEST_PATH=/tmp/demucs python3 -m dora run --clear \
11
+ dset.musdb=./tests/musdb dset.segment=4 dset.shift=2 epochs=2 model=demucs \
12
+ demucs.depth=2 demucs.channels=4 test.sdr=false misc.num_workers=0 test.workers=0 \
13
+ test.shifts=0
14
+
15
+ test_eval:
16
+ python3 -m demucs -n demucs_unittest test.mp3
17
+ python3 -m demucs -n demucs_unittest --two-stems=vocals test.mp3
18
+ python3 -m demucs -n demucs_unittest --mp3 test.mp3
19
+ python3 -m demucs -n demucs_unittest --flac --int24 test.mp3
20
+ python3 -m demucs -n demucs_unittest --int24 --clip-mode clamp test.mp3
21
+ python3 -m demucs -n demucs_unittest --segment 8 test.mp3
22
+ python3 -m demucs.api -n demucs_unittest --segment 8 test.mp3
23
+ python3 -m demucs --list-models
24
+
25
+ tests/musdb:
26
+ test -e tests || mkdir tests
27
+ python3 -c 'import musdb; musdb.DB("tests/tmp", download=True)'
28
+ musdbconvert tests/tmp tests/musdb
29
+
30
+ dist:
31
+ python3 setup.py sdist
32
+
33
+ clean:
34
+ rm -r dist build *.egg-info
35
+
36
+ .PHONY: linter dist test_train test_eval
README.md CHANGED
@@ -1,14 +1,319 @@
1
- ---
2
- title: Audio
3
- emoji: 📈
4
- colorFrom: pink
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.35.0
8
- app_file: app.py
9
- pinned: false
10
- license: unknown
11
- short_description: audio processor
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demucs Music Source Separation
2
+
3
+ [![Support Ukraine](https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB)](https://opensource.fb.com/support-ukraine)
4
+ ![tests badge](https://github.com/facebookresearch/demucs/workflows/tests/badge.svg)
5
+ ![linter badge](https://github.com/facebookresearch/demucs/workflows/linter/badge.svg)
6
+
7
+
8
+ **Important:** As I am no longer working at Meta, **this repository is not maintained anymore**.
9
+ I've created a fork at [github.com/adefossez/demucs](https://github.com/adefossez/demucs). Note that this project is not actively maintained anymore
10
+ and only important bug fixes will be processed on the new repo. Please do not open issues for feature request or if Demucs doesn't work perfectly for your use case :)
11
+
12
+ This is the 4th release of Demucs (v4), featuring Hybrid Transformer based source separation.
13
+ **For the classic Hybrid Demucs (v3):** [Go this commit][demucs_v3].
14
+ If you are experiencing issues and want the old Demucs back, please file an issue, and then you can get back to Demucs v3 with
15
+ `git checkout v3`. You can also go [Demucs v2][demucs_v2].
16
+
17
+
18
+ Demucs is a state-of-the-art music source separation model, currently capable of separating
19
+ drums, bass, and vocals from the rest of the accompaniment.
20
+ Demucs is based on a U-Net convolutional architecture inspired by [Wave-U-Net][waveunet].
21
+ The v4 version features [Hybrid Transformer Demucs][htdemucs], a hybrid spectrogram/waveform separation model using Transformers.
22
+ It is based on [Hybrid Demucs][hybrid_paper] (also provided in this repo), with the innermost layers
23
+ replaced by a cross-domain Transformer Encoder. This Transformer uses self-attention within each domain,
24
+ and cross-attention across domains.
25
+ The model achieves a SDR of 9.00 dB on the MUSDB HQ test set. Moreover, when using sparse attention
26
+ kernels to extend its receptive field and per source fine-tuning, we achieve state-of-the-art 9.20 dB of SDR.
27
+
28
+ Samples are available [on our sample page](https://ai.honu.io/papers/htdemucs/index.html).
29
+ Checkout [our paper][htdemucs] for more information.
30
+ It has been trained on the [MUSDB HQ][musdb] dataset + an extra training dataset of 800 songs.
31
+ This model separates drums, bass and vocals and other stems for any song.
32
+
33
+
34
+ As Hybrid Transformer Demucs is brand new, it is not activated by default, you can activate it in the usual
35
+ commands described hereafter with `-n htdemucs_ft`.
36
+ The single, non fine-tuned model is provided as `-n htdemucs`, and the retrained baseline
37
+ as `-n hdemucs_mmi`. The Sparse Hybrid Transformer model decribed in our paper is not provided as its
38
+ requires custom CUDA code that is not ready for release yet.
39
+ We are also releasing an experimental 6 sources model, that adds a `guitar` and `piano` source.
40
+ Quick testing seems to show okay quality for `guitar`, but a lot of bleeding and artifacts for the `piano` source.
41
+
42
+
43
+ <p align="center">
44
+ <img src="./demucs.png" alt="Schema representing the structure of Hybrid Transformer Demucs,
45
+ with a dual U-Net structure, one branch for the temporal domain,
46
+ and one branch for the spectral domain. There is a cross-domain Transformer between the Encoders and Decoders."
47
+ width="800px"></p>
48
+
49
+
50
+
51
+ ## Important news if you are already using Demucs
52
+
53
+ See the [release notes](./docs/release.md) for more details.
54
+
55
+ - 22/02/2023: added support for the [SDX 2023 Challenge](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023),
56
+ see the dedicated [doc page](./docs/sdx23.md)
57
+ - 07/12/2022: Demucs v4 now on PyPI. **htdemucs** model now used by default. Also releasing
58
+ a 6 sources models (adding `guitar` and `piano`, although the latter doesn't work so well at the moment).
59
+ - 16/11/2022: Added the new **Hybrid Transformer Demucs v4** models.
60
+ Adding support for the [torchaudio implementation of HDemucs](https://pytorch.org/audio/stable/tutorials/hybrid_demucs_tutorial.html).
61
+ - 30/08/2022: added reproducibility and ablation grids, along with an updated version of the paper.
62
+ - 17/08/2022: Releasing v3.0.5: Set split segment length to reduce memory. Compatible with pyTorch 1.12.
63
+ - 24/02/2022: Releasing v3.0.4: split into two stems (i.e. karaoke mode).
64
+ Export as float32 or int24.
65
+ - 17/12/2021: Releasing v3.0.3: bug fixes (thanks @keunwoochoi), memory drastically
66
+ reduced on GPU (thanks @famzah) and new multi-core evaluation on CPU (`-j` flag).
67
+ - 12/11/2021: Releasing **Demucs v3** with hybrid domain separation. Strong improvements
68
+ on all sources. This is the model that won Sony MDX challenge.
69
+ - 11/05/2021: Adding support for MusDB-HQ and arbitrary wav set, for the MDX challenge. For more information
70
+ on joining the challenge with Demucs see [the Demucs MDX instructions](docs/mdx.md)
71
+
72
+
73
+ ## Comparison with other models
74
+
75
+ We provide hereafter a summary of the different metrics presented in the paper.
76
+ You can also compare Hybrid Demucs (v3), [KUIELAB-MDX-Net][kuielab], [Spleeter][spleeter], Open-Unmix, Demucs (v1), and Conv-Tasnet on one of my favorite
77
+ songs on my [soundcloud playlist][soundcloud].
78
+
79
+ ### Comparison of accuracy
80
+
81
+ `Overall SDR` is the mean of the SDR for each of the 4 sources, `MOS Quality` is a rating from 1 to 5
82
+ of the naturalness and absence of artifacts given by human listeners (5 = no artifacts), `MOS Contamination`
83
+ is a rating from 1 to 5 with 5 being zero contamination by other sources. We refer the reader to our [paper][hybrid_paper],
84
+ for more details.
85
+
86
+ | Model | Domain | Extra data? | Overall SDR | MOS Quality | MOS Contamination |
87
+ |------------------------------|-------------|-------------------|-------------|-------------|-------------------|
88
+ | [Wave-U-Net][waveunet] | waveform | no | 3.2 | - | - |
89
+ | [Open-Unmix][openunmix] | spectrogram | no | 5.3 | - | - |
90
+ | [D3Net][d3net] | spectrogram | no | 6.0 | - | - |
91
+ | [Conv-Tasnet][demucs_v2] | waveform | no | 5.7 | - | |
92
+ | [Demucs (v2)][demucs_v2] | waveform | no | 6.3 | 2.37 | 2.36 |
93
+ | [ResUNetDecouple+][decouple] | spectrogram | no | 6.7 | - | - |
94
+ | [KUIELAB-MDX-Net][kuielab] | hybrid | no | 7.5 | **2.86** | 2.55 |
95
+ | [Band-Spit RNN][bandsplit] | spectrogram | no | **8.2** | - | - |
96
+ | **Hybrid Demucs (v3)** | hybrid | no | 7.7 | **2.83** | **3.04** |
97
+ | [MMDenseLSTM][mmdenselstm] | spectrogram | 804 songs | 6.0 | - | - |
98
+ | [D3Net][d3net] | spectrogram | 1.5k songs | 6.7 | - | - |
99
+ | [Spleeter][spleeter] | spectrogram | 25k songs | 5.9 | - | - |
100
+ | [Band-Spit RNN][bandsplit] | spectrogram | 1.7k (mixes only) | **9.0** | - | - |
101
+ | **HT Demucs f.t. (v4)** | hybrid | 800 songs | **9.0** | - | - |
102
+
103
+
104
+
105
+ ## Requirements
106
+
107
+ You will need at least Python 3.8. See `requirements_minimal.txt` for requirements for separation only,
108
+ and `environment-[cpu|cuda].yml` (or `requirements.txt`) if you want to train a new model.
109
+
110
+ ### For Windows users
111
+
112
+ Everytime you see `python3`, replace it with `python.exe`. You should always run commands from the
113
+ Anaconda console.
114
+
115
+ ### For musicians
116
+
117
+ If you just want to use Demucs to separate tracks, you can install it with
118
+
119
+ ```bash
120
+ python3 -m pip install -U demucs
121
+ ```
122
+
123
+ For bleeding edge versions, you can install directly from this repo using
124
+ ```bash
125
+ python3 -m pip install -U git+https://github.com/facebookresearch/demucs#egg=demucs
126
+ ```
127
+
128
+ Advanced OS support are provided on the following page, **you must read the page for your OS before posting an issues**:
129
+ - **If you are using Windows:** [Windows support](docs/windows.md).
130
+ - **If you are using macOS:** [macOS support](docs/mac.md).
131
+ - **If you are using Linux:** [Linux support](docs/linux.md).
132
+
133
+ ### For machine learning scientists
134
+
135
+ If you have anaconda installed, you can run from the root of this repository:
136
+
137
+ ```bash
138
+ conda env update -f environment-cpu.yml # if you don't have GPUs
139
+ conda env update -f environment-cuda.yml # if you have GPUs
140
+ conda activate demucs
141
+ pip install -e .
142
+ ```
143
+
144
+ This will create a `demucs` environment with all the dependencies installed.
145
+
146
+ You will also need to install [soundstretch/soundtouch](https://www.surina.net/soundtouch/soundstretch.html): on macOS you can do `brew install sound-touch`,
147
+ and on Ubuntu `sudo apt-get install soundstretch`. This is used for the
148
+ pitch/tempo augmentation.
149
+
150
+
151
+ ### Running in Docker
152
+
153
+ Thanks to @xserrat, there is now a Docker image definition ready for using Demucs. This can ensure all libraries are correctly installed without interfering with the host OS. See his repo [Docker Facebook Demucs](https://github.com/xserrat/docker-facebook-demucs) for more information.
154
+
155
+
156
+ ### Running from Colab
157
+
158
+ I made a Colab to easily separate track with Demucs. Note that
159
+ transfer speeds with Colab are a bit slow for large media files,
160
+ but it will allow you to use Demucs without installing anything.
161
+
162
+ [Demucs on Google Colab](https://colab.research.google.com/drive/1dC9nVxk3V_VPjUADsnFu8EiT-xnU1tGH?usp=sharing)
163
+
164
+ ### Web Demo
165
+
166
+ Integrated to [Hugging Face Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/demucs)
167
+
168
+ ### Graphical Interface
169
+
170
+ @CarlGao4 has released a GUI for Demucs: [CarlGao4/Demucs-Gui](https://github.com/CarlGao4/Demucs-Gui). Downloads for Windows and macOS is available [here](https://github.com/CarlGao4/Demucs-Gui/releases). Use [FossHub mirror](https://fosshub.com/Demucs-GUI.html) to speed up your download.
171
+
172
+ @Anjok07 is providing a self contained GUI in [UVR (Ultimate Vocal Remover)](https://github.com/facebookresearch/demucs/issues/334) that supports Demucs.
173
+
174
+ ### Other providers
175
+
176
+ Audiostrip is providing free online separation with Demucs on their website [https://audiostrip.co.uk/](https://audiostrip.co.uk/).
177
+
178
+ [MVSep](https://mvsep.com/) also provides free online separation, select `Demucs3 model B` for the best quality.
179
+
180
+ [Neutone](https://neutone.space/) provides a realtime Demucs model in their free VST/AU plugin that can be used in your favorite DAW.
181
+
182
+
183
+ ## Separating tracks
184
+
185
+ In order to try Demucs, you can just run from any folder (as long as you properly installed it)
186
+
187
+ ```bash
188
+ demucs PATH_TO_AUDIO_FILE_1 [PATH_TO_AUDIO_FILE_2 ...] # for Demucs
189
+ # If you used `pip install --user` you might need to replace demucs with python3 -m demucs
190
+ python3 -m demucs --mp3 --mp3-bitrate BITRATE PATH_TO_AUDIO_FILE_1 # output files saved as MP3
191
+ # use --mp3-preset to change encoder preset, 2 for best quality, 7 for fastest
192
+ # If your filename contain spaces don't forget to quote it !!!
193
+ demucs "my music/my favorite track.mp3"
194
+ # You can select different models with `-n` mdx_q is the quantized model, smaller but maybe a bit less accurate.
195
+ demucs -n mdx_q myfile.mp3
196
+ # If you only want to separate vocals out of an audio, use `--two-stems=vocals` (You can also set to drums or bass)
197
+ demucs --two-stems=vocals myfile.mp3
198
+ ```
199
+
200
+
201
+ If you have a GPU, but you run out of memory, please use `--segment SEGMENT` to reduce length of each split. `SEGMENT` should be changed to a integer describing the length of each segment in seconds.
202
+ A segment length of at least 10 is recommended (the bigger the number is, the more memory is required, but quality may increase). Note that the Hybrid Transformer models only support a maximum segment length of 7.8 seconds.
203
+ Creating an environment variable `PYTORCH_NO_CUDA_MEMORY_CACHING=1` is also helpful. If this still does not help, please add `-d cpu` to the command line. See the section hereafter for more details on the memory requirements for GPU acceleration.
204
+
205
+ Separated tracks are stored in the `separated/MODEL_NAME/TRACK_NAME` folder. There you will find four stereo wav files sampled at 44.1 kHz: `drums.wav`, `bass.wav`,
206
+ `other.wav`, `vocals.wav` (or `.mp3` if you used the `--mp3` option).
207
+
208
+ All audio formats supported by `torchaudio` can be processed (i.e. wav, mp3, flac, ogg/vorbis on Linux/macOS, etc.). On Windows, `torchaudio` has limited support, so we rely on `ffmpeg`, which should support pretty much anything.
209
+ Audio is resampled on the fly if necessary.
210
+ The output will be a wav file encoded as int16.
211
+ You can save as float32 wav files with `--float32`, or 24 bits integer wav with `--int24`.
212
+ You can pass `--mp3` to save as mp3 instead, and set the bitrate (in kbps) with `--mp3-bitrate` (default is 320).
213
+
214
+ It can happen that the output would need clipping, in particular due to some separation artifacts.
215
+ Demucs will automatically rescale each output stem so as to avoid clipping. This can however break
216
+ the relative volume between stems. If instead you prefer hard clipping, pass `--clip-mode clamp`.
217
+ You can also try to reduce the volume of the input mixture before feeding it to Demucs.
218
+
219
+
220
+ Other pre-trained models can be selected with the `-n` flag.
221
+ The list of pre-trained models is:
222
+ - `htdemucs`: first version of Hybrid Transformer Demucs. Trained on MusDB + 800 songs. Default model.
223
+ - `htdemucs_ft`: fine-tuned version of `htdemucs`, separation will take 4 times more time
224
+ but might be a bit better. Same training set as `htdemucs`.
225
+ - `htdemucs_6s`: 6 sources version of `htdemucs`, with `piano` and `guitar` being added as sources.
226
+ Note that the `piano` source is not working great at the moment.
227
+ - `hdemucs_mmi`: Hybrid Demucs v3, retrained on MusDB + 800 songs.
228
+ - `mdx`: trained only on MusDB HQ, winning model on track A at the [MDX][mdx] challenge.
229
+ - `mdx_extra`: trained with extra training data (**including MusDB test set**), ranked 2nd on the track B
230
+ of the [MDX][mdx] challenge.
231
+ - `mdx_q`, `mdx_extra_q`: quantized version of the previous models. Smaller download and storage
232
+ but quality can be slightly worse.
233
+ - `SIG`: where `SIG` is a single model from the [model zoo](docs/training.md#model-zoo).
234
+
235
+ The `--two-stems=vocals` option allows separating vocals from the rest of the accompaniment (i.e., karaoke mode).
236
+ `vocals` can be changed to any source in the selected model.
237
+ This will mix the files after separating the mix fully, so this won't be faster or use less memory.
238
+
239
+ The `--shifts=SHIFTS` performs multiple predictions with random shifts (a.k.a the *shift trick*) of the input and average them. This makes prediction `SHIFTS` times
240
+ slower. Don't use it unless you have a GPU.
241
+
242
+ The `--overlap` option controls the amount of overlap between prediction windows. Default is 0.25 (i.e. 25%) which is probably fine.
243
+ It can probably be reduced to 0.1 to improve a bit speed.
244
+
245
+
246
+ The `-j` flag allow to specify a number of parallel jobs (e.g. `demucs -j 2 myfile.mp3`).
247
+ This will multiply by the same amount the RAM used so be careful!
248
+
249
+ ### Memory requirements for GPU acceleration
250
+
251
+ If you want to use GPU acceleration, you will need at least 3GB of RAM on your GPU for `demucs`. However, about 7GB of RAM will be required if you use the default arguments. Add `--segment SEGMENT` to change size of each split. If you only have 3GB memory, set SEGMENT to 8 (though quality may be worse if this argument is too small). Creating an environment variable `PYTORCH_NO_CUDA_MEMORY_CACHING=1` can help users with even smaller RAM such as 2GB (I separated a track that is 4 minutes but only 1.5GB is used), but this would make the separation slower.
252
+
253
+ If you do not have enough memory on your GPU, simply add `-d cpu` to the command line to use the CPU. With Demucs, processing time should be roughly equal to 1.5 times the duration of the track.
254
+
255
+ ## Calling from another Python program
256
+
257
+ The main function provides an `opt` parameter as a simple API. You can just pass the parsed command line as this parameter:
258
+ ```python
259
+ # Assume that your command is `demucs --mp3 --two-stems vocals -n mdx_extra "track with space.mp3"`
260
+ # The following codes are same as the command above:
261
+ import demucs.separate
262
+ demucs.separate.main(["--mp3", "--two-stems", "vocals", "-n", "mdx_extra", "track with space.mp3"])
263
+
264
+ # Or like this
265
+ import demucs.separate
266
+ import shlex
267
+ demucs.separate.main(shlex.split('--mp3 --two-stems vocals -n mdx_extra "track with space.mp3"'))
268
+ ```
269
+
270
+ To use more complicated APIs, see [API docs](docs/api.md)
271
+
272
+ ## Training Demucs
273
+
274
+ If you want to train (Hybrid) Demucs, please follow the [training doc](docs/training.md).
275
+
276
+ ## MDX Challenge reproduction
277
+
278
+ In order to reproduce the results from the Track A and Track B submissions, checkout the [MDX Hybrid Demucs submission repo][mdx_submission].
279
+
280
+
281
+
282
+ ## How to cite
283
+
284
+ ```
285
+ @inproceedings{rouard2022hybrid,
286
+ title={Hybrid Transformers for Music Source Separation},
287
+ author={Rouard, Simon and Massa, Francisco and D{\'e}fossez, Alexandre},
288
+ booktitle={ICASSP 23},
289
+ year={2023}
290
+ }
291
+
292
+ @inproceedings{defossez2021hybrid,
293
+ title={Hybrid Spectrogram and Waveform Source Separation},
294
+ author={D{\'e}fossez, Alexandre},
295
+ booktitle={Proceedings of the ISMIR 2021 Workshop on Music Source Separation},
296
+ year={2021}
297
+ }
298
+ ```
299
+
300
+ ## License
301
+
302
+ Demucs is released under the MIT license as found in the [LICENSE](LICENSE) file.
303
+
304
+ [hybrid_paper]: https://arxiv.org/abs/2111.03600
305
+ [waveunet]: https://github.com/f90/Wave-U-Net
306
+ [musdb]: https://sigsep.github.io/datasets/musdb.html
307
+ [openunmix]: https://github.com/sigsep/open-unmix-pytorch
308
+ [mmdenselstm]: https://arxiv.org/abs/1805.02410
309
+ [demucs_v2]: https://github.com/facebookresearch/demucs/tree/v2
310
+ [demucs_v3]: https://github.com/facebookresearch/demucs/tree/v3
311
+ [spleeter]: https://github.com/deezer/spleeter
312
+ [soundcloud]: https://soundcloud.com/honualx/sets/source-separation-in-the-waveform-domain
313
+ [d3net]: https://arxiv.org/abs/2010.01733
314
+ [mdx]: https://www.aicrowd.com/challenges/music-demixing-challenge-ismir-2021
315
+ [kuielab]: https://github.com/kuielab/mdx-net-submission
316
+ [decouple]: https://arxiv.org/abs/2109.05418
317
+ [mdx_submission]: https://github.com/adefossez/mdx21_demucs
318
+ [bandsplit]: https://arxiv.org/abs/2209.15174
319
+ [htdemucs]: https://arxiv.org/abs/2211.08553
app.py CHANGED
@@ -1,37 +1,37 @@
1
- import os
2
- import shutil
3
- import gradio as gr
4
- from demucs.separate import main
5
-
6
- def separate_stems(audio_file):
7
- input_path = "input.mp3"
8
- shutil.copy(audio_file, input_path)
9
-
10
- output_dir = "output"
11
- if os.path.exists(output_dir):
12
- shutil.rmtree(output_dir)
13
- os.makedirs(output_dir, exist_ok=True)
14
-
15
- # Run Demucs
16
- main(["-n", "htdemucs", "-o", output_dir, input_path])
17
-
18
- # Build list of stems to return
19
- base = os.path.splitext(os.path.basename(input_path))[0]
20
- stem_path = os.path.join(output_dir, "htdemucs", base)
21
- stems = [os.path.join(stem_path, f"{stem}.mp3") for stem in ["vocals", "drums", "bass", "other"]]
22
- return stems
23
-
24
- demo = gr.Interface(
25
- fn=separate_stems,
26
- inputs=gr.Audio(type="filepath", label="Upload Song"),
27
- outputs=[
28
- gr.Audio(label="Vocals"),
29
- gr.Audio(label="Drums"),
30
- gr.Audio(label="Bass"),
31
- gr.Audio(label="Other"),
32
- ],
33
- title="Demucs v4 Stem Separator",
34
- description="Upload a song to separate vocals, drums, bass, and other using Facebook's Demucs model.",
35
- )
36
-
37
- demo.launch()
 
1
+ import os
2
+ import shutil
3
+ import gradio as gr
4
+ from demucs.separate import main
5
+
6
+ def separate_stems(audio_file):
7
+ input_path = "input.mp3"
8
+ shutil.copy(audio_file, input_path)
9
+
10
+ output_dir = "output"
11
+ if os.path.exists(output_dir):
12
+ shutil.rmtree(output_dir)
13
+ os.makedirs(output_dir, exist_ok=True)
14
+
15
+ # Run Demucs
16
+ main(["-n", "htdemucs", "-o", output_dir, input_path])
17
+
18
+ # Build list of stems to return
19
+ base = os.path.splitext(os.path.basename(input_path))[0]
20
+ stem_path = os.path.join(output_dir, "htdemucs", base)
21
+ stems = [os.path.join(stem_path, f"{stem}.mp3") for stem in ["vocals", "drums", "bass", "other"]]
22
+ return stems
23
+
24
+ demo = gr.Interface(
25
+ fn=separate_stems,
26
+ inputs=gr.Audio(type="filepath", label="Upload Song"),
27
+ outputs=[
28
+ gr.Audio(label="Vocals"),
29
+ gr.Audio(label="Drums"),
30
+ gr.Audio(label="Bass"),
31
+ gr.Audio(label="Other"),
32
+ ],
33
+ title="Demucs v4 Stem Separator",
34
+ description="Upload a song to separate vocals, drums, bass, and other using Facebook's Demucs model.",
35
+ )
36
+
37
+ demo.launch()
conf/config.yaml ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - dset: musdb44
4
+ - svd: default
5
+ - variant: default
6
+ - override hydra/hydra_logging: colorlog
7
+ - override hydra/job_logging: colorlog
8
+
9
+ dummy:
10
+ dset:
11
+ musdb: /checkpoint/defossez/datasets/musdbhq
12
+ musdb_samplerate: 44100
13
+ use_musdb: true # set to false to not use musdb as training data.
14
+ wav: # path to custom wav dataset
15
+ wav2: # second custom wav dataset
16
+ segment: 11
17
+ shift: 1
18
+ train_valid: false
19
+ full_cv: true
20
+ samplerate: 44100
21
+ channels: 2
22
+ normalize: true
23
+ metadata: ./metadata
24
+ sources: ['drums', 'bass', 'other', 'vocals']
25
+ valid_samples: # valid dataset size
26
+ backend: null # if provided select torchaudio backend.
27
+
28
+ test:
29
+ save: False
30
+ best: True
31
+ workers: 2
32
+ every: 20
33
+ split: true
34
+ shifts: 1
35
+ overlap: 0.25
36
+ sdr: true
37
+ metric: 'loss' # metric used for best model selection on the valid set, can also be nsdr
38
+ nonhq: # path to non hq MusDB for evaluation
39
+
40
+ epochs: 360
41
+ batch_size: 64
42
+ max_batches: # limit the number of batches per epoch, useful for debugging
43
+ # or if your dataset is gigantic.
44
+ optim:
45
+ lr: 3e-4
46
+ momentum: 0.9
47
+ beta2: 0.999
48
+ loss: l1 # l1 or mse
49
+ optim: adam
50
+ weight_decay: 0
51
+ clip_grad: 0
52
+
53
+ seed: 42
54
+ debug: false
55
+ valid_apply: true
56
+ flag:
57
+ save_every:
58
+ weights: [1., 1., 1., 1.] # weights over each source for the training/valid loss.
59
+
60
+ augment:
61
+ shift_same: false
62
+ repitch:
63
+ proba: 0.2
64
+ max_tempo: 12
65
+ remix:
66
+ proba: 1
67
+ group_size: 4
68
+ scale:
69
+ proba: 1
70
+ min: 0.25
71
+ max: 1.25
72
+ flip: true
73
+
74
+ continue_from: # continue from other XP, give the XP Dora signature.
75
+ continue_pretrained: # signature of a pretrained XP, this cannot be a bag of models.
76
+ pretrained_repo: # repo for pretrained model (default is official AWS)
77
+ continue_best: true
78
+ continue_opt: false
79
+
80
+ misc:
81
+ num_workers: 10
82
+ num_prints: 4
83
+ show: false
84
+ verbose: false
85
+
86
+ # List of decay for EMA at batch or epoch level, e.g. 0.999.
87
+ # Batch level EMA are kept on GPU for speed.
88
+ ema:
89
+ epoch: []
90
+ batch: []
91
+
92
+ use_train_segment: true # to remove
93
+ model_segment: # override the segment parameter for the model, usually 4 times the training segment.
94
+ model: demucs # see demucs/train.py for the possibilities, and config for each model hereafter.
95
+ demucs: # see demucs/demucs.py for a detailed description
96
+ # Channels
97
+ channels: 64
98
+ growth: 2
99
+ # Main structure
100
+ depth: 6
101
+ rewrite: true
102
+ lstm_layers: 0
103
+ # Convolutions
104
+ kernel_size: 8
105
+ stride: 4
106
+ context: 1
107
+ # Activations
108
+ gelu: true
109
+ glu: true
110
+ # Normalization
111
+ norm_groups: 4
112
+ norm_starts: 4
113
+ # DConv residual branch
114
+ dconv_depth: 2
115
+ dconv_mode: 1 # 1 = branch in encoder, 2 = in decoder, 3 = in both.
116
+ dconv_comp: 4
117
+ dconv_attn: 4
118
+ dconv_lstm: 4
119
+ dconv_init: 1e-4
120
+ # Pre/post treatment
121
+ resample: true
122
+ normalize: false
123
+ # Weight init
124
+ rescale: 0.1
125
+
126
+ hdemucs: # see demucs/hdemucs.py for a detailed description
127
+ # Channels
128
+ channels: 48
129
+ channels_time:
130
+ growth: 2
131
+ # STFT
132
+ nfft: 4096
133
+ wiener_iters: 0
134
+ end_iters: 0
135
+ wiener_residual: false
136
+ cac: true
137
+ # Main structure
138
+ depth: 6
139
+ rewrite: true
140
+ hybrid: true
141
+ hybrid_old: false
142
+ # Frequency Branch
143
+ multi_freqs: []
144
+ multi_freqs_depth: 3
145
+ freq_emb: 0.2
146
+ emb_scale: 10
147
+ emb_smooth: true
148
+ # Convolutions
149
+ kernel_size: 8
150
+ stride: 4
151
+ time_stride: 2
152
+ context: 1
153
+ context_enc: 0
154
+ # normalization
155
+ norm_starts: 4
156
+ norm_groups: 4
157
+ # DConv residual branch
158
+ dconv_mode: 1
159
+ dconv_depth: 2
160
+ dconv_comp: 4
161
+ dconv_attn: 4
162
+ dconv_lstm: 4
163
+ dconv_init: 1e-3
164
+ # Weight init
165
+ rescale: 0.1
166
+
167
+ # Torchaudio implementation of HDemucs
168
+ torch_hdemucs:
169
+ # Channels
170
+ channels: 48
171
+ growth: 2
172
+ # STFT
173
+ nfft: 4096
174
+ # Main structure
175
+ depth: 6
176
+ freq_emb: 0.2
177
+ emb_scale: 10
178
+ emb_smooth: true
179
+ # Convolutions
180
+ kernel_size: 8
181
+ stride: 4
182
+ time_stride: 2
183
+ context: 1
184
+ context_enc: 0
185
+ # normalization
186
+ norm_starts: 4
187
+ norm_groups: 4
188
+ # DConv residual branch
189
+ dconv_depth: 2
190
+ dconv_comp: 4
191
+ dconv_attn: 4
192
+ dconv_lstm: 4
193
+ dconv_init: 1e-3
194
+
195
+ htdemucs: # see demucs/htdemucs.py for a detailed description
196
+ # Channels
197
+ channels: 48
198
+ channels_time:
199
+ growth: 2
200
+ # STFT
201
+ nfft: 4096
202
+ wiener_iters: 0
203
+ end_iters: 0
204
+ wiener_residual: false
205
+ cac: true
206
+ # Main structure
207
+ depth: 4
208
+ rewrite: true
209
+ # Frequency Branch
210
+ multi_freqs: []
211
+ multi_freqs_depth: 3
212
+ freq_emb: 0.2
213
+ emb_scale: 10
214
+ emb_smooth: true
215
+ # Convolutions
216
+ kernel_size: 8
217
+ stride: 4
218
+ time_stride: 2
219
+ context: 1
220
+ context_enc: 0
221
+ # normalization
222
+ norm_starts: 4
223
+ norm_groups: 4
224
+ # DConv residual branch
225
+ dconv_mode: 1
226
+ dconv_depth: 2
227
+ dconv_comp: 8
228
+ dconv_init: 1e-3
229
+ # Before the Transformer
230
+ bottom_channels: 0
231
+ # CrossTransformer
232
+ # ------ Common to all
233
+ # Regular parameters
234
+ t_layers: 5
235
+ t_hidden_scale: 4.0
236
+ t_heads: 8
237
+ t_dropout: 0.0
238
+ t_layer_scale: True
239
+ t_gelu: True
240
+ # ------------- Positional Embedding
241
+ t_emb: sin
242
+ t_max_positions: 10000 # for the scaled embedding
243
+ t_max_period: 10000.0
244
+ t_weight_pos_embed: 1.0
245
+ t_cape_mean_normalize: True
246
+ t_cape_augment: True
247
+ t_cape_glob_loc_scale: [5000.0, 1.0, 1.4]
248
+ t_sin_random_shift: 0
249
+ # ------------- norm before a transformer encoder
250
+ t_norm_in: True
251
+ t_norm_in_group: False
252
+ # ------------- norm inside the encoder
253
+ t_group_norm: False
254
+ t_norm_first: True
255
+ t_norm_out: True
256
+ # ------------- optim
257
+ t_weight_decay: 0.0
258
+ t_lr:
259
+ # ------------- sparsity
260
+ t_sparse_self_attn: False
261
+ t_sparse_cross_attn: False
262
+ t_mask_type: diag
263
+ t_mask_random_seed: 42
264
+ t_sparse_attn_window: 400
265
+ t_global_window: 100
266
+ t_sparsity: 0.95
267
+ t_auto_sparsity: False
268
+ # Cross Encoder First (False)
269
+ t_cross_first: False
270
+ # Weight init
271
+ rescale: 0.1
272
+
273
+ svd: # see svd.py for documentation
274
+ penalty: 0
275
+ min_size: 0.1
276
+ dim: 1
277
+ niters: 2
278
+ powm: false
279
+ proba: 1
280
+ conv_only: false
281
+ convtr: false
282
+ bs: 1
283
+
284
+ quant: # quantization hyper params
285
+ diffq: # diffq penalty, typically 1e-4 or 3e-4
286
+ qat: # use QAT with a fixed number of bits (not as good as diffq)
287
+ min_size: 0.2
288
+ group_size: 8
289
+
290
+ dora:
291
+ dir: outputs
292
+ exclude: ["misc.*", "slurm.*", 'test.reval', 'flag', 'dset.backend']
293
+
294
+ slurm:
295
+ time: 4320
296
+ constraint: volta32gb
297
+ setup: ['module load cudnn/v8.4.1.50-cuda.11.6 NCCL/2.11.4-6-cuda.11.6 cuda/11.6']
298
+
299
+ # Hydra config
300
+ hydra:
301
+ job_logging:
302
+ formatters:
303
+ colorlog:
304
+ datefmt: "%m-%d %H:%M:%S"
conf/dset/aetl.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # automix dataset with Musdb, extra training data and the test set of Musdb.
4
+ # This used even more remixes than auto_extra_test.
5
+ dset:
6
+ wav: /checkpoint/defossez/datasets/aetl
7
+ samplerate: 44100
8
+ channels: 2
9
+ epochs: 320
10
+ max_batches: 500
11
+
12
+ augment:
13
+ shift_same: true
14
+ scale:
15
+ proba: 0.
16
+ remix:
17
+ proba: 0
18
+ repitch:
19
+ proba: 0
conf/dset/auto_extra_test.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # automix dataset with Musdb, extra training data and the test set of Musdb.
4
+ dset:
5
+ wav: /checkpoint/defossez/datasets/automix_extra_test2
6
+ samplerate: 44100
7
+ channels: 2
8
+ epochs: 320
9
+ max_batches: 500
10
+
11
+ augment:
12
+ shift_same: true
13
+ scale:
14
+ proba: 0.
15
+ remix:
16
+ proba: 0
17
+ repitch:
18
+ proba: 0
conf/dset/auto_mus.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Automix dataset based on musdb train set.
4
+ dset:
5
+ wav: /checkpoint/defossez/datasets/automix_musdb
6
+ samplerate: 44100
7
+ channels: 2
8
+ epochs: 360
9
+ max_batches: 300
10
+ test:
11
+ every: 4
12
+
13
+ augment:
14
+ shift_same: true
15
+ scale:
16
+ proba: 0.5
17
+ remix:
18
+ proba: 0
19
+ repitch:
20
+ proba: 0
conf/dset/extra44.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Musdb + extra tracks
4
+ dset:
5
+ wav: /checkpoint/defossez/datasets/allstems_44/
6
+ samplerate: 44100
7
+ channels: 2
8
+ epochs: 320
conf/dset/extra_mmi_goodclean.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Musdb + extra tracks
4
+ dset:
5
+ wav: /checkpoint/defossez/datasets/allstems_44/
6
+ wav2: /checkpoint/defossez/datasets/mmi44_goodclean
7
+ samplerate: 44100
8
+ channels: 2
9
+ wav2_weight: null
10
+ wav2_valid: false
11
+ valid_samples: 100
12
+ epochs: 1200
conf/dset/extra_test.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Musdb + extra tracks + test set from musdb.
4
+ dset:
5
+ wav: /checkpoint/defossez/datasets/allstems_test_44/
6
+ samplerate: 44100
7
+ channels: 2
8
+ epochs: 320
9
+ max_batches: 700
10
+ test:
11
+ sdr: false
12
+ every: 500
conf/dset/musdb44.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ dset:
4
+ samplerate: 44100
5
+ channels: 2
conf/dset/sdx23_bleeding.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Musdb + extra tracks
4
+ dset:
5
+ wav: /shared/home/defossez/data/datasets/moisesdb23_bleeding_v1.0/
6
+ use_musdb: false
7
+ samplerate: 44100
8
+ channels: 2
9
+ backend: soundfile # must use soundfile as some mixture would clip with sox.
10
+ epochs: 320
conf/dset/sdx23_labelnoise.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Musdb + extra tracks
4
+ dset:
5
+ wav: /shared/home/defossez/data/datasets/moisesdb23_labelnoise_v1.0
6
+ use_musdb: false
7
+ samplerate: 44100
8
+ channels: 2
9
+ backend: soundfile # must use soundfile as some mixture would clip with sox.
10
+ epochs: 320
conf/svd/base.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ svd:
4
+ penalty: 0
5
+ min_size: 1
6
+ dim: 50
7
+ niters: 4
8
+ powm: false
9
+ proba: 1
10
+ conv_only: false
11
+ convtr: false # ideally this should be true, but some models were trained with this to false.
12
+
13
+ optim:
14
+ beta2: 0.9998
conf/svd/base2.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ svd:
4
+ penalty: 0
5
+ min_size: 1
6
+ dim: 100
7
+ niters: 4
8
+ powm: false
9
+ proba: 1
10
+ conv_only: false
11
+ convtr: true
12
+
13
+ optim:
14
+ beta2: 0.9998
conf/svd/default.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ # @package _global_
conf/variant/default.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ # @package _global_
conf/variant/example.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ model: hdemucs
4
+ hdemucs:
5
+ channels: 32
conf/variant/finetune.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ epochs: 4
4
+ batch_size: 16
5
+ optim:
6
+ lr: 0.0006
7
+ test:
8
+ every: 1
9
+ sdr: false
10
+ dset:
11
+ segment: 28
12
+ shift: 2
13
+
14
+ augment:
15
+ scale:
16
+ proba: 0
17
+ shift_same: true
18
+ remix:
19
+ proba: 0
demucs.png ADDED

Git LFS Details

  • SHA256: 7f8a53c1bbaa6c0268d358cd4cb9c2f1128907758aeb10a79789f7bbf61ded95
  • Pointer size: 131 Bytes
  • Size of remote file: 339 kB
demucs/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ __version__ = "4.1.0a2"
demucs/__main__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .separate import main
8
+
9
+ if __name__ == '__main__':
10
+ main()
demucs/api.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """API methods for demucs
8
+
9
+ Classes
10
+ -------
11
+ `demucs.api.Separator`: The base separator class
12
+
13
+ Functions
14
+ ---------
15
+ `demucs.api.save_audio`: Save an audio
16
+ `demucs.api.list_models`: Get models list
17
+
18
+ Examples
19
+ --------
20
+ See the end of this module (if __name__ == "__main__")
21
+ """
22
+
23
+ import subprocess
24
+
25
+ import torch as th
26
+ import torchaudio as ta
27
+
28
+ from dora.log import fatal
29
+ from pathlib import Path
30
+ from typing import Optional, Callable, Dict, Tuple, Union
31
+
32
+ from .apply import apply_model, _replace_dict
33
+ from .audio import AudioFile, convert_audio, save_audio
34
+ from .pretrained import get_model, _parse_remote_files, REMOTE_ROOT
35
+ from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo
36
+
37
+
38
+ class LoadAudioError(Exception):
39
+ pass
40
+
41
+
42
+ class LoadModelError(Exception):
43
+ pass
44
+
45
+
46
+ class _NotProvided:
47
+ pass
48
+
49
+
50
+ NotProvided = _NotProvided()
51
+
52
+
53
+ class Separator:
54
+ def __init__(
55
+ self,
56
+ model: str = "htdemucs",
57
+ repo: Optional[Path] = None,
58
+ device: str = "cuda" if th.cuda.is_available() else "cpu",
59
+ shifts: int = 1,
60
+ overlap: float = 0.25,
61
+ split: bool = True,
62
+ segment: Optional[int] = None,
63
+ jobs: int = 0,
64
+ progress: bool = False,
65
+ callback: Optional[Callable[[dict], None]] = None,
66
+ callback_arg: Optional[dict] = None,
67
+ ):
68
+ """
69
+ `class Separator`
70
+ =================
71
+
72
+ Parameters
73
+ ----------
74
+ model: Pretrained model name or signature. Default is htdemucs.
75
+ repo: Folder containing all pre-trained models for use.
76
+ segment: Length (in seconds) of each segment (only available if `split` is `True`). If \
77
+ not specified, will use the command line option.
78
+ shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \
79
+ apply the oppositve shift to the output. This is repeated `shifts` time and all \
80
+ predictions are averaged. This effectively makes the model time equivariant and \
81
+ improves SDR by up to 0.2 points. If not specified, will use the command line option.
82
+ split: If True, the input will be broken down into small chunks (length set by `segment`) \
83
+ and predictions will be performed individually on each and concatenated. Useful for \
84
+ model with large memory footprint like Tasnet. If not specified, will use the command \
85
+ line option.
86
+ overlap: The overlap between the splits. If not specified, will use the command line \
87
+ option.
88
+ device (torch.device, str, or None): If provided, device on which to execute the \
89
+ computation, otherwise `wav.device` is assumed. When `device` is different from \
90
+ `wav.device`, only local computations will be on `device`, while the entire tracks \
91
+ will be stored on `wav.device`. If not specified, will use the command line option.
92
+ jobs: Number of jobs. This can increase memory usage but will be much faster when \
93
+ multiple cores are available. If not specified, will use the command line option.
94
+ callback: A function will be called when the separation of a chunk starts or finished. \
95
+ The argument passed to the function will be a dict. For more information, please see \
96
+ the Callback section.
97
+ callback_arg: A dict containing private parameters to be passed to callback function. For \
98
+ more information, please see the Callback section.
99
+ progress: If true, show a progress bar.
100
+
101
+ Callback
102
+ --------
103
+ The function will be called with only one positional parameter whose type is `dict`. The
104
+ `callback_arg` will be combined with information of current separation progress. The
105
+ progress information will override the values in `callback_arg` if same key has been used.
106
+ To abort the separation, raise `KeyboardInterrupt`.
107
+
108
+ Progress information contains several keys (These keys will always exist):
109
+ - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.
110
+ - `shift_idx`: The index of shifts. Starts from 0.
111
+ - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't
112
+ mean that it is at the 441000 second of the audio, but the "frame" of the tensor.
113
+ - `state`: Could be `"start"` or `"end"`.
114
+ - `audio_length`: Length of the audio (in "frame" of the tensor).
115
+ - `models`: Count of submodels in the model.
116
+ """
117
+ self._name = model
118
+ self._repo = repo
119
+ self._load_model()
120
+ self.update_parameter(device=device, shifts=shifts, overlap=overlap, split=split,
121
+ segment=segment, jobs=jobs, progress=progress, callback=callback,
122
+ callback_arg=callback_arg)
123
+
124
+ def update_parameter(
125
+ self,
126
+ device: Union[str, _NotProvided] = NotProvided,
127
+ shifts: Union[int, _NotProvided] = NotProvided,
128
+ overlap: Union[float, _NotProvided] = NotProvided,
129
+ split: Union[bool, _NotProvided] = NotProvided,
130
+ segment: Optional[Union[int, _NotProvided]] = NotProvided,
131
+ jobs: Union[int, _NotProvided] = NotProvided,
132
+ progress: Union[bool, _NotProvided] = NotProvided,
133
+ callback: Optional[
134
+ Union[Callable[[dict], None], _NotProvided]
135
+ ] = NotProvided,
136
+ callback_arg: Optional[Union[dict, _NotProvided]] = NotProvided,
137
+ ):
138
+ """
139
+ Update the parameters of separation.
140
+
141
+ Parameters
142
+ ----------
143
+ segment: Length (in seconds) of each segment (only available if `split` is `True`). If \
144
+ not specified, will use the command line option.
145
+ shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \
146
+ apply the oppositve shift to the output. This is repeated `shifts` time and all \
147
+ predictions are averaged. This effectively makes the model time equivariant and \
148
+ improves SDR by up to 0.2 points. If not specified, will use the command line option.
149
+ split: If True, the input will be broken down into small chunks (length set by `segment`) \
150
+ and predictions will be performed individually on each and concatenated. Useful for \
151
+ model with large memory footprint like Tasnet. If not specified, will use the command \
152
+ line option.
153
+ overlap: The overlap between the splits. If not specified, will use the command line \
154
+ option.
155
+ device (torch.device, str, or None): If provided, device on which to execute the \
156
+ computation, otherwise `wav.device` is assumed. When `device` is different from \
157
+ `wav.device`, only local computations will be on `device`, while the entire tracks \
158
+ will be stored on `wav.device`. If not specified, will use the command line option.
159
+ jobs: Number of jobs. This can increase memory usage but will be much faster when \
160
+ multiple cores are available. If not specified, will use the command line option.
161
+ callback: A function will be called when the separation of a chunk starts or finished. \
162
+ The argument passed to the function will be a dict. For more information, please see \
163
+ the Callback section.
164
+ callback_arg: A dict containing private parameters to be passed to callback function. For \
165
+ more information, please see the Callback section.
166
+ progress: If true, show a progress bar.
167
+
168
+ Callback
169
+ --------
170
+ The function will be called with only one positional parameter whose type is `dict`. The
171
+ `callback_arg` will be combined with information of current separation progress. The
172
+ progress information will override the values in `callback_arg` if same key has been used.
173
+ To abort the separation, raise `KeyboardInterrupt`.
174
+
175
+ Progress information contains several keys (These keys will always exist):
176
+ - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.
177
+ - `shift_idx`: The index of shifts. Starts from 0.
178
+ - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't
179
+ mean that it is at the 441000 second of the audio, but the "frame" of the tensor.
180
+ - `state`: Could be `"start"` or `"end"`.
181
+ - `audio_length`: Length of the audio (in "frame" of the tensor).
182
+ - `models`: Count of submodels in the model.
183
+ """
184
+ if not isinstance(device, _NotProvided):
185
+ self._device = device
186
+ if not isinstance(shifts, _NotProvided):
187
+ self._shifts = shifts
188
+ if not isinstance(overlap, _NotProvided):
189
+ self._overlap = overlap
190
+ if not isinstance(split, _NotProvided):
191
+ self._split = split
192
+ if not isinstance(segment, _NotProvided):
193
+ self._segment = segment
194
+ if not isinstance(jobs, _NotProvided):
195
+ self._jobs = jobs
196
+ if not isinstance(progress, _NotProvided):
197
+ self._progress = progress
198
+ if not isinstance(callback, _NotProvided):
199
+ self._callback = callback
200
+ if not isinstance(callback_arg, _NotProvided):
201
+ self._callback_arg = callback_arg
202
+
203
+ def _load_model(self):
204
+ self._model = get_model(name=self._name, repo=self._repo)
205
+ if self._model is None:
206
+ raise LoadModelError("Failed to load model")
207
+ self._audio_channels = self._model.audio_channels
208
+ self._samplerate = self._model.samplerate
209
+
210
+ def _load_audio(self, track: Path):
211
+ errors = {}
212
+ wav = None
213
+
214
+ try:
215
+ wav = AudioFile(track).read(streams=0, samplerate=self._samplerate,
216
+ channels=self._audio_channels)
217
+ except FileNotFoundError:
218
+ errors["ffmpeg"] = "FFmpeg is not installed."
219
+ except subprocess.CalledProcessError:
220
+ errors["ffmpeg"] = "FFmpeg could not read the file."
221
+
222
+ if wav is None:
223
+ try:
224
+ wav, sr = ta.load(str(track))
225
+ except RuntimeError as err:
226
+ errors["torchaudio"] = err.args[0]
227
+ else:
228
+ wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)
229
+
230
+ if wav is None:
231
+ raise LoadAudioError(
232
+ "\n".join(
233
+ "When trying to load using {}, got the following error: {}".format(
234
+ backend, error
235
+ )
236
+ for backend, error in errors.items()
237
+ )
238
+ )
239
+ return wav
240
+
241
+ def separate_tensor(
242
+ self, wav: th.Tensor, sr: Optional[int] = None
243
+ ) -> Tuple[th.Tensor, Dict[str, th.Tensor]]:
244
+ """
245
+ Separate a loaded tensor.
246
+
247
+ Parameters
248
+ ----------
249
+ wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, \
250
+ while the second is the waveform of each channel. Type should be float32. \
251
+ e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels.
252
+ sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the \
253
+ model.
254
+
255
+ Returns
256
+ -------
257
+ A tuple, whose first element is the original wave and second element is a dict, whose keys
258
+ are the name of stems and values are separated waves. The original wave will have already
259
+ been resampled.
260
+
261
+ Notes
262
+ -----
263
+ Use this function with cautiousness. This function does not provide data verifying.
264
+ """
265
+ if sr is not None and sr != self.samplerate:
266
+ wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)
267
+ ref = wav.mean(0)
268
+ wav -= ref.mean()
269
+ wav /= ref.std() + 1e-8
270
+ out = apply_model(
271
+ self._model,
272
+ wav[None],
273
+ segment=self._segment,
274
+ shifts=self._shifts,
275
+ split=self._split,
276
+ overlap=self._overlap,
277
+ device=self._device,
278
+ num_workers=self._jobs,
279
+ callback=self._callback,
280
+ callback_arg=_replace_dict(
281
+ self._callback_arg, ("audio_length", wav.shape[1])
282
+ ),
283
+ progress=self._progress,
284
+ )
285
+ if out is None:
286
+ raise KeyboardInterrupt
287
+ out *= ref.std() + 1e-8
288
+ out += ref.mean()
289
+ wav *= ref.std() + 1e-8
290
+ wav += ref.mean()
291
+ return (wav, dict(zip(self._model.sources, out[0])))
292
+
293
+ def separate_audio_file(self, file: Path):
294
+ """
295
+ Separate an audio file. The method will automatically read the file.
296
+
297
+ Parameters
298
+ ----------
299
+ wav: Path of the file to be separated.
300
+
301
+ Returns
302
+ -------
303
+ A tuple, whose first element is the original wave and second element is a dict, whose keys
304
+ are the name of stems and values are separated waves. The original wave will have already
305
+ been resampled.
306
+ """
307
+ return self.separate_tensor(self._load_audio(file), self.samplerate)
308
+
309
+ @property
310
+ def samplerate(self):
311
+ return self._samplerate
312
+
313
+ @property
314
+ def audio_channels(self):
315
+ return self._audio_channels
316
+
317
+ @property
318
+ def model(self):
319
+ return self._model
320
+
321
+
322
+ def list_models(repo: Optional[Path] = None) -> Dict[str, Dict[str, Union[str, Path]]]:
323
+ """
324
+ List the available models. Please remember that not all the returned models can be
325
+ successfully loaded.
326
+
327
+ Parameters
328
+ ----------
329
+ repo: The repo whose models are to be listed.
330
+
331
+ Returns
332
+ -------
333
+ A dict with two keys ("single" for single models and "bag" for bag of models). The values are
334
+ lists whose components are strs.
335
+ """
336
+ model_repo: ModelOnlyRepo
337
+ if repo is None:
338
+ models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
339
+ model_repo = RemoteRepo(models)
340
+ bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
341
+ else:
342
+ if not repo.is_dir():
343
+ fatal(f"{repo} must exist and be a directory.")
344
+ model_repo = LocalRepo(repo)
345
+ bag_repo = BagOnlyRepo(repo, model_repo)
346
+ return {"single": model_repo.list_model(), "bag": bag_repo.list_model()}
347
+
348
+
349
+ if __name__ == "__main__":
350
+ # Test API functions
351
+ # two-stem not supported
352
+
353
+ from .separate import get_parser
354
+
355
+ args = get_parser().parse_args()
356
+ separator = Separator(
357
+ model=args.name,
358
+ repo=args.repo,
359
+ device=args.device,
360
+ shifts=args.shifts,
361
+ overlap=args.overlap,
362
+ split=args.split,
363
+ segment=args.segment,
364
+ jobs=args.jobs,
365
+ callback=print
366
+ )
367
+ out = args.out / args.name
368
+ out.mkdir(parents=True, exist_ok=True)
369
+ for file in args.tracks:
370
+ separated = separator.separate_audio_file(file)[1]
371
+ if args.mp3:
372
+ ext = "mp3"
373
+ elif args.flac:
374
+ ext = "flac"
375
+ else:
376
+ ext = "wav"
377
+ kwargs = {
378
+ "samplerate": separator.samplerate,
379
+ "bitrate": args.mp3_bitrate,
380
+ "clip": args.clip_mode,
381
+ "as_float": args.float32,
382
+ "bits_per_sample": 24 if args.int24 else 16,
383
+ }
384
+ for stem, source in separated.items():
385
+ stem = out / args.filename.format(
386
+ track=Path(file).name.rsplit(".", 1)[0],
387
+ trackext=Path(file).name.rsplit(".", 1)[-1],
388
+ stem=stem,
389
+ ext=ext,
390
+ )
391
+ stem.parent.mkdir(parents=True, exist_ok=True)
392
+ save_audio(source, str(stem), **kwargs)
demucs/apply.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Code to apply a model to a mix. It will handle chunking with overlaps and
8
+ inteprolation between chunks, as well as the "shift trick".
9
+ """
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ import copy
12
+ import random
13
+ from threading import Lock
14
+ import typing as tp
15
+
16
+ import torch as th
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+ import tqdm
20
+
21
+ from .demucs import Demucs
22
+ from .hdemucs import HDemucs
23
+ from .htdemucs import HTDemucs
24
+ from .utils import center_trim, DummyPoolExecutor
25
+
26
+ Model = tp.Union[Demucs, HDemucs, HTDemucs]
27
+
28
+
29
+ class BagOfModels(nn.Module):
30
+ def __init__(self, models: tp.List[Model],
31
+ weights: tp.Optional[tp.List[tp.List[float]]] = None,
32
+ segment: tp.Optional[float] = None):
33
+ """
34
+ Represents a bag of models with specific weights.
35
+ You should call `apply_model` rather than calling directly the forward here for
36
+ optimal performance.
37
+
38
+ Args:
39
+ models (list[nn.Module]): list of Demucs/HDemucs models.
40
+ weights (list[list[float]]): list of weights. If None, assumed to
41
+ be all ones, otherwise it should be a list of N list (N number of models),
42
+ each containing S floats (S number of sources).
43
+ segment (None or float): overrides the `segment` attribute of each model
44
+ (this is performed inplace, be careful is you reuse the models passed).
45
+ """
46
+ super().__init__()
47
+ assert len(models) > 0
48
+ first = models[0]
49
+ for other in models:
50
+ assert other.sources == first.sources
51
+ assert other.samplerate == first.samplerate
52
+ assert other.audio_channels == first.audio_channels
53
+ if segment is not None:
54
+ if not isinstance(other, HTDemucs) and segment > other.segment:
55
+ other.segment = segment
56
+
57
+ self.audio_channels = first.audio_channels
58
+ self.samplerate = first.samplerate
59
+ self.sources = first.sources
60
+ self.models = nn.ModuleList(models)
61
+
62
+ if weights is None:
63
+ weights = [[1. for _ in first.sources] for _ in models]
64
+ else:
65
+ assert len(weights) == len(models)
66
+ for weight in weights:
67
+ assert len(weight) == len(first.sources)
68
+ self.weights = weights
69
+
70
+ @property
71
+ def max_allowed_segment(self) -> float:
72
+ max_allowed_segment = float('inf')
73
+ for model in self.models:
74
+ if isinstance(model, HTDemucs):
75
+ max_allowed_segment = min(max_allowed_segment, float(model.segment))
76
+ return max_allowed_segment
77
+
78
+ def forward(self, x):
79
+ raise NotImplementedError("Call `apply_model` on this.")
80
+
81
+
82
+ class TensorChunk:
83
+ def __init__(self, tensor, offset=0, length=None):
84
+ total_length = tensor.shape[-1]
85
+ assert offset >= 0
86
+ assert offset < total_length
87
+
88
+ if length is None:
89
+ length = total_length - offset
90
+ else:
91
+ length = min(total_length - offset, length)
92
+
93
+ if isinstance(tensor, TensorChunk):
94
+ self.tensor = tensor.tensor
95
+ self.offset = offset + tensor.offset
96
+ else:
97
+ self.tensor = tensor
98
+ self.offset = offset
99
+ self.length = length
100
+ self.device = tensor.device
101
+
102
+ @property
103
+ def shape(self):
104
+ shape = list(self.tensor.shape)
105
+ shape[-1] = self.length
106
+ return shape
107
+
108
+ def padded(self, target_length):
109
+ delta = target_length - self.length
110
+ total_length = self.tensor.shape[-1]
111
+ assert delta >= 0
112
+
113
+ start = self.offset - delta // 2
114
+ end = start + target_length
115
+
116
+ correct_start = max(0, start)
117
+ correct_end = min(total_length, end)
118
+
119
+ pad_left = correct_start - start
120
+ pad_right = end - correct_end
121
+
122
+ out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
123
+ assert out.shape[-1] == target_length
124
+ return out
125
+
126
+
127
+ def tensor_chunk(tensor_or_chunk):
128
+ if isinstance(tensor_or_chunk, TensorChunk):
129
+ return tensor_or_chunk
130
+ else:
131
+ assert isinstance(tensor_or_chunk, th.Tensor)
132
+ return TensorChunk(tensor_or_chunk)
133
+
134
+
135
+ def _replace_dict(_dict: tp.Optional[dict], *subs: tp.Tuple[tp.Hashable, tp.Any]) -> dict:
136
+ if _dict is None:
137
+ _dict = {}
138
+ else:
139
+ _dict = copy.copy(_dict)
140
+ for key, value in subs:
141
+ _dict[key] = value
142
+ return _dict
143
+
144
+
145
+ def apply_model(model: tp.Union[BagOfModels, Model],
146
+ mix: tp.Union[th.Tensor, TensorChunk],
147
+ shifts: int = 1, split: bool = True,
148
+ overlap: float = 0.25, transition_power: float = 1.,
149
+ progress: bool = False, device=None,
150
+ num_workers: int = 0, segment: tp.Optional[float] = None,
151
+ pool=None, lock=None,
152
+ callback: tp.Optional[tp.Callable[[dict], None]] = None,
153
+ callback_arg: tp.Optional[dict] = None) -> th.Tensor:
154
+ """
155
+ Apply model to a given mixture.
156
+
157
+ Args:
158
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
159
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
160
+ all predictions are averaged. This effectively makes the model time equivariant
161
+ and improves SDR by up to 0.2 points.
162
+ split (bool): if True, the input will be broken down in 8 seconds extracts
163
+ and predictions will be performed individually on each and concatenated.
164
+ Useful for model with large memory footprint like Tasnet.
165
+ progress (bool): if True, show a progress bar (requires split=True)
166
+ device (torch.device, str, or None): if provided, device on which to
167
+ execute the computation, otherwise `mix.device` is assumed.
168
+ When `device` is different from `mix.device`, only local computations will
169
+ be on `device`, while the entire tracks will be stored on `mix.device`.
170
+ num_workers (int): if non zero, device is 'cpu', how many threads to
171
+ use in parallel.
172
+ segment (float or None): override the model segment parameter.
173
+ """
174
+ if device is None:
175
+ device = mix.device
176
+ else:
177
+ device = th.device(device)
178
+ if pool is None:
179
+ if num_workers > 0 and device.type == 'cpu':
180
+ pool = ThreadPoolExecutor(num_workers)
181
+ else:
182
+ pool = DummyPoolExecutor()
183
+ if lock is None:
184
+ lock = Lock()
185
+ callback_arg = _replace_dict(
186
+ callback_arg, *{"model_idx_in_bag": 0, "shift_idx": 0, "segment_offset": 0}.items()
187
+ )
188
+ kwargs: tp.Dict[str, tp.Any] = {
189
+ 'shifts': shifts,
190
+ 'split': split,
191
+ 'overlap': overlap,
192
+ 'transition_power': transition_power,
193
+ 'progress': progress,
194
+ 'device': device,
195
+ 'pool': pool,
196
+ 'segment': segment,
197
+ 'lock': lock,
198
+ }
199
+ out: tp.Union[float, th.Tensor]
200
+ res: tp.Union[float, th.Tensor]
201
+ if isinstance(model, BagOfModels):
202
+ # Special treatment for bag of model.
203
+ # We explicitely apply multiple times `apply_model` so that the random shifts
204
+ # are different for each model.
205
+ estimates: tp.Union[float, th.Tensor] = 0.
206
+ totals = [0.] * len(model.sources)
207
+ callback_arg["models"] = len(model.models)
208
+ for sub_model, model_weights in zip(model.models, model.weights):
209
+ kwargs["callback"] = ((
210
+ lambda d, i=callback_arg["model_idx_in_bag"]: callback(
211
+ _replace_dict(d, ("model_idx_in_bag", i))) if callback else None)
212
+ )
213
+ original_model_device = next(iter(sub_model.parameters())).device
214
+ sub_model.to(device)
215
+
216
+ res = apply_model(sub_model, mix, **kwargs, callback_arg=callback_arg)
217
+ out = res
218
+ sub_model.to(original_model_device)
219
+ for k, inst_weight in enumerate(model_weights):
220
+ out[:, k, :, :] *= inst_weight
221
+ totals[k] += inst_weight
222
+ estimates += out
223
+ del out
224
+ callback_arg["model_idx_in_bag"] += 1
225
+
226
+ assert isinstance(estimates, th.Tensor)
227
+ for k in range(estimates.shape[1]):
228
+ estimates[:, k, :, :] /= totals[k]
229
+ return estimates
230
+
231
+ if "models" not in callback_arg:
232
+ callback_arg["models"] = 1
233
+ model.to(device)
234
+ model.eval()
235
+ assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
236
+ batch, channels, length = mix.shape
237
+ if shifts:
238
+ kwargs['shifts'] = 0
239
+ max_shift = int(0.5 * model.samplerate)
240
+ mix = tensor_chunk(mix)
241
+ assert isinstance(mix, TensorChunk)
242
+ padded_mix = mix.padded(length + 2 * max_shift)
243
+ out = 0.
244
+ for shift_idx in range(shifts):
245
+ offset = random.randint(0, max_shift)
246
+ shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
247
+ kwargs["callback"] = (
248
+ (lambda d, i=shift_idx: callback(_replace_dict(d, ("shift_idx", i)))
249
+ if callback else None)
250
+ )
251
+ res = apply_model(model, shifted, **kwargs, callback_arg=callback_arg)
252
+ shifted_out = res
253
+ out += shifted_out[..., max_shift - offset:]
254
+ out /= shifts
255
+ assert isinstance(out, th.Tensor)
256
+ return out
257
+ elif split:
258
+ kwargs['split'] = False
259
+ out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
260
+ sum_weight = th.zeros(length, device=mix.device)
261
+ if segment is None:
262
+ segment = model.segment
263
+ assert segment is not None and segment > 0.
264
+ segment_length: int = int(model.samplerate * segment)
265
+ stride = int((1 - overlap) * segment_length)
266
+ offsets = range(0, length, stride)
267
+ scale = float(format(stride / model.samplerate, ".2f"))
268
+ # We start from a triangle shaped weight, with maximal weight in the middle
269
+ # of the segment. Then we normalize and take to the power `transition_power`.
270
+ # Large values of transition power will lead to sharper transitions.
271
+ weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device),
272
+ th.arange(segment_length - segment_length // 2, 0, -1, device=device)])
273
+ assert len(weight) == segment_length
274
+ # If the overlap < 50%, this will translate to linear transition when
275
+ # transition_power is 1.
276
+ weight = (weight / weight.max())**transition_power
277
+ futures = []
278
+ for offset in offsets:
279
+ chunk = TensorChunk(mix, offset, segment_length)
280
+ future = pool.submit(apply_model, model, chunk, **kwargs, callback_arg=callback_arg,
281
+ callback=(lambda d, i=offset:
282
+ callback(_replace_dict(d, ("segment_offset", i)))
283
+ if callback else None))
284
+ futures.append((future, offset))
285
+ offset += segment_length
286
+ if progress:
287
+ futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
288
+ for future, offset in futures:
289
+ try:
290
+ chunk_out = future.result() # type: th.Tensor
291
+ except Exception:
292
+ pool.shutdown(wait=True, cancel_futures=True)
293
+ raise
294
+ chunk_length = chunk_out.shape[-1]
295
+ out[..., offset:offset + segment_length] += (
296
+ weight[:chunk_length] * chunk_out).to(mix.device)
297
+ sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device)
298
+ assert sum_weight.min() > 0
299
+ out /= sum_weight
300
+ assert isinstance(out, th.Tensor)
301
+ return out
302
+ else:
303
+ valid_length: int
304
+ if isinstance(model, HTDemucs) and segment is not None:
305
+ valid_length = int(segment * model.samplerate)
306
+ elif hasattr(model, 'valid_length'):
307
+ valid_length = model.valid_length(length) # type: ignore
308
+ else:
309
+ valid_length = length
310
+ mix = tensor_chunk(mix)
311
+ assert isinstance(mix, TensorChunk)
312
+ padded_mix = mix.padded(valid_length).to(device)
313
+ with lock:
314
+ if callback is not None:
315
+ callback(_replace_dict(callback_arg, ("state", "start"))) # type: ignore
316
+ with th.no_grad():
317
+ out = model(padded_mix)
318
+ with lock:
319
+ if callback is not None:
320
+ callback(_replace_dict(callback_arg, ("state", "end"))) # type: ignore
321
+ assert isinstance(out, th.Tensor)
322
+ return center_trim(out, length)
demucs/audio.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import json
7
+ import subprocess as sp
8
+ from pathlib import Path
9
+
10
+ import lameenc
11
+ import julius
12
+ import numpy as np
13
+ import torch
14
+ import torchaudio as ta
15
+ import typing as tp
16
+
17
+ from .utils import temp_filenames
18
+
19
+
20
+ def _read_info(path):
21
+ stdout_data = sp.check_output([
22
+ 'ffprobe', "-loglevel", "panic",
23
+ str(path), '-print_format', 'json', '-show_format', '-show_streams'
24
+ ])
25
+ return json.loads(stdout_data.decode('utf-8'))
26
+
27
+
28
+ class AudioFile:
29
+ """
30
+ Allows to read audio from any format supported by ffmpeg, as well as resampling or
31
+ converting to mono on the fly. See :method:`read` for more details.
32
+ """
33
+ def __init__(self, path: Path):
34
+ self.path = Path(path)
35
+ self._info = None
36
+
37
+ def __repr__(self):
38
+ features = [("path", self.path)]
39
+ features.append(("samplerate", self.samplerate()))
40
+ features.append(("channels", self.channels()))
41
+ features.append(("streams", len(self)))
42
+ features_str = ", ".join(f"{name}={value}" for name, value in features)
43
+ return f"AudioFile({features_str})"
44
+
45
+ @property
46
+ def info(self):
47
+ if self._info is None:
48
+ self._info = _read_info(self.path)
49
+ return self._info
50
+
51
+ @property
52
+ def duration(self):
53
+ return float(self.info['format']['duration'])
54
+
55
+ @property
56
+ def _audio_streams(self):
57
+ return [
58
+ index for index, stream in enumerate(self.info["streams"])
59
+ if stream["codec_type"] == "audio"
60
+ ]
61
+
62
+ def __len__(self):
63
+ return len(self._audio_streams)
64
+
65
+ def channels(self, stream=0):
66
+ return int(self.info['streams'][self._audio_streams[stream]]['channels'])
67
+
68
+ def samplerate(self, stream=0):
69
+ return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
70
+
71
+ def read(self,
72
+ seek_time=None,
73
+ duration=None,
74
+ streams=slice(None),
75
+ samplerate=None,
76
+ channels=None):
77
+ """
78
+ Slightly more efficient implementation than stempeg,
79
+ in particular, this will extract all stems at once
80
+ rather than having to loop over one file multiple times
81
+ for each stream.
82
+
83
+ Args:
84
+ seek_time (float): seek time in seconds or None if no seeking is needed.
85
+ duration (float): duration in seconds to extract or None to extract until the end.
86
+ streams (slice, int or list): streams to extract, can be a single int, a list or
87
+ a slice. If it is a slice or list, the output will be of size [S, C, T]
88
+ with S the number of streams, C the number of channels and T the number of samples.
89
+ If it is an int, the output will be [C, T].
90
+ samplerate (int): if provided, will resample on the fly. If None, no resampling will
91
+ be done. Original sampling rate can be obtained with :method:`samplerate`.
92
+ channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
93
+ as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
94
+ See https://sound.stackexchange.com/a/42710.
95
+ Our definition of mono is simply the average of the two channels. Any other
96
+ value will be ignored.
97
+ """
98
+ streams = np.array(range(len(self)))[streams]
99
+ single = not isinstance(streams, np.ndarray)
100
+ if single:
101
+ streams = [streams]
102
+
103
+ if duration is None:
104
+ target_size = None
105
+ query_duration = None
106
+ else:
107
+ target_size = int((samplerate or self.samplerate()) * duration)
108
+ query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
109
+
110
+ with temp_filenames(len(streams)) as filenames:
111
+ command = ['ffmpeg', '-y']
112
+ command += ['-loglevel', 'panic']
113
+ if seek_time:
114
+ command += ['-ss', str(seek_time)]
115
+ command += ['-i', str(self.path)]
116
+ for stream, filename in zip(streams, filenames):
117
+ command += ['-map', f'0:{self._audio_streams[stream]}']
118
+ if query_duration is not None:
119
+ command += ['-t', str(query_duration)]
120
+ command += ['-threads', '1']
121
+ command += ['-f', 'f32le']
122
+ if samplerate is not None:
123
+ command += ['-ar', str(samplerate)]
124
+ command += [filename]
125
+
126
+ sp.run(command, check=True)
127
+ wavs = []
128
+ for filename in filenames:
129
+ wav = np.fromfile(filename, dtype=np.float32)
130
+ wav = torch.from_numpy(wav)
131
+ wav = wav.view(-1, self.channels()).t()
132
+ if channels is not None:
133
+ wav = convert_audio_channels(wav, channels)
134
+ if target_size is not None:
135
+ wav = wav[..., :target_size]
136
+ wavs.append(wav)
137
+ wav = torch.stack(wavs, dim=0)
138
+ if single:
139
+ wav = wav[0]
140
+ return wav
141
+
142
+
143
+ def convert_audio_channels(wav, channels=2):
144
+ """Convert audio to the given number of channels."""
145
+ *shape, src_channels, length = wav.shape
146
+ if src_channels == channels:
147
+ pass
148
+ elif channels == 1:
149
+ # Case 1:
150
+ # The caller asked 1-channel audio, but the stream have multiple
151
+ # channels, downmix all channels.
152
+ wav = wav.mean(dim=-2, keepdim=True)
153
+ elif src_channels == 1:
154
+ # Case 2:
155
+ # The caller asked for multiple channels, but the input file have
156
+ # one single channel, replicate the audio over all channels.
157
+ wav = wav.expand(*shape, channels, length)
158
+ elif src_channels >= channels:
159
+ # Case 3:
160
+ # The caller asked for multiple channels, and the input file have
161
+ # more channels than requested. In that case return the first channels.
162
+ wav = wav[..., :channels, :]
163
+ else:
164
+ # Case 4: What is a reasonable choice here?
165
+ raise ValueError('The audio file has less channels than requested but is not mono.')
166
+ return wav
167
+
168
+
169
+ def convert_audio(wav, from_samplerate, to_samplerate, channels) -> torch.Tensor:
170
+ """Convert audio from a given samplerate to a target one and target number of channels."""
171
+ wav = convert_audio_channels(wav, channels)
172
+ return julius.resample_frac(wav, from_samplerate, to_samplerate)
173
+
174
+
175
+ def i16_pcm(wav):
176
+ """Convert audio to 16 bits integer PCM format."""
177
+ if wav.dtype.is_floating_point:
178
+ return (wav.clamp_(-1, 1) * (2**15 - 1)).short()
179
+ else:
180
+ return wav
181
+
182
+
183
+ def f32_pcm(wav):
184
+ """Convert audio to float 32 bits PCM format."""
185
+ if wav.dtype.is_floating_point:
186
+ return wav
187
+ else:
188
+ return wav.float() / (2**15 - 1)
189
+
190
+
191
+ def as_dtype_pcm(wav, dtype):
192
+ """Convert audio to either f32 pcm or i16 pcm depending on the given dtype."""
193
+ if wav.dtype.is_floating_point:
194
+ return f32_pcm(wav)
195
+ else:
196
+ return i16_pcm(wav)
197
+
198
+
199
+ def encode_mp3(wav, path, samplerate=44100, bitrate=320, quality=2, verbose=False):
200
+ """Save given audio as mp3. This should work on all OSes."""
201
+ C, T = wav.shape
202
+ wav = i16_pcm(wav)
203
+ encoder = lameenc.Encoder()
204
+ encoder.set_bit_rate(bitrate)
205
+ encoder.set_in_sample_rate(samplerate)
206
+ encoder.set_channels(C)
207
+ encoder.set_quality(quality) # 2-highest, 7-fastest
208
+ if not verbose:
209
+ encoder.silence()
210
+ wav = wav.data.cpu()
211
+ wav = wav.transpose(0, 1).numpy()
212
+ mp3_data = encoder.encode(wav.tobytes())
213
+ mp3_data += encoder.flush()
214
+ with open(path, "wb") as f:
215
+ f.write(mp3_data)
216
+
217
+
218
+ def prevent_clip(wav, mode='rescale'):
219
+ """
220
+ different strategies for avoiding raw clipping.
221
+ """
222
+ if mode is None or mode == 'none':
223
+ return wav
224
+ assert wav.dtype.is_floating_point, "too late for clipping"
225
+ if mode == 'rescale':
226
+ wav = wav / max(1.01 * wav.abs().max(), 1)
227
+ elif mode == 'clamp':
228
+ wav = wav.clamp(-0.99, 0.99)
229
+ elif mode == 'tanh':
230
+ wav = torch.tanh(wav)
231
+ else:
232
+ raise ValueError(f"Invalid mode {mode}")
233
+ return wav
234
+
235
+
236
+ def save_audio(wav: torch.Tensor,
237
+ path: tp.Union[str, Path],
238
+ samplerate: int,
239
+ bitrate: int = 320,
240
+ clip: tp.Literal["rescale", "clamp", "tanh", "none"] = 'rescale',
241
+ bits_per_sample: tp.Literal[16, 24, 32] = 16,
242
+ as_float: bool = False,
243
+ preset: tp.Literal[2, 3, 4, 5, 6, 7] = 2):
244
+ """Save audio file, automatically preventing clipping if necessary
245
+ based on the given `clip` strategy. If the path ends in `.mp3`, this
246
+ will save as mp3 with the given `bitrate`. Use `preset` to set mp3 quality:
247
+ 2 for highest quality, 7 for fastest speed
248
+ """
249
+ wav = prevent_clip(wav, mode=clip)
250
+ path = Path(path)
251
+ suffix = path.suffix.lower()
252
+ if suffix == ".mp3":
253
+ encode_mp3(wav, path, samplerate, bitrate, preset, verbose=True)
254
+ elif suffix == ".wav":
255
+ if as_float:
256
+ bits_per_sample = 32
257
+ encoding = 'PCM_F'
258
+ else:
259
+ encoding = 'PCM_S'
260
+ ta.save(str(path), wav, sample_rate=samplerate,
261
+ encoding=encoding, bits_per_sample=bits_per_sample)
262
+ elif suffix == ".flac":
263
+ ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample)
264
+ else:
265
+ raise ValueError(f"Invalid suffix for path: {suffix}")
demucs/augment.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Data augmentations.
7
+ """
8
+
9
+ import random
10
+ import torch as th
11
+ from torch import nn
12
+
13
+
14
+ class Shift(nn.Module):
15
+ """
16
+ Randomly shift audio in time by up to `shift` samples.
17
+ """
18
+ def __init__(self, shift=8192, same=False):
19
+ super().__init__()
20
+ self.shift = shift
21
+ self.same = same
22
+
23
+ def forward(self, wav):
24
+ batch, sources, channels, time = wav.size()
25
+ length = time - self.shift
26
+ if self.shift > 0:
27
+ if not self.training:
28
+ wav = wav[..., :length]
29
+ else:
30
+ srcs = 1 if self.same else sources
31
+ offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device)
32
+ offsets = offsets.expand(-1, sources, channels, -1)
33
+ indexes = th.arange(length, device=wav.device)
34
+ wav = wav.gather(3, indexes + offsets)
35
+ return wav
36
+
37
+
38
+ class FlipChannels(nn.Module):
39
+ """
40
+ Flip left-right channels.
41
+ """
42
+ def forward(self, wav):
43
+ batch, sources, channels, time = wav.size()
44
+ if self.training and wav.size(2) == 2:
45
+ left = th.randint(2, (batch, sources, 1, 1), device=wav.device)
46
+ left = left.expand(-1, -1, -1, time)
47
+ right = 1 - left
48
+ wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2)
49
+ return wav
50
+
51
+
52
+ class FlipSign(nn.Module):
53
+ """
54
+ Random sign flip.
55
+ """
56
+ def forward(self, wav):
57
+ batch, sources, channels, time = wav.size()
58
+ if self.training:
59
+ signs = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32)
60
+ wav = wav * (2 * signs - 1)
61
+ return wav
62
+
63
+
64
+ class Remix(nn.Module):
65
+ """
66
+ Shuffle sources to make new mixes.
67
+ """
68
+ def __init__(self, proba=1, group_size=4):
69
+ """
70
+ Shuffle sources within one batch.
71
+ Each batch is divided into groups of size `group_size` and shuffling is done within
72
+ each group separatly. This allow to keep the same probability distribution no matter
73
+ the number of GPUs. Without this grouping, using more GPUs would lead to a higher
74
+ probability of keeping two sources from the same track together which can impact
75
+ performance.
76
+ """
77
+ super().__init__()
78
+ self.proba = proba
79
+ self.group_size = group_size
80
+
81
+ def forward(self, wav):
82
+ batch, streams, channels, time = wav.size()
83
+ device = wav.device
84
+
85
+ if self.training and random.random() < self.proba:
86
+ group_size = self.group_size or batch
87
+ if batch % group_size != 0:
88
+ raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}")
89
+ groups = batch // group_size
90
+ wav = wav.view(groups, group_size, streams, channels, time)
91
+ permutations = th.argsort(th.rand(groups, group_size, streams, 1, 1, device=device),
92
+ dim=1)
93
+ wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time))
94
+ wav = wav.view(batch, streams, channels, time)
95
+ return wav
96
+
97
+
98
+ class Scale(nn.Module):
99
+ def __init__(self, proba=1., min=0.25, max=1.25):
100
+ super().__init__()
101
+ self.proba = proba
102
+ self.min = min
103
+ self.max = max
104
+
105
+ def forward(self, wav):
106
+ batch, streams, channels, time = wav.size()
107
+ device = wav.device
108
+ if self.training and random.random() < self.proba:
109
+ scales = th.empty(batch, streams, 1, 1, device=device).uniform_(self.min, self.max)
110
+ wav *= scales
111
+ return wav
demucs/demucs.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+
10
+ import julius
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from .states import capture_init
16
+ from .utils import center_trim, unfold
17
+ from .transformer import LayerScale
18
+
19
+
20
+ class BLSTM(nn.Module):
21
+ """
22
+ BiLSTM with same hidden units as input dim.
23
+ If `max_steps` is not None, input will be splitting in overlapping
24
+ chunks and the LSTM applied separately on each chunk.
25
+ """
26
+ def __init__(self, dim, layers=1, max_steps=None, skip=False):
27
+ super().__init__()
28
+ assert max_steps is None or max_steps % 4 == 0
29
+ self.max_steps = max_steps
30
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
31
+ self.linear = nn.Linear(2 * dim, dim)
32
+ self.skip = skip
33
+
34
+ def forward(self, x):
35
+ B, C, T = x.shape
36
+ y = x
37
+ framed = False
38
+ if self.max_steps is not None and T > self.max_steps:
39
+ width = self.max_steps
40
+ stride = width // 2
41
+ frames = unfold(x, width, stride)
42
+ nframes = frames.shape[2]
43
+ framed = True
44
+ x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
45
+
46
+ x = x.permute(2, 0, 1)
47
+
48
+ x = self.lstm(x)[0]
49
+ x = self.linear(x)
50
+ x = x.permute(1, 2, 0)
51
+ if framed:
52
+ out = []
53
+ frames = x.reshape(B, -1, C, width)
54
+ limit = stride // 2
55
+ for k in range(nframes):
56
+ if k == 0:
57
+ out.append(frames[:, k, :, :-limit])
58
+ elif k == nframes - 1:
59
+ out.append(frames[:, k, :, limit:])
60
+ else:
61
+ out.append(frames[:, k, :, limit:-limit])
62
+ out = torch.cat(out, -1)
63
+ out = out[..., :T]
64
+ x = out
65
+ if self.skip:
66
+ x = x + y
67
+ return x
68
+
69
+
70
+ def rescale_conv(conv, reference):
71
+ """Rescale initial weight scale. It is unclear why it helps but it certainly does.
72
+ """
73
+ std = conv.weight.std().detach()
74
+ scale = (std / reference)**0.5
75
+ conv.weight.data /= scale
76
+ if conv.bias is not None:
77
+ conv.bias.data /= scale
78
+
79
+
80
+ def rescale_module(module, reference):
81
+ for sub in module.modules():
82
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
83
+ rescale_conv(sub, reference)
84
+
85
+
86
+ class DConv(nn.Module):
87
+ """
88
+ New residual branches in each encoder layer.
89
+ This alternates dilated convolutions, potentially with LSTMs and attention.
90
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
91
+ e.g. of dim `channels // compress`.
92
+ """
93
+ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
94
+ norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
95
+ kernel=3, dilate=True):
96
+ """
97
+ Args:
98
+ channels: input/output channels for residual branch.
99
+ compress: amount of channel compression inside the branch.
100
+ depth: number of layers in the residual branch. Each layer has its own
101
+ projection, and potentially LSTM and attention.
102
+ init: initial scale for LayerNorm.
103
+ norm: use GroupNorm.
104
+ attn: use LocalAttention.
105
+ heads: number of heads for the LocalAttention.
106
+ ndecay: number of decay controls in the LocalAttention.
107
+ lstm: use LSTM.
108
+ gelu: Use GELU activation.
109
+ kernel: kernel size for the (dilated) convolutions.
110
+ dilate: if true, use dilation, increasing with the depth.
111
+ """
112
+
113
+ super().__init__()
114
+ assert kernel % 2 == 1
115
+ self.channels = channels
116
+ self.compress = compress
117
+ self.depth = abs(depth)
118
+ dilate = depth > 0
119
+
120
+ norm_fn: tp.Callable[[int], nn.Module]
121
+ norm_fn = lambda d: nn.Identity() # noqa
122
+ if norm:
123
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
124
+
125
+ hidden = int(channels / compress)
126
+
127
+ act: tp.Type[nn.Module]
128
+ if gelu:
129
+ act = nn.GELU
130
+ else:
131
+ act = nn.ReLU
132
+
133
+ self.layers = nn.ModuleList([])
134
+ for d in range(self.depth):
135
+ dilation = 2 ** d if dilate else 1
136
+ padding = dilation * (kernel // 2)
137
+ mods = [
138
+ nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
139
+ norm_fn(hidden), act(),
140
+ nn.Conv1d(hidden, 2 * channels, 1),
141
+ norm_fn(2 * channels), nn.GLU(1),
142
+ LayerScale(channels, init),
143
+ ]
144
+ if attn:
145
+ mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
146
+ if lstm:
147
+ mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
148
+ layer = nn.Sequential(*mods)
149
+ self.layers.append(layer)
150
+
151
+ def forward(self, x):
152
+ for layer in self.layers:
153
+ x = x + layer(x)
154
+ return x
155
+
156
+
157
+ class LocalState(nn.Module):
158
+ """Local state allows to have attention based only on data (no positional embedding),
159
+ but while setting a constraint on the time window (e.g. decaying penalty term).
160
+
161
+ Also a failed experiments with trying to provide some frequency based attention.
162
+ """
163
+ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
164
+ super().__init__()
165
+ assert channels % heads == 0, (channels, heads)
166
+ self.heads = heads
167
+ self.nfreqs = nfreqs
168
+ self.ndecay = ndecay
169
+ self.content = nn.Conv1d(channels, channels, 1)
170
+ self.query = nn.Conv1d(channels, channels, 1)
171
+ self.key = nn.Conv1d(channels, channels, 1)
172
+ if nfreqs:
173
+ self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
174
+ if ndecay:
175
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
176
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
177
+ self.query_decay.weight.data *= 0.01
178
+ assert self.query_decay.bias is not None # stupid type checker
179
+ self.query_decay.bias.data[:] = -2
180
+ self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
181
+
182
+ def forward(self, x):
183
+ B, C, T = x.shape
184
+ heads = self.heads
185
+ indexes = torch.arange(T, device=x.device, dtype=x.dtype)
186
+ # left index are keys, right index are queries
187
+ delta = indexes[:, None] - indexes[None, :]
188
+
189
+ queries = self.query(x).view(B, heads, -1, T)
190
+ keys = self.key(x).view(B, heads, -1, T)
191
+ # t are keys, s are queries
192
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
193
+ dots /= keys.shape[2]**0.5
194
+ if self.nfreqs:
195
+ periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
196
+ freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
197
+ freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
198
+ dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
199
+ if self.ndecay:
200
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
201
+ decay_q = self.query_decay(x).view(B, heads, -1, T)
202
+ decay_q = torch.sigmoid(decay_q) / 2
203
+ decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
204
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
205
+
206
+ # Kill self reference.
207
+ dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
208
+ weights = torch.softmax(dots, dim=2)
209
+
210
+ content = self.content(x).view(B, heads, -1, T)
211
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
212
+ if self.nfreqs:
213
+ time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
214
+ result = torch.cat([result, time_sig], 2)
215
+ result = result.reshape(B, -1, T)
216
+ return x + self.proj(result)
217
+
218
+
219
+ class Demucs(nn.Module):
220
+ @capture_init
221
+ def __init__(self,
222
+ sources,
223
+ # Channels
224
+ audio_channels=2,
225
+ channels=64,
226
+ growth=2.,
227
+ # Main structure
228
+ depth=6,
229
+ rewrite=True,
230
+ lstm_layers=0,
231
+ # Convolutions
232
+ kernel_size=8,
233
+ stride=4,
234
+ context=1,
235
+ # Activations
236
+ gelu=True,
237
+ glu=True,
238
+ # Normalization
239
+ norm_starts=4,
240
+ norm_groups=4,
241
+ # DConv residual branch
242
+ dconv_mode=1,
243
+ dconv_depth=2,
244
+ dconv_comp=4,
245
+ dconv_attn=4,
246
+ dconv_lstm=4,
247
+ dconv_init=1e-4,
248
+ # Pre/post processing
249
+ normalize=True,
250
+ resample=True,
251
+ # Weight init
252
+ rescale=0.1,
253
+ # Metadata
254
+ samplerate=44100,
255
+ segment=4 * 10):
256
+ """
257
+ Args:
258
+ sources (list[str]): list of source names
259
+ audio_channels (int): stereo or mono
260
+ channels (int): first convolution channels
261
+ depth (int): number of encoder/decoder layers
262
+ growth (float): multiply (resp divide) number of channels by that
263
+ for each layer of the encoder (resp decoder)
264
+ depth (int): number of layers in the encoder and in the decoder.
265
+ rewrite (bool): add 1x1 convolution to each layer.
266
+ lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
267
+ by default, as this is now replaced by the smaller and faster small LSTMs
268
+ in the DConv branches.
269
+ kernel_size (int): kernel size for convolutions
270
+ stride (int): stride for convolutions
271
+ context (int): kernel size of the convolution in the
272
+ decoder before the transposed convolution. If > 1,
273
+ will provide some context from neighboring time steps.
274
+ gelu: use GELU activation function.
275
+ glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
276
+ norm_starts: layer at which group norm starts being used.
277
+ decoder layers are numbered in reverse order.
278
+ norm_groups: number of groups for group norm.
279
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
280
+ dconv_depth: depth of residual DConv branch.
281
+ dconv_comp: compression of DConv branch.
282
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
283
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
284
+ dconv_init: initial scale for the DConv branch LayerScale.
285
+ normalize (bool): normalizes the input audio on the fly, and scales back
286
+ the output by the same amount.
287
+ resample (bool): upsample x2 the input and downsample /2 the output.
288
+ rescale (float): rescale initial weights of convolutions
289
+ to get their standard deviation closer to `rescale`.
290
+ samplerate (int): stored as meta information for easing
291
+ future evaluations of the model.
292
+ segment (float): duration of the chunks of audio to ideally evaluate the model on.
293
+ This is used by `demucs.apply.apply_model`.
294
+ """
295
+
296
+ super().__init__()
297
+ self.audio_channels = audio_channels
298
+ self.sources = sources
299
+ self.kernel_size = kernel_size
300
+ self.context = context
301
+ self.stride = stride
302
+ self.depth = depth
303
+ self.resample = resample
304
+ self.channels = channels
305
+ self.normalize = normalize
306
+ self.samplerate = samplerate
307
+ self.segment = segment
308
+ self.encoder = nn.ModuleList()
309
+ self.decoder = nn.ModuleList()
310
+ self.skip_scales = nn.ModuleList()
311
+
312
+ if glu:
313
+ activation = nn.GLU(dim=1)
314
+ ch_scale = 2
315
+ else:
316
+ activation = nn.ReLU()
317
+ ch_scale = 1
318
+ if gelu:
319
+ act2 = nn.GELU
320
+ else:
321
+ act2 = nn.ReLU
322
+
323
+ in_channels = audio_channels
324
+ padding = 0
325
+ for index in range(depth):
326
+ norm_fn = lambda d: nn.Identity() # noqa
327
+ if index >= norm_starts:
328
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
329
+
330
+ encode = []
331
+ encode += [
332
+ nn.Conv1d(in_channels, channels, kernel_size, stride),
333
+ norm_fn(channels),
334
+ act2(),
335
+ ]
336
+ attn = index >= dconv_attn
337
+ lstm = index >= dconv_lstm
338
+ if dconv_mode & 1:
339
+ encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
340
+ compress=dconv_comp, attn=attn, lstm=lstm)]
341
+ if rewrite:
342
+ encode += [
343
+ nn.Conv1d(channels, ch_scale * channels, 1),
344
+ norm_fn(ch_scale * channels), activation]
345
+ self.encoder.append(nn.Sequential(*encode))
346
+
347
+ decode = []
348
+ if index > 0:
349
+ out_channels = in_channels
350
+ else:
351
+ out_channels = len(self.sources) * audio_channels
352
+ if rewrite:
353
+ decode += [
354
+ nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
355
+ norm_fn(ch_scale * channels), activation]
356
+ if dconv_mode & 2:
357
+ decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
358
+ compress=dconv_comp, attn=attn, lstm=lstm)]
359
+ decode += [nn.ConvTranspose1d(channels, out_channels,
360
+ kernel_size, stride, padding=padding)]
361
+ if index > 0:
362
+ decode += [norm_fn(out_channels), act2()]
363
+ self.decoder.insert(0, nn.Sequential(*decode))
364
+ in_channels = channels
365
+ channels = int(growth * channels)
366
+
367
+ channels = in_channels
368
+ if lstm_layers:
369
+ self.lstm = BLSTM(channels, lstm_layers)
370
+ else:
371
+ self.lstm = None
372
+
373
+ if rescale:
374
+ rescale_module(self, reference=rescale)
375
+
376
+ def valid_length(self, length):
377
+ """
378
+ Return the nearest valid length to use with the model so that
379
+ there is no time steps left over in a convolution, e.g. for all
380
+ layers, size of the input - kernel_size % stride = 0.
381
+
382
+ Note that input are automatically padded if necessary to ensure that the output
383
+ has the same length as the input.
384
+ """
385
+ if self.resample:
386
+ length *= 2
387
+
388
+ for _ in range(self.depth):
389
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
390
+ length = max(1, length)
391
+
392
+ for idx in range(self.depth):
393
+ length = (length - 1) * self.stride + self.kernel_size
394
+
395
+ if self.resample:
396
+ length = math.ceil(length / 2)
397
+ return int(length)
398
+
399
+ def forward(self, mix):
400
+ x = mix
401
+ length = x.shape[-1]
402
+
403
+ if self.normalize:
404
+ mono = mix.mean(dim=1, keepdim=True)
405
+ mean = mono.mean(dim=-1, keepdim=True)
406
+ std = mono.std(dim=-1, keepdim=True)
407
+ x = (x - mean) / (1e-5 + std)
408
+ else:
409
+ mean = 0
410
+ std = 1
411
+
412
+ delta = self.valid_length(length) - length
413
+ x = F.pad(x, (delta // 2, delta - delta // 2))
414
+
415
+ if self.resample:
416
+ x = julius.resample_frac(x, 1, 2)
417
+
418
+ saved = []
419
+ for encode in self.encoder:
420
+ x = encode(x)
421
+ saved.append(x)
422
+
423
+ if self.lstm:
424
+ x = self.lstm(x)
425
+
426
+ for decode in self.decoder:
427
+ skip = saved.pop(-1)
428
+ skip = center_trim(skip, x)
429
+ x = decode(x + skip)
430
+
431
+ if self.resample:
432
+ x = julius.resample_frac(x, 2, 1)
433
+ x = x * std + mean
434
+ x = center_trim(x, length)
435
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
436
+ return x
437
+
438
+ def load_state_dict(self, state, strict=True):
439
+ # fix a mismatch with previous generation Demucs models.
440
+ for idx in range(self.depth):
441
+ for a in ['encoder', 'decoder']:
442
+ for b in ['bias', 'weight']:
443
+ new = f'{a}.{idx}.3.{b}'
444
+ old = f'{a}.{idx}.2.{b}'
445
+ if old in state and new not in state:
446
+ state[new] = state.pop(old)
447
+ super().load_state_dict(state, strict=strict)
demucs/distrib.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Distributed training utilities.
7
+ """
8
+ import logging
9
+ import pickle
10
+
11
+ import numpy as np
12
+ import torch
13
+ from torch.utils.data.distributed import DistributedSampler
14
+ from torch.utils.data import DataLoader, Subset
15
+ from torch.nn.parallel.distributed import DistributedDataParallel
16
+
17
+ from dora import distrib as dora_distrib
18
+
19
+ logger = logging.getLogger(__name__)
20
+ rank = 0
21
+ world_size = 1
22
+
23
+
24
+ def init():
25
+ global rank, world_size
26
+ if not torch.distributed.is_initialized():
27
+ dora_distrib.init()
28
+ rank = dora_distrib.rank()
29
+ world_size = dora_distrib.world_size()
30
+
31
+
32
+ def average(metrics, count=1.):
33
+ if isinstance(metrics, dict):
34
+ keys, values = zip(*sorted(metrics.items()))
35
+ values = average(values, count)
36
+ return dict(zip(keys, values))
37
+ if world_size == 1:
38
+ return metrics
39
+ tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
40
+ tensor *= count
41
+ torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
42
+ return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
43
+
44
+
45
+ def wrap(model):
46
+ if world_size == 1:
47
+ return model
48
+ else:
49
+ return DistributedDataParallel(
50
+ model,
51
+ # find_unused_parameters=True,
52
+ device_ids=[torch.cuda.current_device()],
53
+ output_device=torch.cuda.current_device())
54
+
55
+
56
+ def barrier():
57
+ if world_size > 1:
58
+ torch.distributed.barrier()
59
+
60
+
61
+ def share(obj=None, src=0):
62
+ if world_size == 1:
63
+ return obj
64
+ size = torch.empty(1, device='cuda', dtype=torch.long)
65
+ if rank == src:
66
+ dump = pickle.dumps(obj)
67
+ size[0] = len(dump)
68
+ torch.distributed.broadcast(size, src=src)
69
+ # size variable is now set to the length of pickled obj in all processes
70
+
71
+ if rank == src:
72
+ buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda()
73
+ else:
74
+ buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8)
75
+ torch.distributed.broadcast(buffer, src=src)
76
+ # buffer variable is now set to pickled obj in all processes
77
+
78
+ if rank != src:
79
+ obj = pickle.loads(buffer.cpu().numpy().tobytes())
80
+ logger.debug(f"Shared object of size {len(buffer)}")
81
+ return obj
82
+
83
+
84
+ def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
85
+ """
86
+ Create a dataloader properly in case of distributed training.
87
+ If a gradient is going to be computed you must set `shuffle=True`.
88
+ """
89
+ if world_size == 1:
90
+ return klass(dataset, *args, shuffle=shuffle, **kwargs)
91
+
92
+ if shuffle:
93
+ # train means we will compute backward, we use DistributedSampler
94
+ sampler = DistributedSampler(dataset)
95
+ # We ignore shuffle, DistributedSampler already shuffles
96
+ return klass(dataset, *args, **kwargs, sampler=sampler)
97
+ else:
98
+ # We make a manual shard, as DistributedSampler otherwise replicate some examples
99
+ dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
100
+ return klass(dataset, *args, shuffle=shuffle, **kwargs)
demucs/ema.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Inspired from https://github.com/rwightman/pytorch-image-models
8
+ from contextlib import contextmanager
9
+
10
+ import torch
11
+
12
+ from .states import swap_state
13
+
14
+
15
+ class ModelEMA:
16
+ """
17
+ Perform EMA on a model. You can switch to the EMA weights temporarily
18
+ with the `swap` method.
19
+
20
+ ema = ModelEMA(model)
21
+ with ema.swap():
22
+ # compute valid metrics with averaged model.
23
+ """
24
+ def __init__(self, model, decay=0.9999, unbias=True, device='cpu'):
25
+ self.decay = decay
26
+ self.model = model
27
+ self.state = {}
28
+ self.count = 0
29
+ self.device = device
30
+ self.unbias = unbias
31
+
32
+ self._init()
33
+
34
+ def _init(self):
35
+ for key, val in self.model.state_dict().items():
36
+ if val.dtype != torch.float32:
37
+ continue
38
+ device = self.device or val.device
39
+ if key not in self.state:
40
+ self.state[key] = val.detach().to(device, copy=True)
41
+
42
+ def update(self):
43
+ if self.unbias:
44
+ self.count = self.count * self.decay + 1
45
+ w = 1 / self.count
46
+ else:
47
+ w = 1 - self.decay
48
+ for key, val in self.model.state_dict().items():
49
+ if val.dtype != torch.float32:
50
+ continue
51
+ device = self.device or val.device
52
+ self.state[key].mul_(1 - w)
53
+ self.state[key].add_(val.detach().to(device), alpha=w)
54
+
55
+ @contextmanager
56
+ def swap(self):
57
+ with swap_state(self.model, self.state):
58
+ yield
59
+
60
+ def state_dict(self):
61
+ return {'state': self.state, 'count': self.count}
62
+
63
+ def load_state_dict(self, state):
64
+ self.count = state['count']
65
+ for k, v in state['state'].items():
66
+ self.state[k].copy_(v)
demucs/evaluate.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Test time evaluation, either using the original SDR from [Vincent et al. 2006]
8
+ or the newest SDR definition from the MDX 2021 competition (this one will
9
+ be reported as `nsdr` for `new sdr`).
10
+ """
11
+
12
+ from concurrent import futures
13
+ import logging
14
+
15
+ from dora.log import LogProgress
16
+ import numpy as np
17
+ import musdb
18
+ import museval
19
+ import torch as th
20
+
21
+ from .apply import apply_model
22
+ from .audio import convert_audio, save_audio
23
+ from . import distrib
24
+ from .utils import DummyPoolExecutor
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def new_sdr(references, estimates):
31
+ """
32
+ Compute the SDR according to the MDX challenge definition.
33
+ Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license)
34
+ """
35
+ assert references.dim() == 4
36
+ assert estimates.dim() == 4
37
+ delta = 1e-7 # avoid numerical errors
38
+ num = th.sum(th.square(references), dim=(2, 3))
39
+ den = th.sum(th.square(references - estimates), dim=(2, 3))
40
+ num += delta
41
+ den += delta
42
+ scores = 10 * th.log10(num / den)
43
+ return scores
44
+
45
+
46
+ def eval_track(references, estimates, win, hop, compute_sdr=True):
47
+ references = references.transpose(1, 2).double()
48
+ estimates = estimates.transpose(1, 2).double()
49
+
50
+ new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0]
51
+
52
+ if not compute_sdr:
53
+ return None, new_scores
54
+ else:
55
+ references = references.numpy()
56
+ estimates = estimates.numpy()
57
+ scores = museval.metrics.bss_eval(
58
+ references, estimates,
59
+ compute_permutation=False,
60
+ window=win,
61
+ hop=hop,
62
+ framewise_filters=False,
63
+ bsseval_sources_version=False)[:-1]
64
+ return scores, new_scores
65
+
66
+
67
+ def evaluate(solver, compute_sdr=False):
68
+ """
69
+ Evaluate model using museval.
70
+ compute_sdr=False means using only the MDX definition of the SDR, which
71
+ is much faster to evaluate.
72
+ """
73
+
74
+ args = solver.args
75
+
76
+ output_dir = solver.folder / "results"
77
+ output_dir.mkdir(exist_ok=True, parents=True)
78
+ json_folder = solver.folder / "results/test"
79
+ json_folder.mkdir(exist_ok=True, parents=True)
80
+
81
+ # we load tracks from the original musdb set
82
+ if args.test.nonhq is None:
83
+ test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True)
84
+ else:
85
+ test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False)
86
+ src_rate = args.dset.musdb_samplerate
87
+
88
+ eval_device = 'cpu'
89
+
90
+ model = solver.model
91
+ win = int(1. * model.samplerate)
92
+ hop = int(1. * model.samplerate)
93
+
94
+ indexes = range(distrib.rank, len(test_set), distrib.world_size)
95
+ indexes = LogProgress(logger, indexes, updates=args.misc.num_prints,
96
+ name='Eval')
97
+ pendings = []
98
+
99
+ pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor
100
+ with pool(args.test.workers) as pool:
101
+ for index in indexes:
102
+ track = test_set.tracks[index]
103
+
104
+ mix = th.from_numpy(track.audio).t().float()
105
+ if mix.dim() == 1:
106
+ mix = mix[None]
107
+ mix = mix.to(solver.device)
108
+ ref = mix.mean(dim=0) # mono mixture
109
+ mix = (mix - ref.mean()) / ref.std()
110
+ mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
111
+ estimates = apply_model(model, mix[None],
112
+ shifts=args.test.shifts, split=args.test.split,
113
+ overlap=args.test.overlap)[0]
114
+ estimates = estimates * ref.std() + ref.mean()
115
+ estimates = estimates.to(eval_device)
116
+
117
+ references = th.stack(
118
+ [th.from_numpy(track.targets[name].audio).t() for name in model.sources])
119
+ if references.dim() == 2:
120
+ references = references[:, None]
121
+ references = references.to(eval_device)
122
+ references = convert_audio(references, src_rate,
123
+ model.samplerate, model.audio_channels)
124
+ if args.test.save:
125
+ folder = solver.folder / "wav" / track.name
126
+ folder.mkdir(exist_ok=True, parents=True)
127
+ for name, estimate in zip(model.sources, estimates):
128
+ save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate)
129
+
130
+ pendings.append((track.name, pool.submit(
131
+ eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr)))
132
+
133
+ pendings = LogProgress(logger, pendings, updates=args.misc.num_prints,
134
+ name='Eval (BSS)')
135
+ tracks = {}
136
+ for track_name, pending in pendings:
137
+ pending = pending.result()
138
+ scores, nsdrs = pending
139
+ tracks[track_name] = {}
140
+ for idx, target in enumerate(model.sources):
141
+ tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]}
142
+ if scores is not None:
143
+ (sdr, isr, sir, sar) = scores
144
+ for idx, target in enumerate(model.sources):
145
+ values = {
146
+ "SDR": sdr[idx].tolist(),
147
+ "SIR": sir[idx].tolist(),
148
+ "ISR": isr[idx].tolist(),
149
+ "SAR": sar[idx].tolist()
150
+ }
151
+ tracks[track_name][target].update(values)
152
+
153
+ all_tracks = {}
154
+ for src in range(distrib.world_size):
155
+ all_tracks.update(distrib.share(tracks, src))
156
+
157
+ result = {}
158
+ metric_names = next(iter(all_tracks.values()))[model.sources[0]]
159
+ for metric_name in metric_names:
160
+ avg = 0
161
+ avg_of_medians = 0
162
+ for source in model.sources:
163
+ medians = [
164
+ np.nanmedian(all_tracks[track][source][metric_name])
165
+ for track in all_tracks.keys()]
166
+ mean = np.mean(medians)
167
+ median = np.median(medians)
168
+ result[metric_name.lower() + "_" + source] = mean
169
+ result[metric_name.lower() + "_med" + "_" + source] = median
170
+ avg += mean / len(model.sources)
171
+ avg_of_medians += median / len(model.sources)
172
+ result[metric_name.lower()] = avg
173
+ result[metric_name.lower() + "_med"] = avg_of_medians
174
+ return result
demucs/grids/__init__.py ADDED
File without changes
demucs/grids/_explorers.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ from dora import Explorer
7
+ import treetable as tt
8
+
9
+
10
+ class MyExplorer(Explorer):
11
+ test_metrics = ['nsdr', 'sdr_med']
12
+
13
+ def get_grid_metrics(self):
14
+ """Return the metrics that should be displayed in the tracking table.
15
+ """
16
+ return [
17
+ tt.group("train", [
18
+ tt.leaf("epoch"),
19
+ tt.leaf("reco", ".3f"),
20
+ ], align=">"),
21
+ tt.group("valid", [
22
+ tt.leaf("penalty", ".1f"),
23
+ tt.leaf("ms", ".1f"),
24
+ tt.leaf("reco", ".2%"),
25
+ tt.leaf("breco", ".2%"),
26
+ tt.leaf("b_nsdr", ".2f"),
27
+ # tt.leaf("b_nsdr_drums", ".2f"),
28
+ # tt.leaf("b_nsdr_bass", ".2f"),
29
+ # tt.leaf("b_nsdr_other", ".2f"),
30
+ # tt.leaf("b_nsdr_vocals", ".2f"),
31
+ ], align=">"),
32
+ tt.group("test", [
33
+ tt.leaf(name, ".2f")
34
+ for name in self.test_metrics
35
+ ], align=">")
36
+ ]
37
+
38
+ def process_history(self, history):
39
+ train = {
40
+ 'epoch': len(history),
41
+ }
42
+ valid = {}
43
+ test = {}
44
+ best_v_main = float('inf')
45
+ breco = float('inf')
46
+ for metrics in history:
47
+ train.update(metrics['train'])
48
+ valid.update(metrics['valid'])
49
+ if 'main' in metrics['valid']:
50
+ best_v_main = min(best_v_main, metrics['valid']['main']['loss'])
51
+ valid['bmain'] = best_v_main
52
+ valid['breco'] = min(breco, metrics['valid']['reco'])
53
+ breco = valid['breco']
54
+ if (metrics['valid']['loss'] == metrics['valid']['best'] or
55
+ metrics['valid'].get('nsdr') == metrics['valid']['best']):
56
+ for k, v in metrics['valid'].items():
57
+ if k.startswith('reco_'):
58
+ valid['b_' + k[len('reco_'):]] = v
59
+ if k.startswith('nsdr'):
60
+ valid[f'b_{k}'] = v
61
+ if 'test' in metrics:
62
+ test.update(metrics['test'])
63
+ metrics = history[-1]
64
+ return {"train": train, "valid": valid, "test": test}
demucs/grids/mdx.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Main training for the Track A MDX models.
8
+ """
9
+
10
+ from ._explorers import MyExplorer
11
+ from ..train import main
12
+
13
+
14
+ TRACK_A = ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68']
15
+
16
+
17
+ @MyExplorer
18
+ def explorer(launcher):
19
+ launcher.slurm_(
20
+ gpus=8,
21
+ time=3 * 24 * 60,
22
+ partition='learnlab')
23
+
24
+ # Reproduce results from MDX competition Track A
25
+ # This trains the first round of models. Once this is trained,
26
+ # you will need to schedule `mdx_refine`.
27
+ for sig in TRACK_A:
28
+ xp = main.get_xp_from_sig(sig)
29
+ parent = xp.cfg.continue_from
30
+ xp = main.get_xp_from_sig(parent)
31
+ launcher(xp.argv)
32
+ launcher(xp.argv, {'quant.diffq': 1e-4})
33
+ launcher(xp.argv, {'quant.diffq': 3e-4})
demucs/grids/mdx_extra.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Main training for the Track A MDX models.
8
+ """
9
+
10
+ from ._explorers import MyExplorer
11
+ from ..train import main
12
+
13
+ TRACK_B = ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08']
14
+
15
+
16
+ @MyExplorer
17
+ def explorer(launcher):
18
+ launcher.slurm_(
19
+ gpus=8,
20
+ time=3 * 24 * 60,
21
+ partition='learnlab')
22
+
23
+ # Reproduce results from MDX competition Track A
24
+ # This trains the first round of models. Once this is trained,
25
+ # you will need to schedule `mdx_refine`.
26
+ for sig in TRACK_B:
27
+ while sig is not None:
28
+ xp = main.get_xp_from_sig(sig)
29
+ sig = xp.cfg.continue_from
30
+
31
+ for dset in ['extra44', 'extra_test']:
32
+ sub = launcher.bind(xp.argv, dset=dset)
33
+ sub()
34
+ if dset == 'extra_test':
35
+ sub({'quant.diffq': 1e-4})
36
+ sub({'quant.diffq': 3e-4})
demucs/grids/mdx_refine.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Main training for the Track A MDX models.
8
+ """
9
+
10
+ from ._explorers import MyExplorer
11
+ from .mdx import TRACK_A
12
+ from ..train import main
13
+
14
+
15
+ @MyExplorer
16
+ def explorer(launcher):
17
+ launcher.slurm_(
18
+ gpus=8,
19
+ time=3 * 24 * 60,
20
+ partition='learnlab')
21
+
22
+ # Reproduce results from MDX competition Track A
23
+ # WARNING: all the experiments in the `mdx` grid must have completed.
24
+ for sig in TRACK_A:
25
+ xp = main.get_xp_from_sig(sig)
26
+ launcher(xp.argv)
27
+ for diffq in [1e-4, 3e-4]:
28
+ xp_src = main.get_xp_from_sig(xp.cfg.continue_from)
29
+ q_argv = [f'quant.diffq={diffq}']
30
+ actual_src = main.get_xp(xp_src.argv + q_argv)
31
+ actual_src.link.load()
32
+ assert len(actual_src.link.history) == actual_src.cfg.epochs
33
+ argv = xp.argv + q_argv + [f'continue_from="{actual_src.sig}"']
34
+ launcher(argv)