Spaces:
Configuration error
Configuration error
Upload 107 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +27 -0
- .gitattributes +2 -0
- .github/ISSUE_TEMPLATE/bug.md +33 -0
- .github/ISSUE_TEMPLATE/question.md +10 -0
- .github/workflows/linter.yml +36 -0
- .github/workflows/tests.yml +36 -0
- .gitignore +17 -0
- .vscode/launch.json +19 -0
- .vscode/tasks.json +33 -0
- CODE_OF_CONDUCT.md +76 -0
- CONTRIBUTING.md +23 -0
- Demucs.ipynb +153 -0
- Dockerfile +23 -0
- LICENSE +21 -0
- MANIFEST.in +13 -0
- Makefile +36 -0
- README.md +319 -14
- app.py +37 -37
- conf/config.yaml +304 -0
- conf/dset/aetl.yaml +19 -0
- conf/dset/auto_extra_test.yaml +18 -0
- conf/dset/auto_mus.yaml +20 -0
- conf/dset/extra44.yaml +8 -0
- conf/dset/extra_mmi_goodclean.yaml +12 -0
- conf/dset/extra_test.yaml +12 -0
- conf/dset/musdb44.yaml +5 -0
- conf/dset/sdx23_bleeding.yaml +10 -0
- conf/dset/sdx23_labelnoise.yaml +10 -0
- conf/svd/base.yaml +14 -0
- conf/svd/base2.yaml +14 -0
- conf/svd/default.yaml +1 -0
- conf/variant/default.yaml +1 -0
- conf/variant/example.yaml +5 -0
- conf/variant/finetune.yaml +19 -0
- demucs.png +3 -0
- demucs/__init__.py +7 -0
- demucs/__main__.py +10 -0
- demucs/api.py +392 -0
- demucs/apply.py +322 -0
- demucs/audio.py +265 -0
- demucs/augment.py +111 -0
- demucs/demucs.py +447 -0
- demucs/distrib.py +100 -0
- demucs/ema.py +66 -0
- demucs/evaluate.py +174 -0
- demucs/grids/__init__.py +0 -0
- demucs/grids/_explorers.py +64 -0
- demucs/grids/mdx.py +33 -0
- demucs/grids/mdx_extra.py +36 -0
- demucs/grids/mdx_refine.py +34 -0
.dockerignore
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**/__pycache__
|
| 2 |
+
**/.venv
|
| 3 |
+
**/.classpath
|
| 4 |
+
**/.dockerignore
|
| 5 |
+
**/.env
|
| 6 |
+
**/.git
|
| 7 |
+
**/.gitignore
|
| 8 |
+
**/.project
|
| 9 |
+
**/.settings
|
| 10 |
+
**/.toolstarget
|
| 11 |
+
**/.vs
|
| 12 |
+
**/.vscode
|
| 13 |
+
**/*.*proj.user
|
| 14 |
+
**/*.dbmdl
|
| 15 |
+
**/*.jfm
|
| 16 |
+
**/bin
|
| 17 |
+
**/charts
|
| 18 |
+
**/docker-compose*
|
| 19 |
+
**/compose*
|
| 20 |
+
**/Dockerfile*
|
| 21 |
+
**/node_modules
|
| 22 |
+
**/npm-debug.log
|
| 23 |
+
**/obj
|
| 24 |
+
**/secrets.dev.yaml
|
| 25 |
+
**/values.dev.yaml
|
| 26 |
+
LICENSE
|
| 27 |
+
README.md
|
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
demucs.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
test.mp3 filter=lfs diff=lfs merge=lfs -text
|
.github/ISSUE_TEMPLATE/bug.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: 🐛 Bug Report
|
| 3 |
+
about: Submit a bug report to help us improve
|
| 4 |
+
labels: 'bug'
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 🐛 Bug Report
|
| 8 |
+
|
| 9 |
+
(A clear and concise description of what the bug is)
|
| 10 |
+
|
| 11 |
+
## To Reproduce
|
| 12 |
+
|
| 13 |
+
(Write your steps here:)
|
| 14 |
+
|
| 15 |
+
1. Step 1...
|
| 16 |
+
1. Step 2...
|
| 17 |
+
1. Step 3...
|
| 18 |
+
|
| 19 |
+
## Expected behavior
|
| 20 |
+
|
| 21 |
+
(Write what you thought would happen.)
|
| 22 |
+
|
| 23 |
+
## Actual Behavior
|
| 24 |
+
|
| 25 |
+
(Write what happened. Add screenshots, if applicable.)
|
| 26 |
+
|
| 27 |
+
## Your Environment
|
| 28 |
+
|
| 29 |
+
<!-- Include as many relevant details about the environment you experienced the bug in -->
|
| 30 |
+
|
| 31 |
+
- Python and PyTorch version:
|
| 32 |
+
- Operating system and version (desktop or mobile):
|
| 33 |
+
- Hardware (gpu or cpu, amount of RAM etc.):
|
.github/ISSUE_TEMPLATE/question.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: "❓Questions/Help/Support"
|
| 3 |
+
about: If you have a question about the paper, code or algorithm, please ask here!
|
| 4 |
+
labels: question
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## ❓ Questions
|
| 9 |
+
|
| 10 |
+
(Please ask your question here.)
|
.github/workflows/linter.yml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: linter
|
| 2 |
+
on:
|
| 3 |
+
push:
|
| 4 |
+
branches: [ main ]
|
| 5 |
+
pull_request:
|
| 6 |
+
branches: [ main ]
|
| 7 |
+
workflow_dispatch:
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
build:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
if: ${{ github.repository == 'facebookresearch/demucs' || github.event_name == 'workflow_dispatch' }}
|
| 13 |
+
steps:
|
| 14 |
+
- uses: actions/checkout@v2
|
| 15 |
+
- uses: actions/setup-python@v2
|
| 16 |
+
with:
|
| 17 |
+
python-version: 3.8
|
| 18 |
+
|
| 19 |
+
- uses: actions/cache@v2
|
| 20 |
+
with:
|
| 21 |
+
path: env
|
| 22 |
+
key: env-${{ hashFiles('**/requirements.txt', '.github/workflows/*') }}
|
| 23 |
+
|
| 24 |
+
- name: Install dependencies
|
| 25 |
+
run: |
|
| 26 |
+
python3 -m venv env
|
| 27 |
+
. env/bin/activate
|
| 28 |
+
python -m pip install --upgrade pip
|
| 29 |
+
pip install -r requirements.txt
|
| 30 |
+
pip install '.[dev]'
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
- name: Run linter
|
| 34 |
+
run: |
|
| 35 |
+
. env/bin/activate
|
| 36 |
+
make linter
|
.github/workflows/tests.yml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: tests
|
| 2 |
+
on:
|
| 3 |
+
push:
|
| 4 |
+
branches: [ main ]
|
| 5 |
+
pull_request:
|
| 6 |
+
branches: [ main ]
|
| 7 |
+
workflow_dispatch:
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
build:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
if: ${{ github.repository == 'facebookresearch/demucs' || github.event_name == 'workflow_dispatch' }}
|
| 13 |
+
steps:
|
| 14 |
+
- uses: actions/checkout@v2
|
| 15 |
+
- uses: actions/setup-python@v2
|
| 16 |
+
with:
|
| 17 |
+
python-version: 3.8
|
| 18 |
+
|
| 19 |
+
- uses: actions/cache@v2
|
| 20 |
+
with:
|
| 21 |
+
path: env
|
| 22 |
+
key: env-${{ hashFiles('**/requirements.txt', '.github/workflows/*') }}
|
| 23 |
+
|
| 24 |
+
- name: Install dependencies
|
| 25 |
+
run: |
|
| 26 |
+
sudo apt-get update
|
| 27 |
+
sudo apt-get install -y ffmpeg
|
| 28 |
+
python3 -m venv env
|
| 29 |
+
. env/bin/activate
|
| 30 |
+
python -m pip install --upgrade pip
|
| 31 |
+
pip install -r requirements.txt
|
| 32 |
+
|
| 33 |
+
- name: Run separation test
|
| 34 |
+
run: |
|
| 35 |
+
. env/bin/activate
|
| 36 |
+
make test_eval
|
.gitignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.egg-info
|
| 2 |
+
__pycache__
|
| 3 |
+
Session.vim
|
| 4 |
+
/build
|
| 5 |
+
/dist
|
| 6 |
+
/lab
|
| 7 |
+
/metadata
|
| 8 |
+
/notebooks
|
| 9 |
+
/outputs
|
| 10 |
+
/release
|
| 11 |
+
/release_models
|
| 12 |
+
/separated
|
| 13 |
+
/tests
|
| 14 |
+
/trash
|
| 15 |
+
/misc
|
| 16 |
+
/mdx
|
| 17 |
+
.mypy_cache
|
.vscode/launch.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"configurations": [
|
| 3 |
+
{
|
| 4 |
+
"name": "Containers: Python - Fastapi",
|
| 5 |
+
"type": "docker",
|
| 6 |
+
"request": "launch",
|
| 7 |
+
"preLaunchTask": "docker-run: debug",
|
| 8 |
+
"python": {
|
| 9 |
+
"pathMappings": [
|
| 10 |
+
{
|
| 11 |
+
"localRoot": "${workspaceFolder}",
|
| 12 |
+
"remoteRoot": "/app"
|
| 13 |
+
}
|
| 14 |
+
],
|
| 15 |
+
"projectType": "fastapi"
|
| 16 |
+
}
|
| 17 |
+
}
|
| 18 |
+
]
|
| 19 |
+
}
|
.vscode/tasks.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "2.0.0",
|
| 3 |
+
"tasks": [
|
| 4 |
+
{
|
| 5 |
+
"type": "docker-build",
|
| 6 |
+
"label": "docker-build",
|
| 7 |
+
"platform": "python",
|
| 8 |
+
"dockerBuild": {
|
| 9 |
+
"tag": "demucs:latest",
|
| 10 |
+
"dockerfile": "${workspaceFolder}/Dockerfile",
|
| 11 |
+
"context": "${workspaceFolder}",
|
| 12 |
+
"pull": true
|
| 13 |
+
}
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"type": "docker-run",
|
| 17 |
+
"label": "docker-run: debug",
|
| 18 |
+
"dependsOn": [
|
| 19 |
+
"docker-build"
|
| 20 |
+
],
|
| 21 |
+
"python": {
|
| 22 |
+
"args": [
|
| 23 |
+
"predict:app",
|
| 24 |
+
"--host",
|
| 25 |
+
"0.0.0.0",
|
| 26 |
+
"--port",
|
| 27 |
+
"8000"
|
| 28 |
+
],
|
| 29 |
+
"module": "uvicorn"
|
| 30 |
+
}
|
| 31 |
+
}
|
| 32 |
+
]
|
| 33 |
+
}
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
## Enforcement
|
| 56 |
+
|
| 57 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 58 |
+
reported by contacting the project team at <[email protected]>. All
|
| 59 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 60 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 61 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 62 |
+
Further details of specific enforcement policies may be posted separately.
|
| 63 |
+
|
| 64 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 65 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 66 |
+
members of the project's leadership.
|
| 67 |
+
|
| 68 |
+
## Attribution
|
| 69 |
+
|
| 70 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 71 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 72 |
+
|
| 73 |
+
[homepage]: https://www.contributor-covenant.org
|
| 74 |
+
|
| 75 |
+
For answers to common questions about this code of conduct, see
|
| 76 |
+
https://www.contributor-covenant.org/faq
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to Demucs
|
| 2 |
+
|
| 3 |
+
## Pull Requests
|
| 4 |
+
|
| 5 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 6 |
+
to do this once to work on any of Facebook's open source projects.
|
| 7 |
+
|
| 8 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 9 |
+
|
| 10 |
+
Demucs is the implementation of a research paper.
|
| 11 |
+
Therefore, we do not plan on accepting many pull requests for new features.
|
| 12 |
+
We certainly welcome them for bug fixes.
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
## Issues
|
| 16 |
+
|
| 17 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 18 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
## License
|
| 22 |
+
By contributing to this repository, you agree that your contributions will be licensed
|
| 23 |
+
under the LICENSE file in the root directory of this source tree.
|
Demucs.ipynb
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"colab_type": "text",
|
| 7 |
+
"id": "Be9yoh-ILfRr"
|
| 8 |
+
},
|
| 9 |
+
"source": [
|
| 10 |
+
"# Hybrid Demucs\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"Feel free to use the Colab version:\n",
|
| 13 |
+
"https://colab.research.google.com/drive/1dC9nVxk3V_VPjUADsnFu8EiT-xnU1tGH?usp=sharing"
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "code",
|
| 18 |
+
"execution_count": null,
|
| 19 |
+
"metadata": {
|
| 20 |
+
"colab": {
|
| 21 |
+
"base_uri": "https://localhost:8080/",
|
| 22 |
+
"height": 139
|
| 23 |
+
},
|
| 24 |
+
"colab_type": "code",
|
| 25 |
+
"executionInfo": {
|
| 26 |
+
"elapsed": 12277,
|
| 27 |
+
"status": "ok",
|
| 28 |
+
"timestamp": 1583778134659,
|
| 29 |
+
"user": {
|
| 30 |
+
"displayName": "Marllus Lustosa",
|
| 31 |
+
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgLl2RbW64ZyWz3Y8IBku0zhHCMnt7fz7fEl0LTdA=s64",
|
| 32 |
+
"userId": "14811735256675200480"
|
| 33 |
+
},
|
| 34 |
+
"user_tz": 180
|
| 35 |
+
},
|
| 36 |
+
"id": "kOjIPLlzhPfn",
|
| 37 |
+
"outputId": "c75f17ec-b576-4105-bc5b-c2ac9c1018a3"
|
| 38 |
+
},
|
| 39 |
+
"outputs": [],
|
| 40 |
+
"source": [
|
| 41 |
+
"!pip install -U demucs\n",
|
| 42 |
+
"# or for local development, if you have a clone of Demucs\n",
|
| 43 |
+
"# pip install -e ."
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"execution_count": null,
|
| 49 |
+
"metadata": {
|
| 50 |
+
"colab": {},
|
| 51 |
+
"colab_type": "code",
|
| 52 |
+
"id": "5lYOzKKCKAbJ"
|
| 53 |
+
},
|
| 54 |
+
"outputs": [],
|
| 55 |
+
"source": [
|
| 56 |
+
"# You can use the `demucs` command line to separate tracks\n",
|
| 57 |
+
"!demucs test.mp3"
|
| 58 |
+
]
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"cell_type": "code",
|
| 62 |
+
"execution_count": null,
|
| 63 |
+
"metadata": {},
|
| 64 |
+
"outputs": [],
|
| 65 |
+
"source": [
|
| 66 |
+
"# You can also load directly the pretrained models,\n",
|
| 67 |
+
"# for instance for the MDX 2021 winning model of Track A:\n",
|
| 68 |
+
"from demucs import pretrained\n",
|
| 69 |
+
"model = pretrained.get_model('mdx')"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"execution_count": null,
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"outputs": [],
|
| 77 |
+
"source": [
|
| 78 |
+
"# Because `model` is a bag of 4 models, you cannot directly call it on your data,\n",
|
| 79 |
+
"# but the `apply_model` will know what to do of it.\n",
|
| 80 |
+
"import torch\n",
|
| 81 |
+
"from demucs.apply import apply_model\n",
|
| 82 |
+
"x = torch.randn(1, 2, 44100 * 10) # ten seconds of white noise for the demo\n",
|
| 83 |
+
"out = apply_model(model, x)[0] # shape is [S, C, T] with S the number of sources\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"# So let see, where is all the white noise content is going ?\n",
|
| 86 |
+
"for name, source in zip(model.sources, out):\n",
|
| 87 |
+
" print(name, source.std() / x.std())\n",
|
| 88 |
+
"# The outputs are quite weird to be fair, not what I would have expected."
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"execution_count": null,
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"outputs": [],
|
| 96 |
+
"source": [
|
| 97 |
+
"# now let's take a single model from the bag, and let's test it on a pure cosine\n",
|
| 98 |
+
"freq = 440 # in Hz\n",
|
| 99 |
+
"sr = model.samplerate\n",
|
| 100 |
+
"t = torch.arange(10 * sr).float() / sr\n",
|
| 101 |
+
"x = torch.cos(2 * 3.1416 * freq * t).expand(1, 2, -1)\n",
|
| 102 |
+
"sub_model = model.models[3]\n",
|
| 103 |
+
"out = sub_model(x)[0]\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"# Same question where does it go?\n",
|
| 106 |
+
"for name, source in zip(model.sources, out):\n",
|
| 107 |
+
" print(name, source.std() / x.std())\n",
|
| 108 |
+
" \n",
|
| 109 |
+
"# Well now it makes much more sense, all the energy is going\n",
|
| 110 |
+
"# in the `other` source.\n",
|
| 111 |
+
"# Feel free to try lower pitch (try 80 Hz) to see what happens !"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "code",
|
| 116 |
+
"execution_count": null,
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"outputs": [],
|
| 119 |
+
"source": [
|
| 120 |
+
"# For training or more fun, refer to the Demucs README on our repo\n",
|
| 121 |
+
"# https://github.com/facebookresearch/demucs/tree/main/demucs"
|
| 122 |
+
]
|
| 123 |
+
}
|
| 124 |
+
],
|
| 125 |
+
"metadata": {
|
| 126 |
+
"accelerator": "GPU",
|
| 127 |
+
"colab": {
|
| 128 |
+
"authorship_tag": "ABX9TyM9xpVr1M86NRcjtQ7g9tCx",
|
| 129 |
+
"collapsed_sections": [],
|
| 130 |
+
"name": "Demucs.ipynb",
|
| 131 |
+
"provenance": []
|
| 132 |
+
},
|
| 133 |
+
"kernelspec": {
|
| 134 |
+
"display_name": "Python 3",
|
| 135 |
+
"language": "python",
|
| 136 |
+
"name": "python3"
|
| 137 |
+
},
|
| 138 |
+
"language_info": {
|
| 139 |
+
"codemirror_mode": {
|
| 140 |
+
"name": "ipython",
|
| 141 |
+
"version": 3
|
| 142 |
+
},
|
| 143 |
+
"file_extension": ".py",
|
| 144 |
+
"mimetype": "text/x-python",
|
| 145 |
+
"name": "python",
|
| 146 |
+
"nbconvert_exporter": "python",
|
| 147 |
+
"pygments_lexer": "ipython3",
|
| 148 |
+
"version": "3.8.8"
|
| 149 |
+
}
|
| 150 |
+
},
|
| 151 |
+
"nbformat": 4,
|
| 152 |
+
"nbformat_minor": 1
|
| 153 |
+
}
|
Dockerfile
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.9 slim base
|
| 2 |
+
FROM python:3.9-slim
|
| 3 |
+
|
| 4 |
+
# Install system dependencies
|
| 5 |
+
RUN apt-get update && apt-get install -y ffmpeg git && apt-get clean
|
| 6 |
+
|
| 7 |
+
# Set work directory
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
# Install Python packages
|
| 11 |
+
RUN pip install --upgrade pip
|
| 12 |
+
RUN pip install torch torchaudio
|
| 13 |
+
RUN pip install fastapi uvicorn
|
| 14 |
+
RUN pip install git+https://github.com/facebookresearch/demucs
|
| 15 |
+
|
| 16 |
+
# Copy your inference script into the container
|
| 17 |
+
COPY predict.py .
|
| 18 |
+
|
| 19 |
+
# Expose port for FastAPI
|
| 20 |
+
EXPOSE 8000
|
| 21 |
+
|
| 22 |
+
# Run the FastAPI app
|
| 23 |
+
CMD ["uvicorn", "predict:app", "--host", "0.0.0.0", "--port", "8000"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
MANIFEST.in
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
recursive-exclude env *
|
| 2 |
+
recursive-include conf *.yaml
|
| 3 |
+
include Makefile
|
| 4 |
+
include LICENSE
|
| 5 |
+
include demucs.png
|
| 6 |
+
include outputs.tar.gz
|
| 7 |
+
include test.mp3
|
| 8 |
+
include requirements.txt
|
| 9 |
+
include requirements_minimal.txt
|
| 10 |
+
include mypy.ini
|
| 11 |
+
include demucs/py.typed
|
| 12 |
+
include demucs/remote/*.txt
|
| 13 |
+
include demucs/remote/*.yaml
|
Makefile
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
all: linter tests
|
| 2 |
+
|
| 3 |
+
linter:
|
| 4 |
+
flake8 demucs
|
| 5 |
+
mypy demucs
|
| 6 |
+
|
| 7 |
+
tests: test_train test_eval
|
| 8 |
+
|
| 9 |
+
test_train: tests/musdb
|
| 10 |
+
_DORA_TEST_PATH=/tmp/demucs python3 -m dora run --clear \
|
| 11 |
+
dset.musdb=./tests/musdb dset.segment=4 dset.shift=2 epochs=2 model=demucs \
|
| 12 |
+
demucs.depth=2 demucs.channels=4 test.sdr=false misc.num_workers=0 test.workers=0 \
|
| 13 |
+
test.shifts=0
|
| 14 |
+
|
| 15 |
+
test_eval:
|
| 16 |
+
python3 -m demucs -n demucs_unittest test.mp3
|
| 17 |
+
python3 -m demucs -n demucs_unittest --two-stems=vocals test.mp3
|
| 18 |
+
python3 -m demucs -n demucs_unittest --mp3 test.mp3
|
| 19 |
+
python3 -m demucs -n demucs_unittest --flac --int24 test.mp3
|
| 20 |
+
python3 -m demucs -n demucs_unittest --int24 --clip-mode clamp test.mp3
|
| 21 |
+
python3 -m demucs -n demucs_unittest --segment 8 test.mp3
|
| 22 |
+
python3 -m demucs.api -n demucs_unittest --segment 8 test.mp3
|
| 23 |
+
python3 -m demucs --list-models
|
| 24 |
+
|
| 25 |
+
tests/musdb:
|
| 26 |
+
test -e tests || mkdir tests
|
| 27 |
+
python3 -c 'import musdb; musdb.DB("tests/tmp", download=True)'
|
| 28 |
+
musdbconvert tests/tmp tests/musdb
|
| 29 |
+
|
| 30 |
+
dist:
|
| 31 |
+
python3 setup.py sdist
|
| 32 |
+
|
| 33 |
+
clean:
|
| 34 |
+
rm -r dist build *.egg-info
|
| 35 |
+
|
| 36 |
+
.PHONY: linter dist test_train test_eval
|
README.md
CHANGED
|
@@ -1,14 +1,319 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Demucs Music Source Separation
|
| 2 |
+
|
| 3 |
+
[](https://opensource.fb.com/support-ukraine)
|
| 4 |
+

|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
**Important:** As I am no longer working at Meta, **this repository is not maintained anymore**.
|
| 9 |
+
I've created a fork at [github.com/adefossez/demucs](https://github.com/adefossez/demucs). Note that this project is not actively maintained anymore
|
| 10 |
+
and only important bug fixes will be processed on the new repo. Please do not open issues for feature request or if Demucs doesn't work perfectly for your use case :)
|
| 11 |
+
|
| 12 |
+
This is the 4th release of Demucs (v4), featuring Hybrid Transformer based source separation.
|
| 13 |
+
**For the classic Hybrid Demucs (v3):** [Go this commit][demucs_v3].
|
| 14 |
+
If you are experiencing issues and want the old Demucs back, please file an issue, and then you can get back to Demucs v3 with
|
| 15 |
+
`git checkout v3`. You can also go [Demucs v2][demucs_v2].
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
Demucs is a state-of-the-art music source separation model, currently capable of separating
|
| 19 |
+
drums, bass, and vocals from the rest of the accompaniment.
|
| 20 |
+
Demucs is based on a U-Net convolutional architecture inspired by [Wave-U-Net][waveunet].
|
| 21 |
+
The v4 version features [Hybrid Transformer Demucs][htdemucs], a hybrid spectrogram/waveform separation model using Transformers.
|
| 22 |
+
It is based on [Hybrid Demucs][hybrid_paper] (also provided in this repo), with the innermost layers
|
| 23 |
+
replaced by a cross-domain Transformer Encoder. This Transformer uses self-attention within each domain,
|
| 24 |
+
and cross-attention across domains.
|
| 25 |
+
The model achieves a SDR of 9.00 dB on the MUSDB HQ test set. Moreover, when using sparse attention
|
| 26 |
+
kernels to extend its receptive field and per source fine-tuning, we achieve state-of-the-art 9.20 dB of SDR.
|
| 27 |
+
|
| 28 |
+
Samples are available [on our sample page](https://ai.honu.io/papers/htdemucs/index.html).
|
| 29 |
+
Checkout [our paper][htdemucs] for more information.
|
| 30 |
+
It has been trained on the [MUSDB HQ][musdb] dataset + an extra training dataset of 800 songs.
|
| 31 |
+
This model separates drums, bass and vocals and other stems for any song.
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
As Hybrid Transformer Demucs is brand new, it is not activated by default, you can activate it in the usual
|
| 35 |
+
commands described hereafter with `-n htdemucs_ft`.
|
| 36 |
+
The single, non fine-tuned model is provided as `-n htdemucs`, and the retrained baseline
|
| 37 |
+
as `-n hdemucs_mmi`. The Sparse Hybrid Transformer model decribed in our paper is not provided as its
|
| 38 |
+
requires custom CUDA code that is not ready for release yet.
|
| 39 |
+
We are also releasing an experimental 6 sources model, that adds a `guitar` and `piano` source.
|
| 40 |
+
Quick testing seems to show okay quality for `guitar`, but a lot of bleeding and artifacts for the `piano` source.
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
<p align="center">
|
| 44 |
+
<img src="./demucs.png" alt="Schema representing the structure of Hybrid Transformer Demucs,
|
| 45 |
+
with a dual U-Net structure, one branch for the temporal domain,
|
| 46 |
+
and one branch for the spectral domain. There is a cross-domain Transformer between the Encoders and Decoders."
|
| 47 |
+
width="800px"></p>
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
## Important news if you are already using Demucs
|
| 52 |
+
|
| 53 |
+
See the [release notes](./docs/release.md) for more details.
|
| 54 |
+
|
| 55 |
+
- 22/02/2023: added support for the [SDX 2023 Challenge](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023),
|
| 56 |
+
see the dedicated [doc page](./docs/sdx23.md)
|
| 57 |
+
- 07/12/2022: Demucs v4 now on PyPI. **htdemucs** model now used by default. Also releasing
|
| 58 |
+
a 6 sources models (adding `guitar` and `piano`, although the latter doesn't work so well at the moment).
|
| 59 |
+
- 16/11/2022: Added the new **Hybrid Transformer Demucs v4** models.
|
| 60 |
+
Adding support for the [torchaudio implementation of HDemucs](https://pytorch.org/audio/stable/tutorials/hybrid_demucs_tutorial.html).
|
| 61 |
+
- 30/08/2022: added reproducibility and ablation grids, along with an updated version of the paper.
|
| 62 |
+
- 17/08/2022: Releasing v3.0.5: Set split segment length to reduce memory. Compatible with pyTorch 1.12.
|
| 63 |
+
- 24/02/2022: Releasing v3.0.4: split into two stems (i.e. karaoke mode).
|
| 64 |
+
Export as float32 or int24.
|
| 65 |
+
- 17/12/2021: Releasing v3.0.3: bug fixes (thanks @keunwoochoi), memory drastically
|
| 66 |
+
reduced on GPU (thanks @famzah) and new multi-core evaluation on CPU (`-j` flag).
|
| 67 |
+
- 12/11/2021: Releasing **Demucs v3** with hybrid domain separation. Strong improvements
|
| 68 |
+
on all sources. This is the model that won Sony MDX challenge.
|
| 69 |
+
- 11/05/2021: Adding support for MusDB-HQ and arbitrary wav set, for the MDX challenge. For more information
|
| 70 |
+
on joining the challenge with Demucs see [the Demucs MDX instructions](docs/mdx.md)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
## Comparison with other models
|
| 74 |
+
|
| 75 |
+
We provide hereafter a summary of the different metrics presented in the paper.
|
| 76 |
+
You can also compare Hybrid Demucs (v3), [KUIELAB-MDX-Net][kuielab], [Spleeter][spleeter], Open-Unmix, Demucs (v1), and Conv-Tasnet on one of my favorite
|
| 77 |
+
songs on my [soundcloud playlist][soundcloud].
|
| 78 |
+
|
| 79 |
+
### Comparison of accuracy
|
| 80 |
+
|
| 81 |
+
`Overall SDR` is the mean of the SDR for each of the 4 sources, `MOS Quality` is a rating from 1 to 5
|
| 82 |
+
of the naturalness and absence of artifacts given by human listeners (5 = no artifacts), `MOS Contamination`
|
| 83 |
+
is a rating from 1 to 5 with 5 being zero contamination by other sources. We refer the reader to our [paper][hybrid_paper],
|
| 84 |
+
for more details.
|
| 85 |
+
|
| 86 |
+
| Model | Domain | Extra data? | Overall SDR | MOS Quality | MOS Contamination |
|
| 87 |
+
|------------------------------|-------------|-------------------|-------------|-------------|-------------------|
|
| 88 |
+
| [Wave-U-Net][waveunet] | waveform | no | 3.2 | - | - |
|
| 89 |
+
| [Open-Unmix][openunmix] | spectrogram | no | 5.3 | - | - |
|
| 90 |
+
| [D3Net][d3net] | spectrogram | no | 6.0 | - | - |
|
| 91 |
+
| [Conv-Tasnet][demucs_v2] | waveform | no | 5.7 | - | |
|
| 92 |
+
| [Demucs (v2)][demucs_v2] | waveform | no | 6.3 | 2.37 | 2.36 |
|
| 93 |
+
| [ResUNetDecouple+][decouple] | spectrogram | no | 6.7 | - | - |
|
| 94 |
+
| [KUIELAB-MDX-Net][kuielab] | hybrid | no | 7.5 | **2.86** | 2.55 |
|
| 95 |
+
| [Band-Spit RNN][bandsplit] | spectrogram | no | **8.2** | - | - |
|
| 96 |
+
| **Hybrid Demucs (v3)** | hybrid | no | 7.7 | **2.83** | **3.04** |
|
| 97 |
+
| [MMDenseLSTM][mmdenselstm] | spectrogram | 804 songs | 6.0 | - | - |
|
| 98 |
+
| [D3Net][d3net] | spectrogram | 1.5k songs | 6.7 | - | - |
|
| 99 |
+
| [Spleeter][spleeter] | spectrogram | 25k songs | 5.9 | - | - |
|
| 100 |
+
| [Band-Spit RNN][bandsplit] | spectrogram | 1.7k (mixes only) | **9.0** | - | - |
|
| 101 |
+
| **HT Demucs f.t. (v4)** | hybrid | 800 songs | **9.0** | - | - |
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
## Requirements
|
| 106 |
+
|
| 107 |
+
You will need at least Python 3.8. See `requirements_minimal.txt` for requirements for separation only,
|
| 108 |
+
and `environment-[cpu|cuda].yml` (or `requirements.txt`) if you want to train a new model.
|
| 109 |
+
|
| 110 |
+
### For Windows users
|
| 111 |
+
|
| 112 |
+
Everytime you see `python3`, replace it with `python.exe`. You should always run commands from the
|
| 113 |
+
Anaconda console.
|
| 114 |
+
|
| 115 |
+
### For musicians
|
| 116 |
+
|
| 117 |
+
If you just want to use Demucs to separate tracks, you can install it with
|
| 118 |
+
|
| 119 |
+
```bash
|
| 120 |
+
python3 -m pip install -U demucs
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
For bleeding edge versions, you can install directly from this repo using
|
| 124 |
+
```bash
|
| 125 |
+
python3 -m pip install -U git+https://github.com/facebookresearch/demucs#egg=demucs
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
Advanced OS support are provided on the following page, **you must read the page for your OS before posting an issues**:
|
| 129 |
+
- **If you are using Windows:** [Windows support](docs/windows.md).
|
| 130 |
+
- **If you are using macOS:** [macOS support](docs/mac.md).
|
| 131 |
+
- **If you are using Linux:** [Linux support](docs/linux.md).
|
| 132 |
+
|
| 133 |
+
### For machine learning scientists
|
| 134 |
+
|
| 135 |
+
If you have anaconda installed, you can run from the root of this repository:
|
| 136 |
+
|
| 137 |
+
```bash
|
| 138 |
+
conda env update -f environment-cpu.yml # if you don't have GPUs
|
| 139 |
+
conda env update -f environment-cuda.yml # if you have GPUs
|
| 140 |
+
conda activate demucs
|
| 141 |
+
pip install -e .
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
This will create a `demucs` environment with all the dependencies installed.
|
| 145 |
+
|
| 146 |
+
You will also need to install [soundstretch/soundtouch](https://www.surina.net/soundtouch/soundstretch.html): on macOS you can do `brew install sound-touch`,
|
| 147 |
+
and on Ubuntu `sudo apt-get install soundstretch`. This is used for the
|
| 148 |
+
pitch/tempo augmentation.
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
### Running in Docker
|
| 152 |
+
|
| 153 |
+
Thanks to @xserrat, there is now a Docker image definition ready for using Demucs. This can ensure all libraries are correctly installed without interfering with the host OS. See his repo [Docker Facebook Demucs](https://github.com/xserrat/docker-facebook-demucs) for more information.
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
### Running from Colab
|
| 157 |
+
|
| 158 |
+
I made a Colab to easily separate track with Demucs. Note that
|
| 159 |
+
transfer speeds with Colab are a bit slow for large media files,
|
| 160 |
+
but it will allow you to use Demucs without installing anything.
|
| 161 |
+
|
| 162 |
+
[Demucs on Google Colab](https://colab.research.google.com/drive/1dC9nVxk3V_VPjUADsnFu8EiT-xnU1tGH?usp=sharing)
|
| 163 |
+
|
| 164 |
+
### Web Demo
|
| 165 |
+
|
| 166 |
+
Integrated to [Hugging Face Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See demo: [](https://huggingface.co/spaces/akhaliq/demucs)
|
| 167 |
+
|
| 168 |
+
### Graphical Interface
|
| 169 |
+
|
| 170 |
+
@CarlGao4 has released a GUI for Demucs: [CarlGao4/Demucs-Gui](https://github.com/CarlGao4/Demucs-Gui). Downloads for Windows and macOS is available [here](https://github.com/CarlGao4/Demucs-Gui/releases). Use [FossHub mirror](https://fosshub.com/Demucs-GUI.html) to speed up your download.
|
| 171 |
+
|
| 172 |
+
@Anjok07 is providing a self contained GUI in [UVR (Ultimate Vocal Remover)](https://github.com/facebookresearch/demucs/issues/334) that supports Demucs.
|
| 173 |
+
|
| 174 |
+
### Other providers
|
| 175 |
+
|
| 176 |
+
Audiostrip is providing free online separation with Demucs on their website [https://audiostrip.co.uk/](https://audiostrip.co.uk/).
|
| 177 |
+
|
| 178 |
+
[MVSep](https://mvsep.com/) also provides free online separation, select `Demucs3 model B` for the best quality.
|
| 179 |
+
|
| 180 |
+
[Neutone](https://neutone.space/) provides a realtime Demucs model in their free VST/AU plugin that can be used in your favorite DAW.
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
## Separating tracks
|
| 184 |
+
|
| 185 |
+
In order to try Demucs, you can just run from any folder (as long as you properly installed it)
|
| 186 |
+
|
| 187 |
+
```bash
|
| 188 |
+
demucs PATH_TO_AUDIO_FILE_1 [PATH_TO_AUDIO_FILE_2 ...] # for Demucs
|
| 189 |
+
# If you used `pip install --user` you might need to replace demucs with python3 -m demucs
|
| 190 |
+
python3 -m demucs --mp3 --mp3-bitrate BITRATE PATH_TO_AUDIO_FILE_1 # output files saved as MP3
|
| 191 |
+
# use --mp3-preset to change encoder preset, 2 for best quality, 7 for fastest
|
| 192 |
+
# If your filename contain spaces don't forget to quote it !!!
|
| 193 |
+
demucs "my music/my favorite track.mp3"
|
| 194 |
+
# You can select different models with `-n` mdx_q is the quantized model, smaller but maybe a bit less accurate.
|
| 195 |
+
demucs -n mdx_q myfile.mp3
|
| 196 |
+
# If you only want to separate vocals out of an audio, use `--two-stems=vocals` (You can also set to drums or bass)
|
| 197 |
+
demucs --two-stems=vocals myfile.mp3
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
If you have a GPU, but you run out of memory, please use `--segment SEGMENT` to reduce length of each split. `SEGMENT` should be changed to a integer describing the length of each segment in seconds.
|
| 202 |
+
A segment length of at least 10 is recommended (the bigger the number is, the more memory is required, but quality may increase). Note that the Hybrid Transformer models only support a maximum segment length of 7.8 seconds.
|
| 203 |
+
Creating an environment variable `PYTORCH_NO_CUDA_MEMORY_CACHING=1` is also helpful. If this still does not help, please add `-d cpu` to the command line. See the section hereafter for more details on the memory requirements for GPU acceleration.
|
| 204 |
+
|
| 205 |
+
Separated tracks are stored in the `separated/MODEL_NAME/TRACK_NAME` folder. There you will find four stereo wav files sampled at 44.1 kHz: `drums.wav`, `bass.wav`,
|
| 206 |
+
`other.wav`, `vocals.wav` (or `.mp3` if you used the `--mp3` option).
|
| 207 |
+
|
| 208 |
+
All audio formats supported by `torchaudio` can be processed (i.e. wav, mp3, flac, ogg/vorbis on Linux/macOS, etc.). On Windows, `torchaudio` has limited support, so we rely on `ffmpeg`, which should support pretty much anything.
|
| 209 |
+
Audio is resampled on the fly if necessary.
|
| 210 |
+
The output will be a wav file encoded as int16.
|
| 211 |
+
You can save as float32 wav files with `--float32`, or 24 bits integer wav with `--int24`.
|
| 212 |
+
You can pass `--mp3` to save as mp3 instead, and set the bitrate (in kbps) with `--mp3-bitrate` (default is 320).
|
| 213 |
+
|
| 214 |
+
It can happen that the output would need clipping, in particular due to some separation artifacts.
|
| 215 |
+
Demucs will automatically rescale each output stem so as to avoid clipping. This can however break
|
| 216 |
+
the relative volume between stems. If instead you prefer hard clipping, pass `--clip-mode clamp`.
|
| 217 |
+
You can also try to reduce the volume of the input mixture before feeding it to Demucs.
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
Other pre-trained models can be selected with the `-n` flag.
|
| 221 |
+
The list of pre-trained models is:
|
| 222 |
+
- `htdemucs`: first version of Hybrid Transformer Demucs. Trained on MusDB + 800 songs. Default model.
|
| 223 |
+
- `htdemucs_ft`: fine-tuned version of `htdemucs`, separation will take 4 times more time
|
| 224 |
+
but might be a bit better. Same training set as `htdemucs`.
|
| 225 |
+
- `htdemucs_6s`: 6 sources version of `htdemucs`, with `piano` and `guitar` being added as sources.
|
| 226 |
+
Note that the `piano` source is not working great at the moment.
|
| 227 |
+
- `hdemucs_mmi`: Hybrid Demucs v3, retrained on MusDB + 800 songs.
|
| 228 |
+
- `mdx`: trained only on MusDB HQ, winning model on track A at the [MDX][mdx] challenge.
|
| 229 |
+
- `mdx_extra`: trained with extra training data (**including MusDB test set**), ranked 2nd on the track B
|
| 230 |
+
of the [MDX][mdx] challenge.
|
| 231 |
+
- `mdx_q`, `mdx_extra_q`: quantized version of the previous models. Smaller download and storage
|
| 232 |
+
but quality can be slightly worse.
|
| 233 |
+
- `SIG`: where `SIG` is a single model from the [model zoo](docs/training.md#model-zoo).
|
| 234 |
+
|
| 235 |
+
The `--two-stems=vocals` option allows separating vocals from the rest of the accompaniment (i.e., karaoke mode).
|
| 236 |
+
`vocals` can be changed to any source in the selected model.
|
| 237 |
+
This will mix the files after separating the mix fully, so this won't be faster or use less memory.
|
| 238 |
+
|
| 239 |
+
The `--shifts=SHIFTS` performs multiple predictions with random shifts (a.k.a the *shift trick*) of the input and average them. This makes prediction `SHIFTS` times
|
| 240 |
+
slower. Don't use it unless you have a GPU.
|
| 241 |
+
|
| 242 |
+
The `--overlap` option controls the amount of overlap between prediction windows. Default is 0.25 (i.e. 25%) which is probably fine.
|
| 243 |
+
It can probably be reduced to 0.1 to improve a bit speed.
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
The `-j` flag allow to specify a number of parallel jobs (e.g. `demucs -j 2 myfile.mp3`).
|
| 247 |
+
This will multiply by the same amount the RAM used so be careful!
|
| 248 |
+
|
| 249 |
+
### Memory requirements for GPU acceleration
|
| 250 |
+
|
| 251 |
+
If you want to use GPU acceleration, you will need at least 3GB of RAM on your GPU for `demucs`. However, about 7GB of RAM will be required if you use the default arguments. Add `--segment SEGMENT` to change size of each split. If you only have 3GB memory, set SEGMENT to 8 (though quality may be worse if this argument is too small). Creating an environment variable `PYTORCH_NO_CUDA_MEMORY_CACHING=1` can help users with even smaller RAM such as 2GB (I separated a track that is 4 minutes but only 1.5GB is used), but this would make the separation slower.
|
| 252 |
+
|
| 253 |
+
If you do not have enough memory on your GPU, simply add `-d cpu` to the command line to use the CPU. With Demucs, processing time should be roughly equal to 1.5 times the duration of the track.
|
| 254 |
+
|
| 255 |
+
## Calling from another Python program
|
| 256 |
+
|
| 257 |
+
The main function provides an `opt` parameter as a simple API. You can just pass the parsed command line as this parameter:
|
| 258 |
+
```python
|
| 259 |
+
# Assume that your command is `demucs --mp3 --two-stems vocals -n mdx_extra "track with space.mp3"`
|
| 260 |
+
# The following codes are same as the command above:
|
| 261 |
+
import demucs.separate
|
| 262 |
+
demucs.separate.main(["--mp3", "--two-stems", "vocals", "-n", "mdx_extra", "track with space.mp3"])
|
| 263 |
+
|
| 264 |
+
# Or like this
|
| 265 |
+
import demucs.separate
|
| 266 |
+
import shlex
|
| 267 |
+
demucs.separate.main(shlex.split('--mp3 --two-stems vocals -n mdx_extra "track with space.mp3"'))
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
To use more complicated APIs, see [API docs](docs/api.md)
|
| 271 |
+
|
| 272 |
+
## Training Demucs
|
| 273 |
+
|
| 274 |
+
If you want to train (Hybrid) Demucs, please follow the [training doc](docs/training.md).
|
| 275 |
+
|
| 276 |
+
## MDX Challenge reproduction
|
| 277 |
+
|
| 278 |
+
In order to reproduce the results from the Track A and Track B submissions, checkout the [MDX Hybrid Demucs submission repo][mdx_submission].
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
## How to cite
|
| 283 |
+
|
| 284 |
+
```
|
| 285 |
+
@inproceedings{rouard2022hybrid,
|
| 286 |
+
title={Hybrid Transformers for Music Source Separation},
|
| 287 |
+
author={Rouard, Simon and Massa, Francisco and D{\'e}fossez, Alexandre},
|
| 288 |
+
booktitle={ICASSP 23},
|
| 289 |
+
year={2023}
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
@inproceedings{defossez2021hybrid,
|
| 293 |
+
title={Hybrid Spectrogram and Waveform Source Separation},
|
| 294 |
+
author={D{\'e}fossez, Alexandre},
|
| 295 |
+
booktitle={Proceedings of the ISMIR 2021 Workshop on Music Source Separation},
|
| 296 |
+
year={2021}
|
| 297 |
+
}
|
| 298 |
+
```
|
| 299 |
+
|
| 300 |
+
## License
|
| 301 |
+
|
| 302 |
+
Demucs is released under the MIT license as found in the [LICENSE](LICENSE) file.
|
| 303 |
+
|
| 304 |
+
[hybrid_paper]: https://arxiv.org/abs/2111.03600
|
| 305 |
+
[waveunet]: https://github.com/f90/Wave-U-Net
|
| 306 |
+
[musdb]: https://sigsep.github.io/datasets/musdb.html
|
| 307 |
+
[openunmix]: https://github.com/sigsep/open-unmix-pytorch
|
| 308 |
+
[mmdenselstm]: https://arxiv.org/abs/1805.02410
|
| 309 |
+
[demucs_v2]: https://github.com/facebookresearch/demucs/tree/v2
|
| 310 |
+
[demucs_v3]: https://github.com/facebookresearch/demucs/tree/v3
|
| 311 |
+
[spleeter]: https://github.com/deezer/spleeter
|
| 312 |
+
[soundcloud]: https://soundcloud.com/honualx/sets/source-separation-in-the-waveform-domain
|
| 313 |
+
[d3net]: https://arxiv.org/abs/2010.01733
|
| 314 |
+
[mdx]: https://www.aicrowd.com/challenges/music-demixing-challenge-ismir-2021
|
| 315 |
+
[kuielab]: https://github.com/kuielab/mdx-net-submission
|
| 316 |
+
[decouple]: https://arxiv.org/abs/2109.05418
|
| 317 |
+
[mdx_submission]: https://github.com/adefossez/mdx21_demucs
|
| 318 |
+
[bandsplit]: https://arxiv.org/abs/2209.15174
|
| 319 |
+
[htdemucs]: https://arxiv.org/abs/2211.08553
|
app.py
CHANGED
|
@@ -1,37 +1,37 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import shutil
|
| 3 |
-
import gradio as gr
|
| 4 |
-
from demucs.separate import main
|
| 5 |
-
|
| 6 |
-
def separate_stems(audio_file):
|
| 7 |
-
input_path = "input.mp3"
|
| 8 |
-
shutil.copy(audio_file, input_path)
|
| 9 |
-
|
| 10 |
-
output_dir = "output"
|
| 11 |
-
if os.path.exists(output_dir):
|
| 12 |
-
shutil.rmtree(output_dir)
|
| 13 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 14 |
-
|
| 15 |
-
# Run Demucs
|
| 16 |
-
main(["-n", "htdemucs", "-o", output_dir, input_path])
|
| 17 |
-
|
| 18 |
-
# Build list of stems to return
|
| 19 |
-
base = os.path.splitext(os.path.basename(input_path))[0]
|
| 20 |
-
stem_path = os.path.join(output_dir, "htdemucs", base)
|
| 21 |
-
stems = [os.path.join(stem_path, f"{stem}.mp3") for stem in ["vocals", "drums", "bass", "other"]]
|
| 22 |
-
return stems
|
| 23 |
-
|
| 24 |
-
demo = gr.Interface(
|
| 25 |
-
fn=separate_stems,
|
| 26 |
-
inputs=gr.Audio(type="filepath", label="Upload Song"),
|
| 27 |
-
outputs=[
|
| 28 |
-
gr.Audio(label="Vocals"),
|
| 29 |
-
gr.Audio(label="Drums"),
|
| 30 |
-
gr.Audio(label="Bass"),
|
| 31 |
-
gr.Audio(label="Other"),
|
| 32 |
-
],
|
| 33 |
-
title="Demucs v4 Stem Separator",
|
| 34 |
-
description="Upload a song to separate vocals, drums, bass, and other using Facebook's Demucs model.",
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
demo.launch()
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from demucs.separate import main
|
| 5 |
+
|
| 6 |
+
def separate_stems(audio_file):
|
| 7 |
+
input_path = "input.mp3"
|
| 8 |
+
shutil.copy(audio_file, input_path)
|
| 9 |
+
|
| 10 |
+
output_dir = "output"
|
| 11 |
+
if os.path.exists(output_dir):
|
| 12 |
+
shutil.rmtree(output_dir)
|
| 13 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 14 |
+
|
| 15 |
+
# Run Demucs
|
| 16 |
+
main(["-n", "htdemucs", "-o", output_dir, input_path])
|
| 17 |
+
|
| 18 |
+
# Build list of stems to return
|
| 19 |
+
base = os.path.splitext(os.path.basename(input_path))[0]
|
| 20 |
+
stem_path = os.path.join(output_dir, "htdemucs", base)
|
| 21 |
+
stems = [os.path.join(stem_path, f"{stem}.mp3") for stem in ["vocals", "drums", "bass", "other"]]
|
| 22 |
+
return stems
|
| 23 |
+
|
| 24 |
+
demo = gr.Interface(
|
| 25 |
+
fn=separate_stems,
|
| 26 |
+
inputs=gr.Audio(type="filepath", label="Upload Song"),
|
| 27 |
+
outputs=[
|
| 28 |
+
gr.Audio(label="Vocals"),
|
| 29 |
+
gr.Audio(label="Drums"),
|
| 30 |
+
gr.Audio(label="Bass"),
|
| 31 |
+
gr.Audio(label="Other"),
|
| 32 |
+
],
|
| 33 |
+
title="Demucs v4 Stem Separator",
|
| 34 |
+
description="Upload a song to separate vocals, drums, bass, and other using Facebook's Demucs model.",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
demo.launch()
|
conf/config.yaml
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- dset: musdb44
|
| 4 |
+
- svd: default
|
| 5 |
+
- variant: default
|
| 6 |
+
- override hydra/hydra_logging: colorlog
|
| 7 |
+
- override hydra/job_logging: colorlog
|
| 8 |
+
|
| 9 |
+
dummy:
|
| 10 |
+
dset:
|
| 11 |
+
musdb: /checkpoint/defossez/datasets/musdbhq
|
| 12 |
+
musdb_samplerate: 44100
|
| 13 |
+
use_musdb: true # set to false to not use musdb as training data.
|
| 14 |
+
wav: # path to custom wav dataset
|
| 15 |
+
wav2: # second custom wav dataset
|
| 16 |
+
segment: 11
|
| 17 |
+
shift: 1
|
| 18 |
+
train_valid: false
|
| 19 |
+
full_cv: true
|
| 20 |
+
samplerate: 44100
|
| 21 |
+
channels: 2
|
| 22 |
+
normalize: true
|
| 23 |
+
metadata: ./metadata
|
| 24 |
+
sources: ['drums', 'bass', 'other', 'vocals']
|
| 25 |
+
valid_samples: # valid dataset size
|
| 26 |
+
backend: null # if provided select torchaudio backend.
|
| 27 |
+
|
| 28 |
+
test:
|
| 29 |
+
save: False
|
| 30 |
+
best: True
|
| 31 |
+
workers: 2
|
| 32 |
+
every: 20
|
| 33 |
+
split: true
|
| 34 |
+
shifts: 1
|
| 35 |
+
overlap: 0.25
|
| 36 |
+
sdr: true
|
| 37 |
+
metric: 'loss' # metric used for best model selection on the valid set, can also be nsdr
|
| 38 |
+
nonhq: # path to non hq MusDB for evaluation
|
| 39 |
+
|
| 40 |
+
epochs: 360
|
| 41 |
+
batch_size: 64
|
| 42 |
+
max_batches: # limit the number of batches per epoch, useful for debugging
|
| 43 |
+
# or if your dataset is gigantic.
|
| 44 |
+
optim:
|
| 45 |
+
lr: 3e-4
|
| 46 |
+
momentum: 0.9
|
| 47 |
+
beta2: 0.999
|
| 48 |
+
loss: l1 # l1 or mse
|
| 49 |
+
optim: adam
|
| 50 |
+
weight_decay: 0
|
| 51 |
+
clip_grad: 0
|
| 52 |
+
|
| 53 |
+
seed: 42
|
| 54 |
+
debug: false
|
| 55 |
+
valid_apply: true
|
| 56 |
+
flag:
|
| 57 |
+
save_every:
|
| 58 |
+
weights: [1., 1., 1., 1.] # weights over each source for the training/valid loss.
|
| 59 |
+
|
| 60 |
+
augment:
|
| 61 |
+
shift_same: false
|
| 62 |
+
repitch:
|
| 63 |
+
proba: 0.2
|
| 64 |
+
max_tempo: 12
|
| 65 |
+
remix:
|
| 66 |
+
proba: 1
|
| 67 |
+
group_size: 4
|
| 68 |
+
scale:
|
| 69 |
+
proba: 1
|
| 70 |
+
min: 0.25
|
| 71 |
+
max: 1.25
|
| 72 |
+
flip: true
|
| 73 |
+
|
| 74 |
+
continue_from: # continue from other XP, give the XP Dora signature.
|
| 75 |
+
continue_pretrained: # signature of a pretrained XP, this cannot be a bag of models.
|
| 76 |
+
pretrained_repo: # repo for pretrained model (default is official AWS)
|
| 77 |
+
continue_best: true
|
| 78 |
+
continue_opt: false
|
| 79 |
+
|
| 80 |
+
misc:
|
| 81 |
+
num_workers: 10
|
| 82 |
+
num_prints: 4
|
| 83 |
+
show: false
|
| 84 |
+
verbose: false
|
| 85 |
+
|
| 86 |
+
# List of decay for EMA at batch or epoch level, e.g. 0.999.
|
| 87 |
+
# Batch level EMA are kept on GPU for speed.
|
| 88 |
+
ema:
|
| 89 |
+
epoch: []
|
| 90 |
+
batch: []
|
| 91 |
+
|
| 92 |
+
use_train_segment: true # to remove
|
| 93 |
+
model_segment: # override the segment parameter for the model, usually 4 times the training segment.
|
| 94 |
+
model: demucs # see demucs/train.py for the possibilities, and config for each model hereafter.
|
| 95 |
+
demucs: # see demucs/demucs.py for a detailed description
|
| 96 |
+
# Channels
|
| 97 |
+
channels: 64
|
| 98 |
+
growth: 2
|
| 99 |
+
# Main structure
|
| 100 |
+
depth: 6
|
| 101 |
+
rewrite: true
|
| 102 |
+
lstm_layers: 0
|
| 103 |
+
# Convolutions
|
| 104 |
+
kernel_size: 8
|
| 105 |
+
stride: 4
|
| 106 |
+
context: 1
|
| 107 |
+
# Activations
|
| 108 |
+
gelu: true
|
| 109 |
+
glu: true
|
| 110 |
+
# Normalization
|
| 111 |
+
norm_groups: 4
|
| 112 |
+
norm_starts: 4
|
| 113 |
+
# DConv residual branch
|
| 114 |
+
dconv_depth: 2
|
| 115 |
+
dconv_mode: 1 # 1 = branch in encoder, 2 = in decoder, 3 = in both.
|
| 116 |
+
dconv_comp: 4
|
| 117 |
+
dconv_attn: 4
|
| 118 |
+
dconv_lstm: 4
|
| 119 |
+
dconv_init: 1e-4
|
| 120 |
+
# Pre/post treatment
|
| 121 |
+
resample: true
|
| 122 |
+
normalize: false
|
| 123 |
+
# Weight init
|
| 124 |
+
rescale: 0.1
|
| 125 |
+
|
| 126 |
+
hdemucs: # see demucs/hdemucs.py for a detailed description
|
| 127 |
+
# Channels
|
| 128 |
+
channels: 48
|
| 129 |
+
channels_time:
|
| 130 |
+
growth: 2
|
| 131 |
+
# STFT
|
| 132 |
+
nfft: 4096
|
| 133 |
+
wiener_iters: 0
|
| 134 |
+
end_iters: 0
|
| 135 |
+
wiener_residual: false
|
| 136 |
+
cac: true
|
| 137 |
+
# Main structure
|
| 138 |
+
depth: 6
|
| 139 |
+
rewrite: true
|
| 140 |
+
hybrid: true
|
| 141 |
+
hybrid_old: false
|
| 142 |
+
# Frequency Branch
|
| 143 |
+
multi_freqs: []
|
| 144 |
+
multi_freqs_depth: 3
|
| 145 |
+
freq_emb: 0.2
|
| 146 |
+
emb_scale: 10
|
| 147 |
+
emb_smooth: true
|
| 148 |
+
# Convolutions
|
| 149 |
+
kernel_size: 8
|
| 150 |
+
stride: 4
|
| 151 |
+
time_stride: 2
|
| 152 |
+
context: 1
|
| 153 |
+
context_enc: 0
|
| 154 |
+
# normalization
|
| 155 |
+
norm_starts: 4
|
| 156 |
+
norm_groups: 4
|
| 157 |
+
# DConv residual branch
|
| 158 |
+
dconv_mode: 1
|
| 159 |
+
dconv_depth: 2
|
| 160 |
+
dconv_comp: 4
|
| 161 |
+
dconv_attn: 4
|
| 162 |
+
dconv_lstm: 4
|
| 163 |
+
dconv_init: 1e-3
|
| 164 |
+
# Weight init
|
| 165 |
+
rescale: 0.1
|
| 166 |
+
|
| 167 |
+
# Torchaudio implementation of HDemucs
|
| 168 |
+
torch_hdemucs:
|
| 169 |
+
# Channels
|
| 170 |
+
channels: 48
|
| 171 |
+
growth: 2
|
| 172 |
+
# STFT
|
| 173 |
+
nfft: 4096
|
| 174 |
+
# Main structure
|
| 175 |
+
depth: 6
|
| 176 |
+
freq_emb: 0.2
|
| 177 |
+
emb_scale: 10
|
| 178 |
+
emb_smooth: true
|
| 179 |
+
# Convolutions
|
| 180 |
+
kernel_size: 8
|
| 181 |
+
stride: 4
|
| 182 |
+
time_stride: 2
|
| 183 |
+
context: 1
|
| 184 |
+
context_enc: 0
|
| 185 |
+
# normalization
|
| 186 |
+
norm_starts: 4
|
| 187 |
+
norm_groups: 4
|
| 188 |
+
# DConv residual branch
|
| 189 |
+
dconv_depth: 2
|
| 190 |
+
dconv_comp: 4
|
| 191 |
+
dconv_attn: 4
|
| 192 |
+
dconv_lstm: 4
|
| 193 |
+
dconv_init: 1e-3
|
| 194 |
+
|
| 195 |
+
htdemucs: # see demucs/htdemucs.py for a detailed description
|
| 196 |
+
# Channels
|
| 197 |
+
channels: 48
|
| 198 |
+
channels_time:
|
| 199 |
+
growth: 2
|
| 200 |
+
# STFT
|
| 201 |
+
nfft: 4096
|
| 202 |
+
wiener_iters: 0
|
| 203 |
+
end_iters: 0
|
| 204 |
+
wiener_residual: false
|
| 205 |
+
cac: true
|
| 206 |
+
# Main structure
|
| 207 |
+
depth: 4
|
| 208 |
+
rewrite: true
|
| 209 |
+
# Frequency Branch
|
| 210 |
+
multi_freqs: []
|
| 211 |
+
multi_freqs_depth: 3
|
| 212 |
+
freq_emb: 0.2
|
| 213 |
+
emb_scale: 10
|
| 214 |
+
emb_smooth: true
|
| 215 |
+
# Convolutions
|
| 216 |
+
kernel_size: 8
|
| 217 |
+
stride: 4
|
| 218 |
+
time_stride: 2
|
| 219 |
+
context: 1
|
| 220 |
+
context_enc: 0
|
| 221 |
+
# normalization
|
| 222 |
+
norm_starts: 4
|
| 223 |
+
norm_groups: 4
|
| 224 |
+
# DConv residual branch
|
| 225 |
+
dconv_mode: 1
|
| 226 |
+
dconv_depth: 2
|
| 227 |
+
dconv_comp: 8
|
| 228 |
+
dconv_init: 1e-3
|
| 229 |
+
# Before the Transformer
|
| 230 |
+
bottom_channels: 0
|
| 231 |
+
# CrossTransformer
|
| 232 |
+
# ------ Common to all
|
| 233 |
+
# Regular parameters
|
| 234 |
+
t_layers: 5
|
| 235 |
+
t_hidden_scale: 4.0
|
| 236 |
+
t_heads: 8
|
| 237 |
+
t_dropout: 0.0
|
| 238 |
+
t_layer_scale: True
|
| 239 |
+
t_gelu: True
|
| 240 |
+
# ------------- Positional Embedding
|
| 241 |
+
t_emb: sin
|
| 242 |
+
t_max_positions: 10000 # for the scaled embedding
|
| 243 |
+
t_max_period: 10000.0
|
| 244 |
+
t_weight_pos_embed: 1.0
|
| 245 |
+
t_cape_mean_normalize: True
|
| 246 |
+
t_cape_augment: True
|
| 247 |
+
t_cape_glob_loc_scale: [5000.0, 1.0, 1.4]
|
| 248 |
+
t_sin_random_shift: 0
|
| 249 |
+
# ------------- norm before a transformer encoder
|
| 250 |
+
t_norm_in: True
|
| 251 |
+
t_norm_in_group: False
|
| 252 |
+
# ------------- norm inside the encoder
|
| 253 |
+
t_group_norm: False
|
| 254 |
+
t_norm_first: True
|
| 255 |
+
t_norm_out: True
|
| 256 |
+
# ------------- optim
|
| 257 |
+
t_weight_decay: 0.0
|
| 258 |
+
t_lr:
|
| 259 |
+
# ------------- sparsity
|
| 260 |
+
t_sparse_self_attn: False
|
| 261 |
+
t_sparse_cross_attn: False
|
| 262 |
+
t_mask_type: diag
|
| 263 |
+
t_mask_random_seed: 42
|
| 264 |
+
t_sparse_attn_window: 400
|
| 265 |
+
t_global_window: 100
|
| 266 |
+
t_sparsity: 0.95
|
| 267 |
+
t_auto_sparsity: False
|
| 268 |
+
# Cross Encoder First (False)
|
| 269 |
+
t_cross_first: False
|
| 270 |
+
# Weight init
|
| 271 |
+
rescale: 0.1
|
| 272 |
+
|
| 273 |
+
svd: # see svd.py for documentation
|
| 274 |
+
penalty: 0
|
| 275 |
+
min_size: 0.1
|
| 276 |
+
dim: 1
|
| 277 |
+
niters: 2
|
| 278 |
+
powm: false
|
| 279 |
+
proba: 1
|
| 280 |
+
conv_only: false
|
| 281 |
+
convtr: false
|
| 282 |
+
bs: 1
|
| 283 |
+
|
| 284 |
+
quant: # quantization hyper params
|
| 285 |
+
diffq: # diffq penalty, typically 1e-4 or 3e-4
|
| 286 |
+
qat: # use QAT with a fixed number of bits (not as good as diffq)
|
| 287 |
+
min_size: 0.2
|
| 288 |
+
group_size: 8
|
| 289 |
+
|
| 290 |
+
dora:
|
| 291 |
+
dir: outputs
|
| 292 |
+
exclude: ["misc.*", "slurm.*", 'test.reval', 'flag', 'dset.backend']
|
| 293 |
+
|
| 294 |
+
slurm:
|
| 295 |
+
time: 4320
|
| 296 |
+
constraint: volta32gb
|
| 297 |
+
setup: ['module load cudnn/v8.4.1.50-cuda.11.6 NCCL/2.11.4-6-cuda.11.6 cuda/11.6']
|
| 298 |
+
|
| 299 |
+
# Hydra config
|
| 300 |
+
hydra:
|
| 301 |
+
job_logging:
|
| 302 |
+
formatters:
|
| 303 |
+
colorlog:
|
| 304 |
+
datefmt: "%m-%d %H:%M:%S"
|
conf/dset/aetl.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# automix dataset with Musdb, extra training data and the test set of Musdb.
|
| 4 |
+
# This used even more remixes than auto_extra_test.
|
| 5 |
+
dset:
|
| 6 |
+
wav: /checkpoint/defossez/datasets/aetl
|
| 7 |
+
samplerate: 44100
|
| 8 |
+
channels: 2
|
| 9 |
+
epochs: 320
|
| 10 |
+
max_batches: 500
|
| 11 |
+
|
| 12 |
+
augment:
|
| 13 |
+
shift_same: true
|
| 14 |
+
scale:
|
| 15 |
+
proba: 0.
|
| 16 |
+
remix:
|
| 17 |
+
proba: 0
|
| 18 |
+
repitch:
|
| 19 |
+
proba: 0
|
conf/dset/auto_extra_test.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# automix dataset with Musdb, extra training data and the test set of Musdb.
|
| 4 |
+
dset:
|
| 5 |
+
wav: /checkpoint/defossez/datasets/automix_extra_test2
|
| 6 |
+
samplerate: 44100
|
| 7 |
+
channels: 2
|
| 8 |
+
epochs: 320
|
| 9 |
+
max_batches: 500
|
| 10 |
+
|
| 11 |
+
augment:
|
| 12 |
+
shift_same: true
|
| 13 |
+
scale:
|
| 14 |
+
proba: 0.
|
| 15 |
+
remix:
|
| 16 |
+
proba: 0
|
| 17 |
+
repitch:
|
| 18 |
+
proba: 0
|
conf/dset/auto_mus.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Automix dataset based on musdb train set.
|
| 4 |
+
dset:
|
| 5 |
+
wav: /checkpoint/defossez/datasets/automix_musdb
|
| 6 |
+
samplerate: 44100
|
| 7 |
+
channels: 2
|
| 8 |
+
epochs: 360
|
| 9 |
+
max_batches: 300
|
| 10 |
+
test:
|
| 11 |
+
every: 4
|
| 12 |
+
|
| 13 |
+
augment:
|
| 14 |
+
shift_same: true
|
| 15 |
+
scale:
|
| 16 |
+
proba: 0.5
|
| 17 |
+
remix:
|
| 18 |
+
proba: 0
|
| 19 |
+
repitch:
|
| 20 |
+
proba: 0
|
conf/dset/extra44.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Musdb + extra tracks
|
| 4 |
+
dset:
|
| 5 |
+
wav: /checkpoint/defossez/datasets/allstems_44/
|
| 6 |
+
samplerate: 44100
|
| 7 |
+
channels: 2
|
| 8 |
+
epochs: 320
|
conf/dset/extra_mmi_goodclean.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Musdb + extra tracks
|
| 4 |
+
dset:
|
| 5 |
+
wav: /checkpoint/defossez/datasets/allstems_44/
|
| 6 |
+
wav2: /checkpoint/defossez/datasets/mmi44_goodclean
|
| 7 |
+
samplerate: 44100
|
| 8 |
+
channels: 2
|
| 9 |
+
wav2_weight: null
|
| 10 |
+
wav2_valid: false
|
| 11 |
+
valid_samples: 100
|
| 12 |
+
epochs: 1200
|
conf/dset/extra_test.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Musdb + extra tracks + test set from musdb.
|
| 4 |
+
dset:
|
| 5 |
+
wav: /checkpoint/defossez/datasets/allstems_test_44/
|
| 6 |
+
samplerate: 44100
|
| 7 |
+
channels: 2
|
| 8 |
+
epochs: 320
|
| 9 |
+
max_batches: 700
|
| 10 |
+
test:
|
| 11 |
+
sdr: false
|
| 12 |
+
every: 500
|
conf/dset/musdb44.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
dset:
|
| 4 |
+
samplerate: 44100
|
| 5 |
+
channels: 2
|
conf/dset/sdx23_bleeding.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Musdb + extra tracks
|
| 4 |
+
dset:
|
| 5 |
+
wav: /shared/home/defossez/data/datasets/moisesdb23_bleeding_v1.0/
|
| 6 |
+
use_musdb: false
|
| 7 |
+
samplerate: 44100
|
| 8 |
+
channels: 2
|
| 9 |
+
backend: soundfile # must use soundfile as some mixture would clip with sox.
|
| 10 |
+
epochs: 320
|
conf/dset/sdx23_labelnoise.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Musdb + extra tracks
|
| 4 |
+
dset:
|
| 5 |
+
wav: /shared/home/defossez/data/datasets/moisesdb23_labelnoise_v1.0
|
| 6 |
+
use_musdb: false
|
| 7 |
+
samplerate: 44100
|
| 8 |
+
channels: 2
|
| 9 |
+
backend: soundfile # must use soundfile as some mixture would clip with sox.
|
| 10 |
+
epochs: 320
|
conf/svd/base.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
svd:
|
| 4 |
+
penalty: 0
|
| 5 |
+
min_size: 1
|
| 6 |
+
dim: 50
|
| 7 |
+
niters: 4
|
| 8 |
+
powm: false
|
| 9 |
+
proba: 1
|
| 10 |
+
conv_only: false
|
| 11 |
+
convtr: false # ideally this should be true, but some models were trained with this to false.
|
| 12 |
+
|
| 13 |
+
optim:
|
| 14 |
+
beta2: 0.9998
|
conf/svd/base2.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
svd:
|
| 4 |
+
penalty: 0
|
| 5 |
+
min_size: 1
|
| 6 |
+
dim: 100
|
| 7 |
+
niters: 4
|
| 8 |
+
powm: false
|
| 9 |
+
proba: 1
|
| 10 |
+
conv_only: false
|
| 11 |
+
convtr: true
|
| 12 |
+
|
| 13 |
+
optim:
|
| 14 |
+
beta2: 0.9998
|
conf/svd/default.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
conf/variant/default.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
conf/variant/example.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
model: hdemucs
|
| 4 |
+
hdemucs:
|
| 5 |
+
channels: 32
|
conf/variant/finetune.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
epochs: 4
|
| 4 |
+
batch_size: 16
|
| 5 |
+
optim:
|
| 6 |
+
lr: 0.0006
|
| 7 |
+
test:
|
| 8 |
+
every: 1
|
| 9 |
+
sdr: false
|
| 10 |
+
dset:
|
| 11 |
+
segment: 28
|
| 12 |
+
shift: 2
|
| 13 |
+
|
| 14 |
+
augment:
|
| 15 |
+
scale:
|
| 16 |
+
proba: 0
|
| 17 |
+
shift_same: true
|
| 18 |
+
remix:
|
| 19 |
+
proba: 0
|
demucs.png
ADDED
|
Git LFS Details
|
demucs/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
__version__ = "4.1.0a2"
|
demucs/__main__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .separate import main
|
| 8 |
+
|
| 9 |
+
if __name__ == '__main__':
|
| 10 |
+
main()
|
demucs/api.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""API methods for demucs
|
| 8 |
+
|
| 9 |
+
Classes
|
| 10 |
+
-------
|
| 11 |
+
`demucs.api.Separator`: The base separator class
|
| 12 |
+
|
| 13 |
+
Functions
|
| 14 |
+
---------
|
| 15 |
+
`demucs.api.save_audio`: Save an audio
|
| 16 |
+
`demucs.api.list_models`: Get models list
|
| 17 |
+
|
| 18 |
+
Examples
|
| 19 |
+
--------
|
| 20 |
+
See the end of this module (if __name__ == "__main__")
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import subprocess
|
| 24 |
+
|
| 25 |
+
import torch as th
|
| 26 |
+
import torchaudio as ta
|
| 27 |
+
|
| 28 |
+
from dora.log import fatal
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from typing import Optional, Callable, Dict, Tuple, Union
|
| 31 |
+
|
| 32 |
+
from .apply import apply_model, _replace_dict
|
| 33 |
+
from .audio import AudioFile, convert_audio, save_audio
|
| 34 |
+
from .pretrained import get_model, _parse_remote_files, REMOTE_ROOT
|
| 35 |
+
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LoadAudioError(Exception):
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class LoadModelError(Exception):
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class _NotProvided:
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
NotProvided = _NotProvided()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Separator:
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
model: str = "htdemucs",
|
| 57 |
+
repo: Optional[Path] = None,
|
| 58 |
+
device: str = "cuda" if th.cuda.is_available() else "cpu",
|
| 59 |
+
shifts: int = 1,
|
| 60 |
+
overlap: float = 0.25,
|
| 61 |
+
split: bool = True,
|
| 62 |
+
segment: Optional[int] = None,
|
| 63 |
+
jobs: int = 0,
|
| 64 |
+
progress: bool = False,
|
| 65 |
+
callback: Optional[Callable[[dict], None]] = None,
|
| 66 |
+
callback_arg: Optional[dict] = None,
|
| 67 |
+
):
|
| 68 |
+
"""
|
| 69 |
+
`class Separator`
|
| 70 |
+
=================
|
| 71 |
+
|
| 72 |
+
Parameters
|
| 73 |
+
----------
|
| 74 |
+
model: Pretrained model name or signature. Default is htdemucs.
|
| 75 |
+
repo: Folder containing all pre-trained models for use.
|
| 76 |
+
segment: Length (in seconds) of each segment (only available if `split` is `True`). If \
|
| 77 |
+
not specified, will use the command line option.
|
| 78 |
+
shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \
|
| 79 |
+
apply the oppositve shift to the output. This is repeated `shifts` time and all \
|
| 80 |
+
predictions are averaged. This effectively makes the model time equivariant and \
|
| 81 |
+
improves SDR by up to 0.2 points. If not specified, will use the command line option.
|
| 82 |
+
split: If True, the input will be broken down into small chunks (length set by `segment`) \
|
| 83 |
+
and predictions will be performed individually on each and concatenated. Useful for \
|
| 84 |
+
model with large memory footprint like Tasnet. If not specified, will use the command \
|
| 85 |
+
line option.
|
| 86 |
+
overlap: The overlap between the splits. If not specified, will use the command line \
|
| 87 |
+
option.
|
| 88 |
+
device (torch.device, str, or None): If provided, device on which to execute the \
|
| 89 |
+
computation, otherwise `wav.device` is assumed. When `device` is different from \
|
| 90 |
+
`wav.device`, only local computations will be on `device`, while the entire tracks \
|
| 91 |
+
will be stored on `wav.device`. If not specified, will use the command line option.
|
| 92 |
+
jobs: Number of jobs. This can increase memory usage but will be much faster when \
|
| 93 |
+
multiple cores are available. If not specified, will use the command line option.
|
| 94 |
+
callback: A function will be called when the separation of a chunk starts or finished. \
|
| 95 |
+
The argument passed to the function will be a dict. For more information, please see \
|
| 96 |
+
the Callback section.
|
| 97 |
+
callback_arg: A dict containing private parameters to be passed to callback function. For \
|
| 98 |
+
more information, please see the Callback section.
|
| 99 |
+
progress: If true, show a progress bar.
|
| 100 |
+
|
| 101 |
+
Callback
|
| 102 |
+
--------
|
| 103 |
+
The function will be called with only one positional parameter whose type is `dict`. The
|
| 104 |
+
`callback_arg` will be combined with information of current separation progress. The
|
| 105 |
+
progress information will override the values in `callback_arg` if same key has been used.
|
| 106 |
+
To abort the separation, raise `KeyboardInterrupt`.
|
| 107 |
+
|
| 108 |
+
Progress information contains several keys (These keys will always exist):
|
| 109 |
+
- `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.
|
| 110 |
+
- `shift_idx`: The index of shifts. Starts from 0.
|
| 111 |
+
- `segment_offset`: The offset of current segment. If the number is 441000, it doesn't
|
| 112 |
+
mean that it is at the 441000 second of the audio, but the "frame" of the tensor.
|
| 113 |
+
- `state`: Could be `"start"` or `"end"`.
|
| 114 |
+
- `audio_length`: Length of the audio (in "frame" of the tensor).
|
| 115 |
+
- `models`: Count of submodels in the model.
|
| 116 |
+
"""
|
| 117 |
+
self._name = model
|
| 118 |
+
self._repo = repo
|
| 119 |
+
self._load_model()
|
| 120 |
+
self.update_parameter(device=device, shifts=shifts, overlap=overlap, split=split,
|
| 121 |
+
segment=segment, jobs=jobs, progress=progress, callback=callback,
|
| 122 |
+
callback_arg=callback_arg)
|
| 123 |
+
|
| 124 |
+
def update_parameter(
|
| 125 |
+
self,
|
| 126 |
+
device: Union[str, _NotProvided] = NotProvided,
|
| 127 |
+
shifts: Union[int, _NotProvided] = NotProvided,
|
| 128 |
+
overlap: Union[float, _NotProvided] = NotProvided,
|
| 129 |
+
split: Union[bool, _NotProvided] = NotProvided,
|
| 130 |
+
segment: Optional[Union[int, _NotProvided]] = NotProvided,
|
| 131 |
+
jobs: Union[int, _NotProvided] = NotProvided,
|
| 132 |
+
progress: Union[bool, _NotProvided] = NotProvided,
|
| 133 |
+
callback: Optional[
|
| 134 |
+
Union[Callable[[dict], None], _NotProvided]
|
| 135 |
+
] = NotProvided,
|
| 136 |
+
callback_arg: Optional[Union[dict, _NotProvided]] = NotProvided,
|
| 137 |
+
):
|
| 138 |
+
"""
|
| 139 |
+
Update the parameters of separation.
|
| 140 |
+
|
| 141 |
+
Parameters
|
| 142 |
+
----------
|
| 143 |
+
segment: Length (in seconds) of each segment (only available if `split` is `True`). If \
|
| 144 |
+
not specified, will use the command line option.
|
| 145 |
+
shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \
|
| 146 |
+
apply the oppositve shift to the output. This is repeated `shifts` time and all \
|
| 147 |
+
predictions are averaged. This effectively makes the model time equivariant and \
|
| 148 |
+
improves SDR by up to 0.2 points. If not specified, will use the command line option.
|
| 149 |
+
split: If True, the input will be broken down into small chunks (length set by `segment`) \
|
| 150 |
+
and predictions will be performed individually on each and concatenated. Useful for \
|
| 151 |
+
model with large memory footprint like Tasnet. If not specified, will use the command \
|
| 152 |
+
line option.
|
| 153 |
+
overlap: The overlap between the splits. If not specified, will use the command line \
|
| 154 |
+
option.
|
| 155 |
+
device (torch.device, str, or None): If provided, device on which to execute the \
|
| 156 |
+
computation, otherwise `wav.device` is assumed. When `device` is different from \
|
| 157 |
+
`wav.device`, only local computations will be on `device`, while the entire tracks \
|
| 158 |
+
will be stored on `wav.device`. If not specified, will use the command line option.
|
| 159 |
+
jobs: Number of jobs. This can increase memory usage but will be much faster when \
|
| 160 |
+
multiple cores are available. If not specified, will use the command line option.
|
| 161 |
+
callback: A function will be called when the separation of a chunk starts or finished. \
|
| 162 |
+
The argument passed to the function will be a dict. For more information, please see \
|
| 163 |
+
the Callback section.
|
| 164 |
+
callback_arg: A dict containing private parameters to be passed to callback function. For \
|
| 165 |
+
more information, please see the Callback section.
|
| 166 |
+
progress: If true, show a progress bar.
|
| 167 |
+
|
| 168 |
+
Callback
|
| 169 |
+
--------
|
| 170 |
+
The function will be called with only one positional parameter whose type is `dict`. The
|
| 171 |
+
`callback_arg` will be combined with information of current separation progress. The
|
| 172 |
+
progress information will override the values in `callback_arg` if same key has been used.
|
| 173 |
+
To abort the separation, raise `KeyboardInterrupt`.
|
| 174 |
+
|
| 175 |
+
Progress information contains several keys (These keys will always exist):
|
| 176 |
+
- `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.
|
| 177 |
+
- `shift_idx`: The index of shifts. Starts from 0.
|
| 178 |
+
- `segment_offset`: The offset of current segment. If the number is 441000, it doesn't
|
| 179 |
+
mean that it is at the 441000 second of the audio, but the "frame" of the tensor.
|
| 180 |
+
- `state`: Could be `"start"` or `"end"`.
|
| 181 |
+
- `audio_length`: Length of the audio (in "frame" of the tensor).
|
| 182 |
+
- `models`: Count of submodels in the model.
|
| 183 |
+
"""
|
| 184 |
+
if not isinstance(device, _NotProvided):
|
| 185 |
+
self._device = device
|
| 186 |
+
if not isinstance(shifts, _NotProvided):
|
| 187 |
+
self._shifts = shifts
|
| 188 |
+
if not isinstance(overlap, _NotProvided):
|
| 189 |
+
self._overlap = overlap
|
| 190 |
+
if not isinstance(split, _NotProvided):
|
| 191 |
+
self._split = split
|
| 192 |
+
if not isinstance(segment, _NotProvided):
|
| 193 |
+
self._segment = segment
|
| 194 |
+
if not isinstance(jobs, _NotProvided):
|
| 195 |
+
self._jobs = jobs
|
| 196 |
+
if not isinstance(progress, _NotProvided):
|
| 197 |
+
self._progress = progress
|
| 198 |
+
if not isinstance(callback, _NotProvided):
|
| 199 |
+
self._callback = callback
|
| 200 |
+
if not isinstance(callback_arg, _NotProvided):
|
| 201 |
+
self._callback_arg = callback_arg
|
| 202 |
+
|
| 203 |
+
def _load_model(self):
|
| 204 |
+
self._model = get_model(name=self._name, repo=self._repo)
|
| 205 |
+
if self._model is None:
|
| 206 |
+
raise LoadModelError("Failed to load model")
|
| 207 |
+
self._audio_channels = self._model.audio_channels
|
| 208 |
+
self._samplerate = self._model.samplerate
|
| 209 |
+
|
| 210 |
+
def _load_audio(self, track: Path):
|
| 211 |
+
errors = {}
|
| 212 |
+
wav = None
|
| 213 |
+
|
| 214 |
+
try:
|
| 215 |
+
wav = AudioFile(track).read(streams=0, samplerate=self._samplerate,
|
| 216 |
+
channels=self._audio_channels)
|
| 217 |
+
except FileNotFoundError:
|
| 218 |
+
errors["ffmpeg"] = "FFmpeg is not installed."
|
| 219 |
+
except subprocess.CalledProcessError:
|
| 220 |
+
errors["ffmpeg"] = "FFmpeg could not read the file."
|
| 221 |
+
|
| 222 |
+
if wav is None:
|
| 223 |
+
try:
|
| 224 |
+
wav, sr = ta.load(str(track))
|
| 225 |
+
except RuntimeError as err:
|
| 226 |
+
errors["torchaudio"] = err.args[0]
|
| 227 |
+
else:
|
| 228 |
+
wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)
|
| 229 |
+
|
| 230 |
+
if wav is None:
|
| 231 |
+
raise LoadAudioError(
|
| 232 |
+
"\n".join(
|
| 233 |
+
"When trying to load using {}, got the following error: {}".format(
|
| 234 |
+
backend, error
|
| 235 |
+
)
|
| 236 |
+
for backend, error in errors.items()
|
| 237 |
+
)
|
| 238 |
+
)
|
| 239 |
+
return wav
|
| 240 |
+
|
| 241 |
+
def separate_tensor(
|
| 242 |
+
self, wav: th.Tensor, sr: Optional[int] = None
|
| 243 |
+
) -> Tuple[th.Tensor, Dict[str, th.Tensor]]:
|
| 244 |
+
"""
|
| 245 |
+
Separate a loaded tensor.
|
| 246 |
+
|
| 247 |
+
Parameters
|
| 248 |
+
----------
|
| 249 |
+
wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, \
|
| 250 |
+
while the second is the waveform of each channel. Type should be float32. \
|
| 251 |
+
e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels.
|
| 252 |
+
sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the \
|
| 253 |
+
model.
|
| 254 |
+
|
| 255 |
+
Returns
|
| 256 |
+
-------
|
| 257 |
+
A tuple, whose first element is the original wave and second element is a dict, whose keys
|
| 258 |
+
are the name of stems and values are separated waves. The original wave will have already
|
| 259 |
+
been resampled.
|
| 260 |
+
|
| 261 |
+
Notes
|
| 262 |
+
-----
|
| 263 |
+
Use this function with cautiousness. This function does not provide data verifying.
|
| 264 |
+
"""
|
| 265 |
+
if sr is not None and sr != self.samplerate:
|
| 266 |
+
wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)
|
| 267 |
+
ref = wav.mean(0)
|
| 268 |
+
wav -= ref.mean()
|
| 269 |
+
wav /= ref.std() + 1e-8
|
| 270 |
+
out = apply_model(
|
| 271 |
+
self._model,
|
| 272 |
+
wav[None],
|
| 273 |
+
segment=self._segment,
|
| 274 |
+
shifts=self._shifts,
|
| 275 |
+
split=self._split,
|
| 276 |
+
overlap=self._overlap,
|
| 277 |
+
device=self._device,
|
| 278 |
+
num_workers=self._jobs,
|
| 279 |
+
callback=self._callback,
|
| 280 |
+
callback_arg=_replace_dict(
|
| 281 |
+
self._callback_arg, ("audio_length", wav.shape[1])
|
| 282 |
+
),
|
| 283 |
+
progress=self._progress,
|
| 284 |
+
)
|
| 285 |
+
if out is None:
|
| 286 |
+
raise KeyboardInterrupt
|
| 287 |
+
out *= ref.std() + 1e-8
|
| 288 |
+
out += ref.mean()
|
| 289 |
+
wav *= ref.std() + 1e-8
|
| 290 |
+
wav += ref.mean()
|
| 291 |
+
return (wav, dict(zip(self._model.sources, out[0])))
|
| 292 |
+
|
| 293 |
+
def separate_audio_file(self, file: Path):
|
| 294 |
+
"""
|
| 295 |
+
Separate an audio file. The method will automatically read the file.
|
| 296 |
+
|
| 297 |
+
Parameters
|
| 298 |
+
----------
|
| 299 |
+
wav: Path of the file to be separated.
|
| 300 |
+
|
| 301 |
+
Returns
|
| 302 |
+
-------
|
| 303 |
+
A tuple, whose first element is the original wave and second element is a dict, whose keys
|
| 304 |
+
are the name of stems and values are separated waves. The original wave will have already
|
| 305 |
+
been resampled.
|
| 306 |
+
"""
|
| 307 |
+
return self.separate_tensor(self._load_audio(file), self.samplerate)
|
| 308 |
+
|
| 309 |
+
@property
|
| 310 |
+
def samplerate(self):
|
| 311 |
+
return self._samplerate
|
| 312 |
+
|
| 313 |
+
@property
|
| 314 |
+
def audio_channels(self):
|
| 315 |
+
return self._audio_channels
|
| 316 |
+
|
| 317 |
+
@property
|
| 318 |
+
def model(self):
|
| 319 |
+
return self._model
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def list_models(repo: Optional[Path] = None) -> Dict[str, Dict[str, Union[str, Path]]]:
|
| 323 |
+
"""
|
| 324 |
+
List the available models. Please remember that not all the returned models can be
|
| 325 |
+
successfully loaded.
|
| 326 |
+
|
| 327 |
+
Parameters
|
| 328 |
+
----------
|
| 329 |
+
repo: The repo whose models are to be listed.
|
| 330 |
+
|
| 331 |
+
Returns
|
| 332 |
+
-------
|
| 333 |
+
A dict with two keys ("single" for single models and "bag" for bag of models). The values are
|
| 334 |
+
lists whose components are strs.
|
| 335 |
+
"""
|
| 336 |
+
model_repo: ModelOnlyRepo
|
| 337 |
+
if repo is None:
|
| 338 |
+
models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
|
| 339 |
+
model_repo = RemoteRepo(models)
|
| 340 |
+
bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
|
| 341 |
+
else:
|
| 342 |
+
if not repo.is_dir():
|
| 343 |
+
fatal(f"{repo} must exist and be a directory.")
|
| 344 |
+
model_repo = LocalRepo(repo)
|
| 345 |
+
bag_repo = BagOnlyRepo(repo, model_repo)
|
| 346 |
+
return {"single": model_repo.list_model(), "bag": bag_repo.list_model()}
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
if __name__ == "__main__":
|
| 350 |
+
# Test API functions
|
| 351 |
+
# two-stem not supported
|
| 352 |
+
|
| 353 |
+
from .separate import get_parser
|
| 354 |
+
|
| 355 |
+
args = get_parser().parse_args()
|
| 356 |
+
separator = Separator(
|
| 357 |
+
model=args.name,
|
| 358 |
+
repo=args.repo,
|
| 359 |
+
device=args.device,
|
| 360 |
+
shifts=args.shifts,
|
| 361 |
+
overlap=args.overlap,
|
| 362 |
+
split=args.split,
|
| 363 |
+
segment=args.segment,
|
| 364 |
+
jobs=args.jobs,
|
| 365 |
+
callback=print
|
| 366 |
+
)
|
| 367 |
+
out = args.out / args.name
|
| 368 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 369 |
+
for file in args.tracks:
|
| 370 |
+
separated = separator.separate_audio_file(file)[1]
|
| 371 |
+
if args.mp3:
|
| 372 |
+
ext = "mp3"
|
| 373 |
+
elif args.flac:
|
| 374 |
+
ext = "flac"
|
| 375 |
+
else:
|
| 376 |
+
ext = "wav"
|
| 377 |
+
kwargs = {
|
| 378 |
+
"samplerate": separator.samplerate,
|
| 379 |
+
"bitrate": args.mp3_bitrate,
|
| 380 |
+
"clip": args.clip_mode,
|
| 381 |
+
"as_float": args.float32,
|
| 382 |
+
"bits_per_sample": 24 if args.int24 else 16,
|
| 383 |
+
}
|
| 384 |
+
for stem, source in separated.items():
|
| 385 |
+
stem = out / args.filename.format(
|
| 386 |
+
track=Path(file).name.rsplit(".", 1)[0],
|
| 387 |
+
trackext=Path(file).name.rsplit(".", 1)[-1],
|
| 388 |
+
stem=stem,
|
| 389 |
+
ext=ext,
|
| 390 |
+
)
|
| 391 |
+
stem.parent.mkdir(parents=True, exist_ok=True)
|
| 392 |
+
save_audio(source, str(stem), **kwargs)
|
demucs/apply.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
Code to apply a model to a mix. It will handle chunking with overlaps and
|
| 8 |
+
inteprolation between chunks, as well as the "shift trick".
|
| 9 |
+
"""
|
| 10 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 11 |
+
import copy
|
| 12 |
+
import random
|
| 13 |
+
from threading import Lock
|
| 14 |
+
import typing as tp
|
| 15 |
+
|
| 16 |
+
import torch as th
|
| 17 |
+
from torch import nn
|
| 18 |
+
from torch.nn import functional as F
|
| 19 |
+
import tqdm
|
| 20 |
+
|
| 21 |
+
from .demucs import Demucs
|
| 22 |
+
from .hdemucs import HDemucs
|
| 23 |
+
from .htdemucs import HTDemucs
|
| 24 |
+
from .utils import center_trim, DummyPoolExecutor
|
| 25 |
+
|
| 26 |
+
Model = tp.Union[Demucs, HDemucs, HTDemucs]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class BagOfModels(nn.Module):
|
| 30 |
+
def __init__(self, models: tp.List[Model],
|
| 31 |
+
weights: tp.Optional[tp.List[tp.List[float]]] = None,
|
| 32 |
+
segment: tp.Optional[float] = None):
|
| 33 |
+
"""
|
| 34 |
+
Represents a bag of models with specific weights.
|
| 35 |
+
You should call `apply_model` rather than calling directly the forward here for
|
| 36 |
+
optimal performance.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
models (list[nn.Module]): list of Demucs/HDemucs models.
|
| 40 |
+
weights (list[list[float]]): list of weights. If None, assumed to
|
| 41 |
+
be all ones, otherwise it should be a list of N list (N number of models),
|
| 42 |
+
each containing S floats (S number of sources).
|
| 43 |
+
segment (None or float): overrides the `segment` attribute of each model
|
| 44 |
+
(this is performed inplace, be careful is you reuse the models passed).
|
| 45 |
+
"""
|
| 46 |
+
super().__init__()
|
| 47 |
+
assert len(models) > 0
|
| 48 |
+
first = models[0]
|
| 49 |
+
for other in models:
|
| 50 |
+
assert other.sources == first.sources
|
| 51 |
+
assert other.samplerate == first.samplerate
|
| 52 |
+
assert other.audio_channels == first.audio_channels
|
| 53 |
+
if segment is not None:
|
| 54 |
+
if not isinstance(other, HTDemucs) and segment > other.segment:
|
| 55 |
+
other.segment = segment
|
| 56 |
+
|
| 57 |
+
self.audio_channels = first.audio_channels
|
| 58 |
+
self.samplerate = first.samplerate
|
| 59 |
+
self.sources = first.sources
|
| 60 |
+
self.models = nn.ModuleList(models)
|
| 61 |
+
|
| 62 |
+
if weights is None:
|
| 63 |
+
weights = [[1. for _ in first.sources] for _ in models]
|
| 64 |
+
else:
|
| 65 |
+
assert len(weights) == len(models)
|
| 66 |
+
for weight in weights:
|
| 67 |
+
assert len(weight) == len(first.sources)
|
| 68 |
+
self.weights = weights
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def max_allowed_segment(self) -> float:
|
| 72 |
+
max_allowed_segment = float('inf')
|
| 73 |
+
for model in self.models:
|
| 74 |
+
if isinstance(model, HTDemucs):
|
| 75 |
+
max_allowed_segment = min(max_allowed_segment, float(model.segment))
|
| 76 |
+
return max_allowed_segment
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
raise NotImplementedError("Call `apply_model` on this.")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TensorChunk:
|
| 83 |
+
def __init__(self, tensor, offset=0, length=None):
|
| 84 |
+
total_length = tensor.shape[-1]
|
| 85 |
+
assert offset >= 0
|
| 86 |
+
assert offset < total_length
|
| 87 |
+
|
| 88 |
+
if length is None:
|
| 89 |
+
length = total_length - offset
|
| 90 |
+
else:
|
| 91 |
+
length = min(total_length - offset, length)
|
| 92 |
+
|
| 93 |
+
if isinstance(tensor, TensorChunk):
|
| 94 |
+
self.tensor = tensor.tensor
|
| 95 |
+
self.offset = offset + tensor.offset
|
| 96 |
+
else:
|
| 97 |
+
self.tensor = tensor
|
| 98 |
+
self.offset = offset
|
| 99 |
+
self.length = length
|
| 100 |
+
self.device = tensor.device
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def shape(self):
|
| 104 |
+
shape = list(self.tensor.shape)
|
| 105 |
+
shape[-1] = self.length
|
| 106 |
+
return shape
|
| 107 |
+
|
| 108 |
+
def padded(self, target_length):
|
| 109 |
+
delta = target_length - self.length
|
| 110 |
+
total_length = self.tensor.shape[-1]
|
| 111 |
+
assert delta >= 0
|
| 112 |
+
|
| 113 |
+
start = self.offset - delta // 2
|
| 114 |
+
end = start + target_length
|
| 115 |
+
|
| 116 |
+
correct_start = max(0, start)
|
| 117 |
+
correct_end = min(total_length, end)
|
| 118 |
+
|
| 119 |
+
pad_left = correct_start - start
|
| 120 |
+
pad_right = end - correct_end
|
| 121 |
+
|
| 122 |
+
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
|
| 123 |
+
assert out.shape[-1] == target_length
|
| 124 |
+
return out
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def tensor_chunk(tensor_or_chunk):
|
| 128 |
+
if isinstance(tensor_or_chunk, TensorChunk):
|
| 129 |
+
return tensor_or_chunk
|
| 130 |
+
else:
|
| 131 |
+
assert isinstance(tensor_or_chunk, th.Tensor)
|
| 132 |
+
return TensorChunk(tensor_or_chunk)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _replace_dict(_dict: tp.Optional[dict], *subs: tp.Tuple[tp.Hashable, tp.Any]) -> dict:
|
| 136 |
+
if _dict is None:
|
| 137 |
+
_dict = {}
|
| 138 |
+
else:
|
| 139 |
+
_dict = copy.copy(_dict)
|
| 140 |
+
for key, value in subs:
|
| 141 |
+
_dict[key] = value
|
| 142 |
+
return _dict
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def apply_model(model: tp.Union[BagOfModels, Model],
|
| 146 |
+
mix: tp.Union[th.Tensor, TensorChunk],
|
| 147 |
+
shifts: int = 1, split: bool = True,
|
| 148 |
+
overlap: float = 0.25, transition_power: float = 1.,
|
| 149 |
+
progress: bool = False, device=None,
|
| 150 |
+
num_workers: int = 0, segment: tp.Optional[float] = None,
|
| 151 |
+
pool=None, lock=None,
|
| 152 |
+
callback: tp.Optional[tp.Callable[[dict], None]] = None,
|
| 153 |
+
callback_arg: tp.Optional[dict] = None) -> th.Tensor:
|
| 154 |
+
"""
|
| 155 |
+
Apply model to a given mixture.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
| 159 |
+
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
| 160 |
+
all predictions are averaged. This effectively makes the model time equivariant
|
| 161 |
+
and improves SDR by up to 0.2 points.
|
| 162 |
+
split (bool): if True, the input will be broken down in 8 seconds extracts
|
| 163 |
+
and predictions will be performed individually on each and concatenated.
|
| 164 |
+
Useful for model with large memory footprint like Tasnet.
|
| 165 |
+
progress (bool): if True, show a progress bar (requires split=True)
|
| 166 |
+
device (torch.device, str, or None): if provided, device on which to
|
| 167 |
+
execute the computation, otherwise `mix.device` is assumed.
|
| 168 |
+
When `device` is different from `mix.device`, only local computations will
|
| 169 |
+
be on `device`, while the entire tracks will be stored on `mix.device`.
|
| 170 |
+
num_workers (int): if non zero, device is 'cpu', how many threads to
|
| 171 |
+
use in parallel.
|
| 172 |
+
segment (float or None): override the model segment parameter.
|
| 173 |
+
"""
|
| 174 |
+
if device is None:
|
| 175 |
+
device = mix.device
|
| 176 |
+
else:
|
| 177 |
+
device = th.device(device)
|
| 178 |
+
if pool is None:
|
| 179 |
+
if num_workers > 0 and device.type == 'cpu':
|
| 180 |
+
pool = ThreadPoolExecutor(num_workers)
|
| 181 |
+
else:
|
| 182 |
+
pool = DummyPoolExecutor()
|
| 183 |
+
if lock is None:
|
| 184 |
+
lock = Lock()
|
| 185 |
+
callback_arg = _replace_dict(
|
| 186 |
+
callback_arg, *{"model_idx_in_bag": 0, "shift_idx": 0, "segment_offset": 0}.items()
|
| 187 |
+
)
|
| 188 |
+
kwargs: tp.Dict[str, tp.Any] = {
|
| 189 |
+
'shifts': shifts,
|
| 190 |
+
'split': split,
|
| 191 |
+
'overlap': overlap,
|
| 192 |
+
'transition_power': transition_power,
|
| 193 |
+
'progress': progress,
|
| 194 |
+
'device': device,
|
| 195 |
+
'pool': pool,
|
| 196 |
+
'segment': segment,
|
| 197 |
+
'lock': lock,
|
| 198 |
+
}
|
| 199 |
+
out: tp.Union[float, th.Tensor]
|
| 200 |
+
res: tp.Union[float, th.Tensor]
|
| 201 |
+
if isinstance(model, BagOfModels):
|
| 202 |
+
# Special treatment for bag of model.
|
| 203 |
+
# We explicitely apply multiple times `apply_model` so that the random shifts
|
| 204 |
+
# are different for each model.
|
| 205 |
+
estimates: tp.Union[float, th.Tensor] = 0.
|
| 206 |
+
totals = [0.] * len(model.sources)
|
| 207 |
+
callback_arg["models"] = len(model.models)
|
| 208 |
+
for sub_model, model_weights in zip(model.models, model.weights):
|
| 209 |
+
kwargs["callback"] = ((
|
| 210 |
+
lambda d, i=callback_arg["model_idx_in_bag"]: callback(
|
| 211 |
+
_replace_dict(d, ("model_idx_in_bag", i))) if callback else None)
|
| 212 |
+
)
|
| 213 |
+
original_model_device = next(iter(sub_model.parameters())).device
|
| 214 |
+
sub_model.to(device)
|
| 215 |
+
|
| 216 |
+
res = apply_model(sub_model, mix, **kwargs, callback_arg=callback_arg)
|
| 217 |
+
out = res
|
| 218 |
+
sub_model.to(original_model_device)
|
| 219 |
+
for k, inst_weight in enumerate(model_weights):
|
| 220 |
+
out[:, k, :, :] *= inst_weight
|
| 221 |
+
totals[k] += inst_weight
|
| 222 |
+
estimates += out
|
| 223 |
+
del out
|
| 224 |
+
callback_arg["model_idx_in_bag"] += 1
|
| 225 |
+
|
| 226 |
+
assert isinstance(estimates, th.Tensor)
|
| 227 |
+
for k in range(estimates.shape[1]):
|
| 228 |
+
estimates[:, k, :, :] /= totals[k]
|
| 229 |
+
return estimates
|
| 230 |
+
|
| 231 |
+
if "models" not in callback_arg:
|
| 232 |
+
callback_arg["models"] = 1
|
| 233 |
+
model.to(device)
|
| 234 |
+
model.eval()
|
| 235 |
+
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
|
| 236 |
+
batch, channels, length = mix.shape
|
| 237 |
+
if shifts:
|
| 238 |
+
kwargs['shifts'] = 0
|
| 239 |
+
max_shift = int(0.5 * model.samplerate)
|
| 240 |
+
mix = tensor_chunk(mix)
|
| 241 |
+
assert isinstance(mix, TensorChunk)
|
| 242 |
+
padded_mix = mix.padded(length + 2 * max_shift)
|
| 243 |
+
out = 0.
|
| 244 |
+
for shift_idx in range(shifts):
|
| 245 |
+
offset = random.randint(0, max_shift)
|
| 246 |
+
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
|
| 247 |
+
kwargs["callback"] = (
|
| 248 |
+
(lambda d, i=shift_idx: callback(_replace_dict(d, ("shift_idx", i)))
|
| 249 |
+
if callback else None)
|
| 250 |
+
)
|
| 251 |
+
res = apply_model(model, shifted, **kwargs, callback_arg=callback_arg)
|
| 252 |
+
shifted_out = res
|
| 253 |
+
out += shifted_out[..., max_shift - offset:]
|
| 254 |
+
out /= shifts
|
| 255 |
+
assert isinstance(out, th.Tensor)
|
| 256 |
+
return out
|
| 257 |
+
elif split:
|
| 258 |
+
kwargs['split'] = False
|
| 259 |
+
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
|
| 260 |
+
sum_weight = th.zeros(length, device=mix.device)
|
| 261 |
+
if segment is None:
|
| 262 |
+
segment = model.segment
|
| 263 |
+
assert segment is not None and segment > 0.
|
| 264 |
+
segment_length: int = int(model.samplerate * segment)
|
| 265 |
+
stride = int((1 - overlap) * segment_length)
|
| 266 |
+
offsets = range(0, length, stride)
|
| 267 |
+
scale = float(format(stride / model.samplerate, ".2f"))
|
| 268 |
+
# We start from a triangle shaped weight, with maximal weight in the middle
|
| 269 |
+
# of the segment. Then we normalize and take to the power `transition_power`.
|
| 270 |
+
# Large values of transition power will lead to sharper transitions.
|
| 271 |
+
weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device),
|
| 272 |
+
th.arange(segment_length - segment_length // 2, 0, -1, device=device)])
|
| 273 |
+
assert len(weight) == segment_length
|
| 274 |
+
# If the overlap < 50%, this will translate to linear transition when
|
| 275 |
+
# transition_power is 1.
|
| 276 |
+
weight = (weight / weight.max())**transition_power
|
| 277 |
+
futures = []
|
| 278 |
+
for offset in offsets:
|
| 279 |
+
chunk = TensorChunk(mix, offset, segment_length)
|
| 280 |
+
future = pool.submit(apply_model, model, chunk, **kwargs, callback_arg=callback_arg,
|
| 281 |
+
callback=(lambda d, i=offset:
|
| 282 |
+
callback(_replace_dict(d, ("segment_offset", i)))
|
| 283 |
+
if callback else None))
|
| 284 |
+
futures.append((future, offset))
|
| 285 |
+
offset += segment_length
|
| 286 |
+
if progress:
|
| 287 |
+
futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
|
| 288 |
+
for future, offset in futures:
|
| 289 |
+
try:
|
| 290 |
+
chunk_out = future.result() # type: th.Tensor
|
| 291 |
+
except Exception:
|
| 292 |
+
pool.shutdown(wait=True, cancel_futures=True)
|
| 293 |
+
raise
|
| 294 |
+
chunk_length = chunk_out.shape[-1]
|
| 295 |
+
out[..., offset:offset + segment_length] += (
|
| 296 |
+
weight[:chunk_length] * chunk_out).to(mix.device)
|
| 297 |
+
sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device)
|
| 298 |
+
assert sum_weight.min() > 0
|
| 299 |
+
out /= sum_weight
|
| 300 |
+
assert isinstance(out, th.Tensor)
|
| 301 |
+
return out
|
| 302 |
+
else:
|
| 303 |
+
valid_length: int
|
| 304 |
+
if isinstance(model, HTDemucs) and segment is not None:
|
| 305 |
+
valid_length = int(segment * model.samplerate)
|
| 306 |
+
elif hasattr(model, 'valid_length'):
|
| 307 |
+
valid_length = model.valid_length(length) # type: ignore
|
| 308 |
+
else:
|
| 309 |
+
valid_length = length
|
| 310 |
+
mix = tensor_chunk(mix)
|
| 311 |
+
assert isinstance(mix, TensorChunk)
|
| 312 |
+
padded_mix = mix.padded(valid_length).to(device)
|
| 313 |
+
with lock:
|
| 314 |
+
if callback is not None:
|
| 315 |
+
callback(_replace_dict(callback_arg, ("state", "start"))) # type: ignore
|
| 316 |
+
with th.no_grad():
|
| 317 |
+
out = model(padded_mix)
|
| 318 |
+
with lock:
|
| 319 |
+
if callback is not None:
|
| 320 |
+
callback(_replace_dict(callback_arg, ("state", "end"))) # type: ignore
|
| 321 |
+
assert isinstance(out, th.Tensor)
|
| 322 |
+
return center_trim(out, length)
|
demucs/audio.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
import json
|
| 7 |
+
import subprocess as sp
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import lameenc
|
| 11 |
+
import julius
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torchaudio as ta
|
| 15 |
+
import typing as tp
|
| 16 |
+
|
| 17 |
+
from .utils import temp_filenames
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _read_info(path):
|
| 21 |
+
stdout_data = sp.check_output([
|
| 22 |
+
'ffprobe', "-loglevel", "panic",
|
| 23 |
+
str(path), '-print_format', 'json', '-show_format', '-show_streams'
|
| 24 |
+
])
|
| 25 |
+
return json.loads(stdout_data.decode('utf-8'))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class AudioFile:
|
| 29 |
+
"""
|
| 30 |
+
Allows to read audio from any format supported by ffmpeg, as well as resampling or
|
| 31 |
+
converting to mono on the fly. See :method:`read` for more details.
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, path: Path):
|
| 34 |
+
self.path = Path(path)
|
| 35 |
+
self._info = None
|
| 36 |
+
|
| 37 |
+
def __repr__(self):
|
| 38 |
+
features = [("path", self.path)]
|
| 39 |
+
features.append(("samplerate", self.samplerate()))
|
| 40 |
+
features.append(("channels", self.channels()))
|
| 41 |
+
features.append(("streams", len(self)))
|
| 42 |
+
features_str = ", ".join(f"{name}={value}" for name, value in features)
|
| 43 |
+
return f"AudioFile({features_str})"
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def info(self):
|
| 47 |
+
if self._info is None:
|
| 48 |
+
self._info = _read_info(self.path)
|
| 49 |
+
return self._info
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def duration(self):
|
| 53 |
+
return float(self.info['format']['duration'])
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def _audio_streams(self):
|
| 57 |
+
return [
|
| 58 |
+
index for index, stream in enumerate(self.info["streams"])
|
| 59 |
+
if stream["codec_type"] == "audio"
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
def __len__(self):
|
| 63 |
+
return len(self._audio_streams)
|
| 64 |
+
|
| 65 |
+
def channels(self, stream=0):
|
| 66 |
+
return int(self.info['streams'][self._audio_streams[stream]]['channels'])
|
| 67 |
+
|
| 68 |
+
def samplerate(self, stream=0):
|
| 69 |
+
return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
|
| 70 |
+
|
| 71 |
+
def read(self,
|
| 72 |
+
seek_time=None,
|
| 73 |
+
duration=None,
|
| 74 |
+
streams=slice(None),
|
| 75 |
+
samplerate=None,
|
| 76 |
+
channels=None):
|
| 77 |
+
"""
|
| 78 |
+
Slightly more efficient implementation than stempeg,
|
| 79 |
+
in particular, this will extract all stems at once
|
| 80 |
+
rather than having to loop over one file multiple times
|
| 81 |
+
for each stream.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
seek_time (float): seek time in seconds or None if no seeking is needed.
|
| 85 |
+
duration (float): duration in seconds to extract or None to extract until the end.
|
| 86 |
+
streams (slice, int or list): streams to extract, can be a single int, a list or
|
| 87 |
+
a slice. If it is a slice or list, the output will be of size [S, C, T]
|
| 88 |
+
with S the number of streams, C the number of channels and T the number of samples.
|
| 89 |
+
If it is an int, the output will be [C, T].
|
| 90 |
+
samplerate (int): if provided, will resample on the fly. If None, no resampling will
|
| 91 |
+
be done. Original sampling rate can be obtained with :method:`samplerate`.
|
| 92 |
+
channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
|
| 93 |
+
as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
|
| 94 |
+
See https://sound.stackexchange.com/a/42710.
|
| 95 |
+
Our definition of mono is simply the average of the two channels. Any other
|
| 96 |
+
value will be ignored.
|
| 97 |
+
"""
|
| 98 |
+
streams = np.array(range(len(self)))[streams]
|
| 99 |
+
single = not isinstance(streams, np.ndarray)
|
| 100 |
+
if single:
|
| 101 |
+
streams = [streams]
|
| 102 |
+
|
| 103 |
+
if duration is None:
|
| 104 |
+
target_size = None
|
| 105 |
+
query_duration = None
|
| 106 |
+
else:
|
| 107 |
+
target_size = int((samplerate or self.samplerate()) * duration)
|
| 108 |
+
query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
|
| 109 |
+
|
| 110 |
+
with temp_filenames(len(streams)) as filenames:
|
| 111 |
+
command = ['ffmpeg', '-y']
|
| 112 |
+
command += ['-loglevel', 'panic']
|
| 113 |
+
if seek_time:
|
| 114 |
+
command += ['-ss', str(seek_time)]
|
| 115 |
+
command += ['-i', str(self.path)]
|
| 116 |
+
for stream, filename in zip(streams, filenames):
|
| 117 |
+
command += ['-map', f'0:{self._audio_streams[stream]}']
|
| 118 |
+
if query_duration is not None:
|
| 119 |
+
command += ['-t', str(query_duration)]
|
| 120 |
+
command += ['-threads', '1']
|
| 121 |
+
command += ['-f', 'f32le']
|
| 122 |
+
if samplerate is not None:
|
| 123 |
+
command += ['-ar', str(samplerate)]
|
| 124 |
+
command += [filename]
|
| 125 |
+
|
| 126 |
+
sp.run(command, check=True)
|
| 127 |
+
wavs = []
|
| 128 |
+
for filename in filenames:
|
| 129 |
+
wav = np.fromfile(filename, dtype=np.float32)
|
| 130 |
+
wav = torch.from_numpy(wav)
|
| 131 |
+
wav = wav.view(-1, self.channels()).t()
|
| 132 |
+
if channels is not None:
|
| 133 |
+
wav = convert_audio_channels(wav, channels)
|
| 134 |
+
if target_size is not None:
|
| 135 |
+
wav = wav[..., :target_size]
|
| 136 |
+
wavs.append(wav)
|
| 137 |
+
wav = torch.stack(wavs, dim=0)
|
| 138 |
+
if single:
|
| 139 |
+
wav = wav[0]
|
| 140 |
+
return wav
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def convert_audio_channels(wav, channels=2):
|
| 144 |
+
"""Convert audio to the given number of channels."""
|
| 145 |
+
*shape, src_channels, length = wav.shape
|
| 146 |
+
if src_channels == channels:
|
| 147 |
+
pass
|
| 148 |
+
elif channels == 1:
|
| 149 |
+
# Case 1:
|
| 150 |
+
# The caller asked 1-channel audio, but the stream have multiple
|
| 151 |
+
# channels, downmix all channels.
|
| 152 |
+
wav = wav.mean(dim=-2, keepdim=True)
|
| 153 |
+
elif src_channels == 1:
|
| 154 |
+
# Case 2:
|
| 155 |
+
# The caller asked for multiple channels, but the input file have
|
| 156 |
+
# one single channel, replicate the audio over all channels.
|
| 157 |
+
wav = wav.expand(*shape, channels, length)
|
| 158 |
+
elif src_channels >= channels:
|
| 159 |
+
# Case 3:
|
| 160 |
+
# The caller asked for multiple channels, and the input file have
|
| 161 |
+
# more channels than requested. In that case return the first channels.
|
| 162 |
+
wav = wav[..., :channels, :]
|
| 163 |
+
else:
|
| 164 |
+
# Case 4: What is a reasonable choice here?
|
| 165 |
+
raise ValueError('The audio file has less channels than requested but is not mono.')
|
| 166 |
+
return wav
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def convert_audio(wav, from_samplerate, to_samplerate, channels) -> torch.Tensor:
|
| 170 |
+
"""Convert audio from a given samplerate to a target one and target number of channels."""
|
| 171 |
+
wav = convert_audio_channels(wav, channels)
|
| 172 |
+
return julius.resample_frac(wav, from_samplerate, to_samplerate)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def i16_pcm(wav):
|
| 176 |
+
"""Convert audio to 16 bits integer PCM format."""
|
| 177 |
+
if wav.dtype.is_floating_point:
|
| 178 |
+
return (wav.clamp_(-1, 1) * (2**15 - 1)).short()
|
| 179 |
+
else:
|
| 180 |
+
return wav
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def f32_pcm(wav):
|
| 184 |
+
"""Convert audio to float 32 bits PCM format."""
|
| 185 |
+
if wav.dtype.is_floating_point:
|
| 186 |
+
return wav
|
| 187 |
+
else:
|
| 188 |
+
return wav.float() / (2**15 - 1)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def as_dtype_pcm(wav, dtype):
|
| 192 |
+
"""Convert audio to either f32 pcm or i16 pcm depending on the given dtype."""
|
| 193 |
+
if wav.dtype.is_floating_point:
|
| 194 |
+
return f32_pcm(wav)
|
| 195 |
+
else:
|
| 196 |
+
return i16_pcm(wav)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def encode_mp3(wav, path, samplerate=44100, bitrate=320, quality=2, verbose=False):
|
| 200 |
+
"""Save given audio as mp3. This should work on all OSes."""
|
| 201 |
+
C, T = wav.shape
|
| 202 |
+
wav = i16_pcm(wav)
|
| 203 |
+
encoder = lameenc.Encoder()
|
| 204 |
+
encoder.set_bit_rate(bitrate)
|
| 205 |
+
encoder.set_in_sample_rate(samplerate)
|
| 206 |
+
encoder.set_channels(C)
|
| 207 |
+
encoder.set_quality(quality) # 2-highest, 7-fastest
|
| 208 |
+
if not verbose:
|
| 209 |
+
encoder.silence()
|
| 210 |
+
wav = wav.data.cpu()
|
| 211 |
+
wav = wav.transpose(0, 1).numpy()
|
| 212 |
+
mp3_data = encoder.encode(wav.tobytes())
|
| 213 |
+
mp3_data += encoder.flush()
|
| 214 |
+
with open(path, "wb") as f:
|
| 215 |
+
f.write(mp3_data)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def prevent_clip(wav, mode='rescale'):
|
| 219 |
+
"""
|
| 220 |
+
different strategies for avoiding raw clipping.
|
| 221 |
+
"""
|
| 222 |
+
if mode is None or mode == 'none':
|
| 223 |
+
return wav
|
| 224 |
+
assert wav.dtype.is_floating_point, "too late for clipping"
|
| 225 |
+
if mode == 'rescale':
|
| 226 |
+
wav = wav / max(1.01 * wav.abs().max(), 1)
|
| 227 |
+
elif mode == 'clamp':
|
| 228 |
+
wav = wav.clamp(-0.99, 0.99)
|
| 229 |
+
elif mode == 'tanh':
|
| 230 |
+
wav = torch.tanh(wav)
|
| 231 |
+
else:
|
| 232 |
+
raise ValueError(f"Invalid mode {mode}")
|
| 233 |
+
return wav
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def save_audio(wav: torch.Tensor,
|
| 237 |
+
path: tp.Union[str, Path],
|
| 238 |
+
samplerate: int,
|
| 239 |
+
bitrate: int = 320,
|
| 240 |
+
clip: tp.Literal["rescale", "clamp", "tanh", "none"] = 'rescale',
|
| 241 |
+
bits_per_sample: tp.Literal[16, 24, 32] = 16,
|
| 242 |
+
as_float: bool = False,
|
| 243 |
+
preset: tp.Literal[2, 3, 4, 5, 6, 7] = 2):
|
| 244 |
+
"""Save audio file, automatically preventing clipping if necessary
|
| 245 |
+
based on the given `clip` strategy. If the path ends in `.mp3`, this
|
| 246 |
+
will save as mp3 with the given `bitrate`. Use `preset` to set mp3 quality:
|
| 247 |
+
2 for highest quality, 7 for fastest speed
|
| 248 |
+
"""
|
| 249 |
+
wav = prevent_clip(wav, mode=clip)
|
| 250 |
+
path = Path(path)
|
| 251 |
+
suffix = path.suffix.lower()
|
| 252 |
+
if suffix == ".mp3":
|
| 253 |
+
encode_mp3(wav, path, samplerate, bitrate, preset, verbose=True)
|
| 254 |
+
elif suffix == ".wav":
|
| 255 |
+
if as_float:
|
| 256 |
+
bits_per_sample = 32
|
| 257 |
+
encoding = 'PCM_F'
|
| 258 |
+
else:
|
| 259 |
+
encoding = 'PCM_S'
|
| 260 |
+
ta.save(str(path), wav, sample_rate=samplerate,
|
| 261 |
+
encoding=encoding, bits_per_sample=bits_per_sample)
|
| 262 |
+
elif suffix == ".flac":
|
| 263 |
+
ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample)
|
| 264 |
+
else:
|
| 265 |
+
raise ValueError(f"Invalid suffix for path: {suffix}")
|
demucs/augment.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""Data augmentations.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import random
|
| 10 |
+
import torch as th
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Shift(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Randomly shift audio in time by up to `shift` samples.
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, shift=8192, same=False):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.shift = shift
|
| 21 |
+
self.same = same
|
| 22 |
+
|
| 23 |
+
def forward(self, wav):
|
| 24 |
+
batch, sources, channels, time = wav.size()
|
| 25 |
+
length = time - self.shift
|
| 26 |
+
if self.shift > 0:
|
| 27 |
+
if not self.training:
|
| 28 |
+
wav = wav[..., :length]
|
| 29 |
+
else:
|
| 30 |
+
srcs = 1 if self.same else sources
|
| 31 |
+
offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device)
|
| 32 |
+
offsets = offsets.expand(-1, sources, channels, -1)
|
| 33 |
+
indexes = th.arange(length, device=wav.device)
|
| 34 |
+
wav = wav.gather(3, indexes + offsets)
|
| 35 |
+
return wav
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class FlipChannels(nn.Module):
|
| 39 |
+
"""
|
| 40 |
+
Flip left-right channels.
|
| 41 |
+
"""
|
| 42 |
+
def forward(self, wav):
|
| 43 |
+
batch, sources, channels, time = wav.size()
|
| 44 |
+
if self.training and wav.size(2) == 2:
|
| 45 |
+
left = th.randint(2, (batch, sources, 1, 1), device=wav.device)
|
| 46 |
+
left = left.expand(-1, -1, -1, time)
|
| 47 |
+
right = 1 - left
|
| 48 |
+
wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2)
|
| 49 |
+
return wav
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class FlipSign(nn.Module):
|
| 53 |
+
"""
|
| 54 |
+
Random sign flip.
|
| 55 |
+
"""
|
| 56 |
+
def forward(self, wav):
|
| 57 |
+
batch, sources, channels, time = wav.size()
|
| 58 |
+
if self.training:
|
| 59 |
+
signs = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32)
|
| 60 |
+
wav = wav * (2 * signs - 1)
|
| 61 |
+
return wav
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Remix(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
Shuffle sources to make new mixes.
|
| 67 |
+
"""
|
| 68 |
+
def __init__(self, proba=1, group_size=4):
|
| 69 |
+
"""
|
| 70 |
+
Shuffle sources within one batch.
|
| 71 |
+
Each batch is divided into groups of size `group_size` and shuffling is done within
|
| 72 |
+
each group separatly. This allow to keep the same probability distribution no matter
|
| 73 |
+
the number of GPUs. Without this grouping, using more GPUs would lead to a higher
|
| 74 |
+
probability of keeping two sources from the same track together which can impact
|
| 75 |
+
performance.
|
| 76 |
+
"""
|
| 77 |
+
super().__init__()
|
| 78 |
+
self.proba = proba
|
| 79 |
+
self.group_size = group_size
|
| 80 |
+
|
| 81 |
+
def forward(self, wav):
|
| 82 |
+
batch, streams, channels, time = wav.size()
|
| 83 |
+
device = wav.device
|
| 84 |
+
|
| 85 |
+
if self.training and random.random() < self.proba:
|
| 86 |
+
group_size = self.group_size or batch
|
| 87 |
+
if batch % group_size != 0:
|
| 88 |
+
raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}")
|
| 89 |
+
groups = batch // group_size
|
| 90 |
+
wav = wav.view(groups, group_size, streams, channels, time)
|
| 91 |
+
permutations = th.argsort(th.rand(groups, group_size, streams, 1, 1, device=device),
|
| 92 |
+
dim=1)
|
| 93 |
+
wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time))
|
| 94 |
+
wav = wav.view(batch, streams, channels, time)
|
| 95 |
+
return wav
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Scale(nn.Module):
|
| 99 |
+
def __init__(self, proba=1., min=0.25, max=1.25):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.proba = proba
|
| 102 |
+
self.min = min
|
| 103 |
+
self.max = max
|
| 104 |
+
|
| 105 |
+
def forward(self, wav):
|
| 106 |
+
batch, streams, channels, time = wav.size()
|
| 107 |
+
device = wav.device
|
| 108 |
+
if self.training and random.random() < self.proba:
|
| 109 |
+
scales = th.empty(batch, streams, 1, 1, device=device).uniform_(self.min, self.max)
|
| 110 |
+
wav *= scales
|
| 111 |
+
return wav
|
demucs/demucs.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import typing as tp
|
| 9 |
+
|
| 10 |
+
import julius
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
|
| 15 |
+
from .states import capture_init
|
| 16 |
+
from .utils import center_trim, unfold
|
| 17 |
+
from .transformer import LayerScale
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class BLSTM(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
BiLSTM with same hidden units as input dim.
|
| 23 |
+
If `max_steps` is not None, input will be splitting in overlapping
|
| 24 |
+
chunks and the LSTM applied separately on each chunk.
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, dim, layers=1, max_steps=None, skip=False):
|
| 27 |
+
super().__init__()
|
| 28 |
+
assert max_steps is None or max_steps % 4 == 0
|
| 29 |
+
self.max_steps = max_steps
|
| 30 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
| 31 |
+
self.linear = nn.Linear(2 * dim, dim)
|
| 32 |
+
self.skip = skip
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
B, C, T = x.shape
|
| 36 |
+
y = x
|
| 37 |
+
framed = False
|
| 38 |
+
if self.max_steps is not None and T > self.max_steps:
|
| 39 |
+
width = self.max_steps
|
| 40 |
+
stride = width // 2
|
| 41 |
+
frames = unfold(x, width, stride)
|
| 42 |
+
nframes = frames.shape[2]
|
| 43 |
+
framed = True
|
| 44 |
+
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
|
| 45 |
+
|
| 46 |
+
x = x.permute(2, 0, 1)
|
| 47 |
+
|
| 48 |
+
x = self.lstm(x)[0]
|
| 49 |
+
x = self.linear(x)
|
| 50 |
+
x = x.permute(1, 2, 0)
|
| 51 |
+
if framed:
|
| 52 |
+
out = []
|
| 53 |
+
frames = x.reshape(B, -1, C, width)
|
| 54 |
+
limit = stride // 2
|
| 55 |
+
for k in range(nframes):
|
| 56 |
+
if k == 0:
|
| 57 |
+
out.append(frames[:, k, :, :-limit])
|
| 58 |
+
elif k == nframes - 1:
|
| 59 |
+
out.append(frames[:, k, :, limit:])
|
| 60 |
+
else:
|
| 61 |
+
out.append(frames[:, k, :, limit:-limit])
|
| 62 |
+
out = torch.cat(out, -1)
|
| 63 |
+
out = out[..., :T]
|
| 64 |
+
x = out
|
| 65 |
+
if self.skip:
|
| 66 |
+
x = x + y
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def rescale_conv(conv, reference):
|
| 71 |
+
"""Rescale initial weight scale. It is unclear why it helps but it certainly does.
|
| 72 |
+
"""
|
| 73 |
+
std = conv.weight.std().detach()
|
| 74 |
+
scale = (std / reference)**0.5
|
| 75 |
+
conv.weight.data /= scale
|
| 76 |
+
if conv.bias is not None:
|
| 77 |
+
conv.bias.data /= scale
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def rescale_module(module, reference):
|
| 81 |
+
for sub in module.modules():
|
| 82 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
|
| 83 |
+
rescale_conv(sub, reference)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class DConv(nn.Module):
|
| 87 |
+
"""
|
| 88 |
+
New residual branches in each encoder layer.
|
| 89 |
+
This alternates dilated convolutions, potentially with LSTMs and attention.
|
| 90 |
+
Also before entering each residual branch, dimension is projected on a smaller subspace,
|
| 91 |
+
e.g. of dim `channels // compress`.
|
| 92 |
+
"""
|
| 93 |
+
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
|
| 94 |
+
norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
|
| 95 |
+
kernel=3, dilate=True):
|
| 96 |
+
"""
|
| 97 |
+
Args:
|
| 98 |
+
channels: input/output channels for residual branch.
|
| 99 |
+
compress: amount of channel compression inside the branch.
|
| 100 |
+
depth: number of layers in the residual branch. Each layer has its own
|
| 101 |
+
projection, and potentially LSTM and attention.
|
| 102 |
+
init: initial scale for LayerNorm.
|
| 103 |
+
norm: use GroupNorm.
|
| 104 |
+
attn: use LocalAttention.
|
| 105 |
+
heads: number of heads for the LocalAttention.
|
| 106 |
+
ndecay: number of decay controls in the LocalAttention.
|
| 107 |
+
lstm: use LSTM.
|
| 108 |
+
gelu: Use GELU activation.
|
| 109 |
+
kernel: kernel size for the (dilated) convolutions.
|
| 110 |
+
dilate: if true, use dilation, increasing with the depth.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
super().__init__()
|
| 114 |
+
assert kernel % 2 == 1
|
| 115 |
+
self.channels = channels
|
| 116 |
+
self.compress = compress
|
| 117 |
+
self.depth = abs(depth)
|
| 118 |
+
dilate = depth > 0
|
| 119 |
+
|
| 120 |
+
norm_fn: tp.Callable[[int], nn.Module]
|
| 121 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 122 |
+
if norm:
|
| 123 |
+
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
|
| 124 |
+
|
| 125 |
+
hidden = int(channels / compress)
|
| 126 |
+
|
| 127 |
+
act: tp.Type[nn.Module]
|
| 128 |
+
if gelu:
|
| 129 |
+
act = nn.GELU
|
| 130 |
+
else:
|
| 131 |
+
act = nn.ReLU
|
| 132 |
+
|
| 133 |
+
self.layers = nn.ModuleList([])
|
| 134 |
+
for d in range(self.depth):
|
| 135 |
+
dilation = 2 ** d if dilate else 1
|
| 136 |
+
padding = dilation * (kernel // 2)
|
| 137 |
+
mods = [
|
| 138 |
+
nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
|
| 139 |
+
norm_fn(hidden), act(),
|
| 140 |
+
nn.Conv1d(hidden, 2 * channels, 1),
|
| 141 |
+
norm_fn(2 * channels), nn.GLU(1),
|
| 142 |
+
LayerScale(channels, init),
|
| 143 |
+
]
|
| 144 |
+
if attn:
|
| 145 |
+
mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
|
| 146 |
+
if lstm:
|
| 147 |
+
mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
|
| 148 |
+
layer = nn.Sequential(*mods)
|
| 149 |
+
self.layers.append(layer)
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
for layer in self.layers:
|
| 153 |
+
x = x + layer(x)
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class LocalState(nn.Module):
|
| 158 |
+
"""Local state allows to have attention based only on data (no positional embedding),
|
| 159 |
+
but while setting a constraint on the time window (e.g. decaying penalty term).
|
| 160 |
+
|
| 161 |
+
Also a failed experiments with trying to provide some frequency based attention.
|
| 162 |
+
"""
|
| 163 |
+
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
|
| 164 |
+
super().__init__()
|
| 165 |
+
assert channels % heads == 0, (channels, heads)
|
| 166 |
+
self.heads = heads
|
| 167 |
+
self.nfreqs = nfreqs
|
| 168 |
+
self.ndecay = ndecay
|
| 169 |
+
self.content = nn.Conv1d(channels, channels, 1)
|
| 170 |
+
self.query = nn.Conv1d(channels, channels, 1)
|
| 171 |
+
self.key = nn.Conv1d(channels, channels, 1)
|
| 172 |
+
if nfreqs:
|
| 173 |
+
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
|
| 174 |
+
if ndecay:
|
| 175 |
+
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
| 176 |
+
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
|
| 177 |
+
self.query_decay.weight.data *= 0.01
|
| 178 |
+
assert self.query_decay.bias is not None # stupid type checker
|
| 179 |
+
self.query_decay.bias.data[:] = -2
|
| 180 |
+
self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
|
| 181 |
+
|
| 182 |
+
def forward(self, x):
|
| 183 |
+
B, C, T = x.shape
|
| 184 |
+
heads = self.heads
|
| 185 |
+
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
|
| 186 |
+
# left index are keys, right index are queries
|
| 187 |
+
delta = indexes[:, None] - indexes[None, :]
|
| 188 |
+
|
| 189 |
+
queries = self.query(x).view(B, heads, -1, T)
|
| 190 |
+
keys = self.key(x).view(B, heads, -1, T)
|
| 191 |
+
# t are keys, s are queries
|
| 192 |
+
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
| 193 |
+
dots /= keys.shape[2]**0.5
|
| 194 |
+
if self.nfreqs:
|
| 195 |
+
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
|
| 196 |
+
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
|
| 197 |
+
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
|
| 198 |
+
dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
|
| 199 |
+
if self.ndecay:
|
| 200 |
+
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
| 201 |
+
decay_q = self.query_decay(x).view(B, heads, -1, T)
|
| 202 |
+
decay_q = torch.sigmoid(decay_q) / 2
|
| 203 |
+
decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
|
| 204 |
+
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
| 205 |
+
|
| 206 |
+
# Kill self reference.
|
| 207 |
+
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
|
| 208 |
+
weights = torch.softmax(dots, dim=2)
|
| 209 |
+
|
| 210 |
+
content = self.content(x).view(B, heads, -1, T)
|
| 211 |
+
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
| 212 |
+
if self.nfreqs:
|
| 213 |
+
time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
|
| 214 |
+
result = torch.cat([result, time_sig], 2)
|
| 215 |
+
result = result.reshape(B, -1, T)
|
| 216 |
+
return x + self.proj(result)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class Demucs(nn.Module):
|
| 220 |
+
@capture_init
|
| 221 |
+
def __init__(self,
|
| 222 |
+
sources,
|
| 223 |
+
# Channels
|
| 224 |
+
audio_channels=2,
|
| 225 |
+
channels=64,
|
| 226 |
+
growth=2.,
|
| 227 |
+
# Main structure
|
| 228 |
+
depth=6,
|
| 229 |
+
rewrite=True,
|
| 230 |
+
lstm_layers=0,
|
| 231 |
+
# Convolutions
|
| 232 |
+
kernel_size=8,
|
| 233 |
+
stride=4,
|
| 234 |
+
context=1,
|
| 235 |
+
# Activations
|
| 236 |
+
gelu=True,
|
| 237 |
+
glu=True,
|
| 238 |
+
# Normalization
|
| 239 |
+
norm_starts=4,
|
| 240 |
+
norm_groups=4,
|
| 241 |
+
# DConv residual branch
|
| 242 |
+
dconv_mode=1,
|
| 243 |
+
dconv_depth=2,
|
| 244 |
+
dconv_comp=4,
|
| 245 |
+
dconv_attn=4,
|
| 246 |
+
dconv_lstm=4,
|
| 247 |
+
dconv_init=1e-4,
|
| 248 |
+
# Pre/post processing
|
| 249 |
+
normalize=True,
|
| 250 |
+
resample=True,
|
| 251 |
+
# Weight init
|
| 252 |
+
rescale=0.1,
|
| 253 |
+
# Metadata
|
| 254 |
+
samplerate=44100,
|
| 255 |
+
segment=4 * 10):
|
| 256 |
+
"""
|
| 257 |
+
Args:
|
| 258 |
+
sources (list[str]): list of source names
|
| 259 |
+
audio_channels (int): stereo or mono
|
| 260 |
+
channels (int): first convolution channels
|
| 261 |
+
depth (int): number of encoder/decoder layers
|
| 262 |
+
growth (float): multiply (resp divide) number of channels by that
|
| 263 |
+
for each layer of the encoder (resp decoder)
|
| 264 |
+
depth (int): number of layers in the encoder and in the decoder.
|
| 265 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
| 266 |
+
lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
|
| 267 |
+
by default, as this is now replaced by the smaller and faster small LSTMs
|
| 268 |
+
in the DConv branches.
|
| 269 |
+
kernel_size (int): kernel size for convolutions
|
| 270 |
+
stride (int): stride for convolutions
|
| 271 |
+
context (int): kernel size of the convolution in the
|
| 272 |
+
decoder before the transposed convolution. If > 1,
|
| 273 |
+
will provide some context from neighboring time steps.
|
| 274 |
+
gelu: use GELU activation function.
|
| 275 |
+
glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
|
| 276 |
+
norm_starts: layer at which group norm starts being used.
|
| 277 |
+
decoder layers are numbered in reverse order.
|
| 278 |
+
norm_groups: number of groups for group norm.
|
| 279 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
| 280 |
+
dconv_depth: depth of residual DConv branch.
|
| 281 |
+
dconv_comp: compression of DConv branch.
|
| 282 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
| 283 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
| 284 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
| 285 |
+
normalize (bool): normalizes the input audio on the fly, and scales back
|
| 286 |
+
the output by the same amount.
|
| 287 |
+
resample (bool): upsample x2 the input and downsample /2 the output.
|
| 288 |
+
rescale (float): rescale initial weights of convolutions
|
| 289 |
+
to get their standard deviation closer to `rescale`.
|
| 290 |
+
samplerate (int): stored as meta information for easing
|
| 291 |
+
future evaluations of the model.
|
| 292 |
+
segment (float): duration of the chunks of audio to ideally evaluate the model on.
|
| 293 |
+
This is used by `demucs.apply.apply_model`.
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.audio_channels = audio_channels
|
| 298 |
+
self.sources = sources
|
| 299 |
+
self.kernel_size = kernel_size
|
| 300 |
+
self.context = context
|
| 301 |
+
self.stride = stride
|
| 302 |
+
self.depth = depth
|
| 303 |
+
self.resample = resample
|
| 304 |
+
self.channels = channels
|
| 305 |
+
self.normalize = normalize
|
| 306 |
+
self.samplerate = samplerate
|
| 307 |
+
self.segment = segment
|
| 308 |
+
self.encoder = nn.ModuleList()
|
| 309 |
+
self.decoder = nn.ModuleList()
|
| 310 |
+
self.skip_scales = nn.ModuleList()
|
| 311 |
+
|
| 312 |
+
if glu:
|
| 313 |
+
activation = nn.GLU(dim=1)
|
| 314 |
+
ch_scale = 2
|
| 315 |
+
else:
|
| 316 |
+
activation = nn.ReLU()
|
| 317 |
+
ch_scale = 1
|
| 318 |
+
if gelu:
|
| 319 |
+
act2 = nn.GELU
|
| 320 |
+
else:
|
| 321 |
+
act2 = nn.ReLU
|
| 322 |
+
|
| 323 |
+
in_channels = audio_channels
|
| 324 |
+
padding = 0
|
| 325 |
+
for index in range(depth):
|
| 326 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 327 |
+
if index >= norm_starts:
|
| 328 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 329 |
+
|
| 330 |
+
encode = []
|
| 331 |
+
encode += [
|
| 332 |
+
nn.Conv1d(in_channels, channels, kernel_size, stride),
|
| 333 |
+
norm_fn(channels),
|
| 334 |
+
act2(),
|
| 335 |
+
]
|
| 336 |
+
attn = index >= dconv_attn
|
| 337 |
+
lstm = index >= dconv_lstm
|
| 338 |
+
if dconv_mode & 1:
|
| 339 |
+
encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
|
| 340 |
+
compress=dconv_comp, attn=attn, lstm=lstm)]
|
| 341 |
+
if rewrite:
|
| 342 |
+
encode += [
|
| 343 |
+
nn.Conv1d(channels, ch_scale * channels, 1),
|
| 344 |
+
norm_fn(ch_scale * channels), activation]
|
| 345 |
+
self.encoder.append(nn.Sequential(*encode))
|
| 346 |
+
|
| 347 |
+
decode = []
|
| 348 |
+
if index > 0:
|
| 349 |
+
out_channels = in_channels
|
| 350 |
+
else:
|
| 351 |
+
out_channels = len(self.sources) * audio_channels
|
| 352 |
+
if rewrite:
|
| 353 |
+
decode += [
|
| 354 |
+
nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
|
| 355 |
+
norm_fn(ch_scale * channels), activation]
|
| 356 |
+
if dconv_mode & 2:
|
| 357 |
+
decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
|
| 358 |
+
compress=dconv_comp, attn=attn, lstm=lstm)]
|
| 359 |
+
decode += [nn.ConvTranspose1d(channels, out_channels,
|
| 360 |
+
kernel_size, stride, padding=padding)]
|
| 361 |
+
if index > 0:
|
| 362 |
+
decode += [norm_fn(out_channels), act2()]
|
| 363 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
| 364 |
+
in_channels = channels
|
| 365 |
+
channels = int(growth * channels)
|
| 366 |
+
|
| 367 |
+
channels = in_channels
|
| 368 |
+
if lstm_layers:
|
| 369 |
+
self.lstm = BLSTM(channels, lstm_layers)
|
| 370 |
+
else:
|
| 371 |
+
self.lstm = None
|
| 372 |
+
|
| 373 |
+
if rescale:
|
| 374 |
+
rescale_module(self, reference=rescale)
|
| 375 |
+
|
| 376 |
+
def valid_length(self, length):
|
| 377 |
+
"""
|
| 378 |
+
Return the nearest valid length to use with the model so that
|
| 379 |
+
there is no time steps left over in a convolution, e.g. for all
|
| 380 |
+
layers, size of the input - kernel_size % stride = 0.
|
| 381 |
+
|
| 382 |
+
Note that input are automatically padded if necessary to ensure that the output
|
| 383 |
+
has the same length as the input.
|
| 384 |
+
"""
|
| 385 |
+
if self.resample:
|
| 386 |
+
length *= 2
|
| 387 |
+
|
| 388 |
+
for _ in range(self.depth):
|
| 389 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
| 390 |
+
length = max(1, length)
|
| 391 |
+
|
| 392 |
+
for idx in range(self.depth):
|
| 393 |
+
length = (length - 1) * self.stride + self.kernel_size
|
| 394 |
+
|
| 395 |
+
if self.resample:
|
| 396 |
+
length = math.ceil(length / 2)
|
| 397 |
+
return int(length)
|
| 398 |
+
|
| 399 |
+
def forward(self, mix):
|
| 400 |
+
x = mix
|
| 401 |
+
length = x.shape[-1]
|
| 402 |
+
|
| 403 |
+
if self.normalize:
|
| 404 |
+
mono = mix.mean(dim=1, keepdim=True)
|
| 405 |
+
mean = mono.mean(dim=-1, keepdim=True)
|
| 406 |
+
std = mono.std(dim=-1, keepdim=True)
|
| 407 |
+
x = (x - mean) / (1e-5 + std)
|
| 408 |
+
else:
|
| 409 |
+
mean = 0
|
| 410 |
+
std = 1
|
| 411 |
+
|
| 412 |
+
delta = self.valid_length(length) - length
|
| 413 |
+
x = F.pad(x, (delta // 2, delta - delta // 2))
|
| 414 |
+
|
| 415 |
+
if self.resample:
|
| 416 |
+
x = julius.resample_frac(x, 1, 2)
|
| 417 |
+
|
| 418 |
+
saved = []
|
| 419 |
+
for encode in self.encoder:
|
| 420 |
+
x = encode(x)
|
| 421 |
+
saved.append(x)
|
| 422 |
+
|
| 423 |
+
if self.lstm:
|
| 424 |
+
x = self.lstm(x)
|
| 425 |
+
|
| 426 |
+
for decode in self.decoder:
|
| 427 |
+
skip = saved.pop(-1)
|
| 428 |
+
skip = center_trim(skip, x)
|
| 429 |
+
x = decode(x + skip)
|
| 430 |
+
|
| 431 |
+
if self.resample:
|
| 432 |
+
x = julius.resample_frac(x, 2, 1)
|
| 433 |
+
x = x * std + mean
|
| 434 |
+
x = center_trim(x, length)
|
| 435 |
+
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
| 436 |
+
return x
|
| 437 |
+
|
| 438 |
+
def load_state_dict(self, state, strict=True):
|
| 439 |
+
# fix a mismatch with previous generation Demucs models.
|
| 440 |
+
for idx in range(self.depth):
|
| 441 |
+
for a in ['encoder', 'decoder']:
|
| 442 |
+
for b in ['bias', 'weight']:
|
| 443 |
+
new = f'{a}.{idx}.3.{b}'
|
| 444 |
+
old = f'{a}.{idx}.2.{b}'
|
| 445 |
+
if old in state and new not in state:
|
| 446 |
+
state[new] = state.pop(old)
|
| 447 |
+
super().load_state_dict(state, strict=strict)
|
demucs/distrib.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""Distributed training utilities.
|
| 7 |
+
"""
|
| 8 |
+
import logging
|
| 9 |
+
import pickle
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 14 |
+
from torch.utils.data import DataLoader, Subset
|
| 15 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
| 16 |
+
|
| 17 |
+
from dora import distrib as dora_distrib
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
rank = 0
|
| 21 |
+
world_size = 1
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def init():
|
| 25 |
+
global rank, world_size
|
| 26 |
+
if not torch.distributed.is_initialized():
|
| 27 |
+
dora_distrib.init()
|
| 28 |
+
rank = dora_distrib.rank()
|
| 29 |
+
world_size = dora_distrib.world_size()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def average(metrics, count=1.):
|
| 33 |
+
if isinstance(metrics, dict):
|
| 34 |
+
keys, values = zip(*sorted(metrics.items()))
|
| 35 |
+
values = average(values, count)
|
| 36 |
+
return dict(zip(keys, values))
|
| 37 |
+
if world_size == 1:
|
| 38 |
+
return metrics
|
| 39 |
+
tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
|
| 40 |
+
tensor *= count
|
| 41 |
+
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
|
| 42 |
+
return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def wrap(model):
|
| 46 |
+
if world_size == 1:
|
| 47 |
+
return model
|
| 48 |
+
else:
|
| 49 |
+
return DistributedDataParallel(
|
| 50 |
+
model,
|
| 51 |
+
# find_unused_parameters=True,
|
| 52 |
+
device_ids=[torch.cuda.current_device()],
|
| 53 |
+
output_device=torch.cuda.current_device())
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def barrier():
|
| 57 |
+
if world_size > 1:
|
| 58 |
+
torch.distributed.barrier()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def share(obj=None, src=0):
|
| 62 |
+
if world_size == 1:
|
| 63 |
+
return obj
|
| 64 |
+
size = torch.empty(1, device='cuda', dtype=torch.long)
|
| 65 |
+
if rank == src:
|
| 66 |
+
dump = pickle.dumps(obj)
|
| 67 |
+
size[0] = len(dump)
|
| 68 |
+
torch.distributed.broadcast(size, src=src)
|
| 69 |
+
# size variable is now set to the length of pickled obj in all processes
|
| 70 |
+
|
| 71 |
+
if rank == src:
|
| 72 |
+
buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda()
|
| 73 |
+
else:
|
| 74 |
+
buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8)
|
| 75 |
+
torch.distributed.broadcast(buffer, src=src)
|
| 76 |
+
# buffer variable is now set to pickled obj in all processes
|
| 77 |
+
|
| 78 |
+
if rank != src:
|
| 79 |
+
obj = pickle.loads(buffer.cpu().numpy().tobytes())
|
| 80 |
+
logger.debug(f"Shared object of size {len(buffer)}")
|
| 81 |
+
return obj
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
|
| 85 |
+
"""
|
| 86 |
+
Create a dataloader properly in case of distributed training.
|
| 87 |
+
If a gradient is going to be computed you must set `shuffle=True`.
|
| 88 |
+
"""
|
| 89 |
+
if world_size == 1:
|
| 90 |
+
return klass(dataset, *args, shuffle=shuffle, **kwargs)
|
| 91 |
+
|
| 92 |
+
if shuffle:
|
| 93 |
+
# train means we will compute backward, we use DistributedSampler
|
| 94 |
+
sampler = DistributedSampler(dataset)
|
| 95 |
+
# We ignore shuffle, DistributedSampler already shuffles
|
| 96 |
+
return klass(dataset, *args, **kwargs, sampler=sampler)
|
| 97 |
+
else:
|
| 98 |
+
# We make a manual shard, as DistributedSampler otherwise replicate some examples
|
| 99 |
+
dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
|
| 100 |
+
return klass(dataset, *args, shuffle=shuffle, **kwargs)
|
demucs/ema.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Inspired from https://github.com/rwightman/pytorch-image-models
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from .states import swap_state
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ModelEMA:
|
| 16 |
+
"""
|
| 17 |
+
Perform EMA on a model. You can switch to the EMA weights temporarily
|
| 18 |
+
with the `swap` method.
|
| 19 |
+
|
| 20 |
+
ema = ModelEMA(model)
|
| 21 |
+
with ema.swap():
|
| 22 |
+
# compute valid metrics with averaged model.
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self, model, decay=0.9999, unbias=True, device='cpu'):
|
| 25 |
+
self.decay = decay
|
| 26 |
+
self.model = model
|
| 27 |
+
self.state = {}
|
| 28 |
+
self.count = 0
|
| 29 |
+
self.device = device
|
| 30 |
+
self.unbias = unbias
|
| 31 |
+
|
| 32 |
+
self._init()
|
| 33 |
+
|
| 34 |
+
def _init(self):
|
| 35 |
+
for key, val in self.model.state_dict().items():
|
| 36 |
+
if val.dtype != torch.float32:
|
| 37 |
+
continue
|
| 38 |
+
device = self.device or val.device
|
| 39 |
+
if key not in self.state:
|
| 40 |
+
self.state[key] = val.detach().to(device, copy=True)
|
| 41 |
+
|
| 42 |
+
def update(self):
|
| 43 |
+
if self.unbias:
|
| 44 |
+
self.count = self.count * self.decay + 1
|
| 45 |
+
w = 1 / self.count
|
| 46 |
+
else:
|
| 47 |
+
w = 1 - self.decay
|
| 48 |
+
for key, val in self.model.state_dict().items():
|
| 49 |
+
if val.dtype != torch.float32:
|
| 50 |
+
continue
|
| 51 |
+
device = self.device or val.device
|
| 52 |
+
self.state[key].mul_(1 - w)
|
| 53 |
+
self.state[key].add_(val.detach().to(device), alpha=w)
|
| 54 |
+
|
| 55 |
+
@contextmanager
|
| 56 |
+
def swap(self):
|
| 57 |
+
with swap_state(self.model, self.state):
|
| 58 |
+
yield
|
| 59 |
+
|
| 60 |
+
def state_dict(self):
|
| 61 |
+
return {'state': self.state, 'count': self.count}
|
| 62 |
+
|
| 63 |
+
def load_state_dict(self, state):
|
| 64 |
+
self.count = state['count']
|
| 65 |
+
for k, v in state['state'].items():
|
| 66 |
+
self.state[k].copy_(v)
|
demucs/evaluate.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Test time evaluation, either using the original SDR from [Vincent et al. 2006]
|
| 8 |
+
or the newest SDR definition from the MDX 2021 competition (this one will
|
| 9 |
+
be reported as `nsdr` for `new sdr`).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from concurrent import futures
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
from dora.log import LogProgress
|
| 16 |
+
import numpy as np
|
| 17 |
+
import musdb
|
| 18 |
+
import museval
|
| 19 |
+
import torch as th
|
| 20 |
+
|
| 21 |
+
from .apply import apply_model
|
| 22 |
+
from .audio import convert_audio, save_audio
|
| 23 |
+
from . import distrib
|
| 24 |
+
from .utils import DummyPoolExecutor
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def new_sdr(references, estimates):
|
| 31 |
+
"""
|
| 32 |
+
Compute the SDR according to the MDX challenge definition.
|
| 33 |
+
Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license)
|
| 34 |
+
"""
|
| 35 |
+
assert references.dim() == 4
|
| 36 |
+
assert estimates.dim() == 4
|
| 37 |
+
delta = 1e-7 # avoid numerical errors
|
| 38 |
+
num = th.sum(th.square(references), dim=(2, 3))
|
| 39 |
+
den = th.sum(th.square(references - estimates), dim=(2, 3))
|
| 40 |
+
num += delta
|
| 41 |
+
den += delta
|
| 42 |
+
scores = 10 * th.log10(num / den)
|
| 43 |
+
return scores
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def eval_track(references, estimates, win, hop, compute_sdr=True):
|
| 47 |
+
references = references.transpose(1, 2).double()
|
| 48 |
+
estimates = estimates.transpose(1, 2).double()
|
| 49 |
+
|
| 50 |
+
new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0]
|
| 51 |
+
|
| 52 |
+
if not compute_sdr:
|
| 53 |
+
return None, new_scores
|
| 54 |
+
else:
|
| 55 |
+
references = references.numpy()
|
| 56 |
+
estimates = estimates.numpy()
|
| 57 |
+
scores = museval.metrics.bss_eval(
|
| 58 |
+
references, estimates,
|
| 59 |
+
compute_permutation=False,
|
| 60 |
+
window=win,
|
| 61 |
+
hop=hop,
|
| 62 |
+
framewise_filters=False,
|
| 63 |
+
bsseval_sources_version=False)[:-1]
|
| 64 |
+
return scores, new_scores
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def evaluate(solver, compute_sdr=False):
|
| 68 |
+
"""
|
| 69 |
+
Evaluate model using museval.
|
| 70 |
+
compute_sdr=False means using only the MDX definition of the SDR, which
|
| 71 |
+
is much faster to evaluate.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
args = solver.args
|
| 75 |
+
|
| 76 |
+
output_dir = solver.folder / "results"
|
| 77 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 78 |
+
json_folder = solver.folder / "results/test"
|
| 79 |
+
json_folder.mkdir(exist_ok=True, parents=True)
|
| 80 |
+
|
| 81 |
+
# we load tracks from the original musdb set
|
| 82 |
+
if args.test.nonhq is None:
|
| 83 |
+
test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True)
|
| 84 |
+
else:
|
| 85 |
+
test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False)
|
| 86 |
+
src_rate = args.dset.musdb_samplerate
|
| 87 |
+
|
| 88 |
+
eval_device = 'cpu'
|
| 89 |
+
|
| 90 |
+
model = solver.model
|
| 91 |
+
win = int(1. * model.samplerate)
|
| 92 |
+
hop = int(1. * model.samplerate)
|
| 93 |
+
|
| 94 |
+
indexes = range(distrib.rank, len(test_set), distrib.world_size)
|
| 95 |
+
indexes = LogProgress(logger, indexes, updates=args.misc.num_prints,
|
| 96 |
+
name='Eval')
|
| 97 |
+
pendings = []
|
| 98 |
+
|
| 99 |
+
pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor
|
| 100 |
+
with pool(args.test.workers) as pool:
|
| 101 |
+
for index in indexes:
|
| 102 |
+
track = test_set.tracks[index]
|
| 103 |
+
|
| 104 |
+
mix = th.from_numpy(track.audio).t().float()
|
| 105 |
+
if mix.dim() == 1:
|
| 106 |
+
mix = mix[None]
|
| 107 |
+
mix = mix.to(solver.device)
|
| 108 |
+
ref = mix.mean(dim=0) # mono mixture
|
| 109 |
+
mix = (mix - ref.mean()) / ref.std()
|
| 110 |
+
mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
|
| 111 |
+
estimates = apply_model(model, mix[None],
|
| 112 |
+
shifts=args.test.shifts, split=args.test.split,
|
| 113 |
+
overlap=args.test.overlap)[0]
|
| 114 |
+
estimates = estimates * ref.std() + ref.mean()
|
| 115 |
+
estimates = estimates.to(eval_device)
|
| 116 |
+
|
| 117 |
+
references = th.stack(
|
| 118 |
+
[th.from_numpy(track.targets[name].audio).t() for name in model.sources])
|
| 119 |
+
if references.dim() == 2:
|
| 120 |
+
references = references[:, None]
|
| 121 |
+
references = references.to(eval_device)
|
| 122 |
+
references = convert_audio(references, src_rate,
|
| 123 |
+
model.samplerate, model.audio_channels)
|
| 124 |
+
if args.test.save:
|
| 125 |
+
folder = solver.folder / "wav" / track.name
|
| 126 |
+
folder.mkdir(exist_ok=True, parents=True)
|
| 127 |
+
for name, estimate in zip(model.sources, estimates):
|
| 128 |
+
save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate)
|
| 129 |
+
|
| 130 |
+
pendings.append((track.name, pool.submit(
|
| 131 |
+
eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr)))
|
| 132 |
+
|
| 133 |
+
pendings = LogProgress(logger, pendings, updates=args.misc.num_prints,
|
| 134 |
+
name='Eval (BSS)')
|
| 135 |
+
tracks = {}
|
| 136 |
+
for track_name, pending in pendings:
|
| 137 |
+
pending = pending.result()
|
| 138 |
+
scores, nsdrs = pending
|
| 139 |
+
tracks[track_name] = {}
|
| 140 |
+
for idx, target in enumerate(model.sources):
|
| 141 |
+
tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]}
|
| 142 |
+
if scores is not None:
|
| 143 |
+
(sdr, isr, sir, sar) = scores
|
| 144 |
+
for idx, target in enumerate(model.sources):
|
| 145 |
+
values = {
|
| 146 |
+
"SDR": sdr[idx].tolist(),
|
| 147 |
+
"SIR": sir[idx].tolist(),
|
| 148 |
+
"ISR": isr[idx].tolist(),
|
| 149 |
+
"SAR": sar[idx].tolist()
|
| 150 |
+
}
|
| 151 |
+
tracks[track_name][target].update(values)
|
| 152 |
+
|
| 153 |
+
all_tracks = {}
|
| 154 |
+
for src in range(distrib.world_size):
|
| 155 |
+
all_tracks.update(distrib.share(tracks, src))
|
| 156 |
+
|
| 157 |
+
result = {}
|
| 158 |
+
metric_names = next(iter(all_tracks.values()))[model.sources[0]]
|
| 159 |
+
for metric_name in metric_names:
|
| 160 |
+
avg = 0
|
| 161 |
+
avg_of_medians = 0
|
| 162 |
+
for source in model.sources:
|
| 163 |
+
medians = [
|
| 164 |
+
np.nanmedian(all_tracks[track][source][metric_name])
|
| 165 |
+
for track in all_tracks.keys()]
|
| 166 |
+
mean = np.mean(medians)
|
| 167 |
+
median = np.median(medians)
|
| 168 |
+
result[metric_name.lower() + "_" + source] = mean
|
| 169 |
+
result[metric_name.lower() + "_med" + "_" + source] = median
|
| 170 |
+
avg += mean / len(model.sources)
|
| 171 |
+
avg_of_medians += median / len(model.sources)
|
| 172 |
+
result[metric_name.lower()] = avg
|
| 173 |
+
result[metric_name.lower() + "_med"] = avg_of_medians
|
| 174 |
+
return result
|
demucs/grids/__init__.py
ADDED
|
File without changes
|
demucs/grids/_explorers.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
from dora import Explorer
|
| 7 |
+
import treetable as tt
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MyExplorer(Explorer):
|
| 11 |
+
test_metrics = ['nsdr', 'sdr_med']
|
| 12 |
+
|
| 13 |
+
def get_grid_metrics(self):
|
| 14 |
+
"""Return the metrics that should be displayed in the tracking table.
|
| 15 |
+
"""
|
| 16 |
+
return [
|
| 17 |
+
tt.group("train", [
|
| 18 |
+
tt.leaf("epoch"),
|
| 19 |
+
tt.leaf("reco", ".3f"),
|
| 20 |
+
], align=">"),
|
| 21 |
+
tt.group("valid", [
|
| 22 |
+
tt.leaf("penalty", ".1f"),
|
| 23 |
+
tt.leaf("ms", ".1f"),
|
| 24 |
+
tt.leaf("reco", ".2%"),
|
| 25 |
+
tt.leaf("breco", ".2%"),
|
| 26 |
+
tt.leaf("b_nsdr", ".2f"),
|
| 27 |
+
# tt.leaf("b_nsdr_drums", ".2f"),
|
| 28 |
+
# tt.leaf("b_nsdr_bass", ".2f"),
|
| 29 |
+
# tt.leaf("b_nsdr_other", ".2f"),
|
| 30 |
+
# tt.leaf("b_nsdr_vocals", ".2f"),
|
| 31 |
+
], align=">"),
|
| 32 |
+
tt.group("test", [
|
| 33 |
+
tt.leaf(name, ".2f")
|
| 34 |
+
for name in self.test_metrics
|
| 35 |
+
], align=">")
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
def process_history(self, history):
|
| 39 |
+
train = {
|
| 40 |
+
'epoch': len(history),
|
| 41 |
+
}
|
| 42 |
+
valid = {}
|
| 43 |
+
test = {}
|
| 44 |
+
best_v_main = float('inf')
|
| 45 |
+
breco = float('inf')
|
| 46 |
+
for metrics in history:
|
| 47 |
+
train.update(metrics['train'])
|
| 48 |
+
valid.update(metrics['valid'])
|
| 49 |
+
if 'main' in metrics['valid']:
|
| 50 |
+
best_v_main = min(best_v_main, metrics['valid']['main']['loss'])
|
| 51 |
+
valid['bmain'] = best_v_main
|
| 52 |
+
valid['breco'] = min(breco, metrics['valid']['reco'])
|
| 53 |
+
breco = valid['breco']
|
| 54 |
+
if (metrics['valid']['loss'] == metrics['valid']['best'] or
|
| 55 |
+
metrics['valid'].get('nsdr') == metrics['valid']['best']):
|
| 56 |
+
for k, v in metrics['valid'].items():
|
| 57 |
+
if k.startswith('reco_'):
|
| 58 |
+
valid['b_' + k[len('reco_'):]] = v
|
| 59 |
+
if k.startswith('nsdr'):
|
| 60 |
+
valid[f'b_{k}'] = v
|
| 61 |
+
if 'test' in metrics:
|
| 62 |
+
test.update(metrics['test'])
|
| 63 |
+
metrics = history[-1]
|
| 64 |
+
return {"train": train, "valid": valid, "test": test}
|
demucs/grids/mdx.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
Main training for the Track A MDX models.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from ._explorers import MyExplorer
|
| 11 |
+
from ..train import main
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
TRACK_A = ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68']
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@MyExplorer
|
| 18 |
+
def explorer(launcher):
|
| 19 |
+
launcher.slurm_(
|
| 20 |
+
gpus=8,
|
| 21 |
+
time=3 * 24 * 60,
|
| 22 |
+
partition='learnlab')
|
| 23 |
+
|
| 24 |
+
# Reproduce results from MDX competition Track A
|
| 25 |
+
# This trains the first round of models. Once this is trained,
|
| 26 |
+
# you will need to schedule `mdx_refine`.
|
| 27 |
+
for sig in TRACK_A:
|
| 28 |
+
xp = main.get_xp_from_sig(sig)
|
| 29 |
+
parent = xp.cfg.continue_from
|
| 30 |
+
xp = main.get_xp_from_sig(parent)
|
| 31 |
+
launcher(xp.argv)
|
| 32 |
+
launcher(xp.argv, {'quant.diffq': 1e-4})
|
| 33 |
+
launcher(xp.argv, {'quant.diffq': 3e-4})
|
demucs/grids/mdx_extra.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
Main training for the Track A MDX models.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from ._explorers import MyExplorer
|
| 11 |
+
from ..train import main
|
| 12 |
+
|
| 13 |
+
TRACK_B = ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08']
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@MyExplorer
|
| 17 |
+
def explorer(launcher):
|
| 18 |
+
launcher.slurm_(
|
| 19 |
+
gpus=8,
|
| 20 |
+
time=3 * 24 * 60,
|
| 21 |
+
partition='learnlab')
|
| 22 |
+
|
| 23 |
+
# Reproduce results from MDX competition Track A
|
| 24 |
+
# This trains the first round of models. Once this is trained,
|
| 25 |
+
# you will need to schedule `mdx_refine`.
|
| 26 |
+
for sig in TRACK_B:
|
| 27 |
+
while sig is not None:
|
| 28 |
+
xp = main.get_xp_from_sig(sig)
|
| 29 |
+
sig = xp.cfg.continue_from
|
| 30 |
+
|
| 31 |
+
for dset in ['extra44', 'extra_test']:
|
| 32 |
+
sub = launcher.bind(xp.argv, dset=dset)
|
| 33 |
+
sub()
|
| 34 |
+
if dset == 'extra_test':
|
| 35 |
+
sub({'quant.diffq': 1e-4})
|
| 36 |
+
sub({'quant.diffq': 3e-4})
|
demucs/grids/mdx_refine.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
Main training for the Track A MDX models.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from ._explorers import MyExplorer
|
| 11 |
+
from .mdx import TRACK_A
|
| 12 |
+
from ..train import main
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@MyExplorer
|
| 16 |
+
def explorer(launcher):
|
| 17 |
+
launcher.slurm_(
|
| 18 |
+
gpus=8,
|
| 19 |
+
time=3 * 24 * 60,
|
| 20 |
+
partition='learnlab')
|
| 21 |
+
|
| 22 |
+
# Reproduce results from MDX competition Track A
|
| 23 |
+
# WARNING: all the experiments in the `mdx` grid must have completed.
|
| 24 |
+
for sig in TRACK_A:
|
| 25 |
+
xp = main.get_xp_from_sig(sig)
|
| 26 |
+
launcher(xp.argv)
|
| 27 |
+
for diffq in [1e-4, 3e-4]:
|
| 28 |
+
xp_src = main.get_xp_from_sig(xp.cfg.continue_from)
|
| 29 |
+
q_argv = [f'quant.diffq={diffq}']
|
| 30 |
+
actual_src = main.get_xp(xp_src.argv + q_argv)
|
| 31 |
+
actual_src.link.load()
|
| 32 |
+
assert len(actual_src.link.history) == actual_src.cfg.epochs
|
| 33 |
+
argv = xp.argv + q_argv + [f'continue_from="{actual_src.sig}"']
|
| 34 |
+
launcher(argv)
|