diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..0b1e1e7ef2cf9bb722ff3b63166b2ade42558dbb --- /dev/null +++ b/.dockerignore @@ -0,0 +1,27 @@ +**/__pycache__ +**/.venv +**/.classpath +**/.dockerignore +**/.env +**/.git +**/.gitignore +**/.project +**/.settings +**/.toolstarget +**/.vs +**/.vscode +**/*.*proj.user +**/*.dbmdl +**/*.jfm +**/bin +**/charts +**/docker-compose* +**/compose* +**/Dockerfile* +**/node_modules +**/npm-debug.log +**/obj +**/secrets.dev.yaml +**/values.dev.yaml +LICENSE +README.md diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f45caefe27c170bbf699acaf48f24795d7554ef8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +demucs.png filter=lfs diff=lfs merge=lfs -text +test.mp3 filter=lfs diff=lfs merge=lfs -text diff --git a/.github/ISSUE_TEMPLATE/bug.md b/.github/ISSUE_TEMPLATE/bug.md new file mode 100644 index 0000000000000000000000000000000000000000..4c2d4ee177efd12fbb7e966b4fe318ad7797af4c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug.md @@ -0,0 +1,33 @@ +--- +name: 🐛 Bug Report +about: Submit a bug report to help us improve +labels: 'bug' +--- + +## 🐛 Bug Report + +(A clear and concise description of what the bug is) + +## To Reproduce + +(Write your steps here:) + +1. Step 1... +1. Step 2... +1. Step 3... + +## Expected behavior + +(Write what you thought would happen.) + +## Actual Behavior + +(Write what happened. Add screenshots, if applicable.) + +## Your Environment + + + +- Python and PyTorch version: +- Operating system and version (desktop or mobile): +- Hardware (gpu or cpu, amount of RAM etc.): diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md new file mode 100644 index 0000000000000000000000000000000000000000..eeb9ad17c62505917009e94a0d467adaa9c85994 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.md @@ -0,0 +1,10 @@ +--- +name: "❓Questions/Help/Support" +about: If you have a question about the paper, code or algorithm, please ask here! +labels: question + +--- + +## ❓ Questions + +(Please ask your question here.) diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml new file mode 100644 index 0000000000000000000000000000000000000000..8f7ecded14c25f89a57e77d676887b610522b053 --- /dev/null +++ b/.github/workflows/linter.yml @@ -0,0 +1,36 @@ +name: linter +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + if: ${{ github.repository == 'facebookresearch/demucs' || github.event_name == 'workflow_dispatch' }} + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - uses: actions/cache@v2 + with: + path: env + key: env-${{ hashFiles('**/requirements.txt', '.github/workflows/*') }} + + - name: Install dependencies + run: | + python3 -m venv env + . env/bin/activate + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install '.[dev]' + + + - name: Run linter + run: | + . env/bin/activate + make linter diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000000000000000000000000000000000000..f9b4cf17e6df8cfef422abdd43d27caf1dc5ea94 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,36 @@ +name: tests +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + if: ${{ github.repository == 'facebookresearch/demucs' || github.event_name == 'workflow_dispatch' }} + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - uses: actions/cache@v2 + with: + path: env + key: env-${{ hashFiles('**/requirements.txt', '.github/workflows/*') }} + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y ffmpeg + python3 -m venv env + . env/bin/activate + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Run separation test + run: | + . env/bin/activate + make test_eval diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0201b148fb9ce43befc3bee12ff63852f6890e7e --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +*.egg-info +__pycache__ +Session.vim +/build +/dist +/lab +/metadata +/notebooks +/outputs +/release +/release_models +/separated +/tests +/trash +/misc +/mdx +.mypy_cache diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000000000000000000000000000000000..e41222e1ef31d172278435da1b96ee63d5b81717 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,19 @@ +{ + "configurations": [ + { + "name": "Containers: Python - Fastapi", + "type": "docker", + "request": "launch", + "preLaunchTask": "docker-run: debug", + "python": { + "pathMappings": [ + { + "localRoot": "${workspaceFolder}", + "remoteRoot": "/app" + } + ], + "projectType": "fastapi" + } + } + ] +} \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000000000000000000000000000000000000..2ee261ed2380395295a6fca4877a86ea692a435b --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,33 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "type": "docker-build", + "label": "docker-build", + "platform": "python", + "dockerBuild": { + "tag": "demucs:latest", + "dockerfile": "${workspaceFolder}/Dockerfile", + "context": "${workspaceFolder}", + "pull": true + } + }, + { + "type": "docker-run", + "label": "docker-run: debug", + "dependsOn": [ + "docker-build" + ], + "python": { + "args": [ + "predict:app", + "--host", + "0.0.0.0", + "--port", + "8000" + ], + "module": "uvicorn" + } + } + ] +} \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..45fbb84900f04ba7cc7f4a004f0be6870b914617 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,76 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..7a1521894978366fdc31e89947f16c36b6fbf754 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,23 @@ +# Contributing to Demucs + +## Pull Requests + +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +Demucs is the implementation of a research paper. +Therefore, we do not plan on accepting many pull requests for new features. +We certainly welcome them for bug fixes. + + +## Issues + +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + + +## License +By contributing to this repository, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/Demucs.ipynb b/Demucs.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..69d8dfff4eab161f5115329e419b523d2ac67fae --- /dev/null +++ b/Demucs.ipynb @@ -0,0 +1,153 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Be9yoh-ILfRr" + }, + "source": [ + "# Hybrid Demucs\n", + "\n", + "Feel free to use the Colab version:\n", + "https://colab.research.google.com/drive/1dC9nVxk3V_VPjUADsnFu8EiT-xnU1tGH?usp=sharing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 139 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 12277, + "status": "ok", + "timestamp": 1583778134659, + "user": { + "displayName": "Marllus Lustosa", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgLl2RbW64ZyWz3Y8IBku0zhHCMnt7fz7fEl0LTdA=s64", + "userId": "14811735256675200480" + }, + "user_tz": 180 + }, + "id": "kOjIPLlzhPfn", + "outputId": "c75f17ec-b576-4105-bc5b-c2ac9c1018a3" + }, + "outputs": [], + "source": [ + "!pip install -U demucs\n", + "# or for local development, if you have a clone of Demucs\n", + "# pip install -e ." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "5lYOzKKCKAbJ" + }, + "outputs": [], + "source": [ + "# You can use the `demucs` command line to separate tracks\n", + "!demucs test.mp3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# You can also load directly the pretrained models,\n", + "# for instance for the MDX 2021 winning model of Track A:\n", + "from demucs import pretrained\n", + "model = pretrained.get_model('mdx')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Because `model` is a bag of 4 models, you cannot directly call it on your data,\n", + "# but the `apply_model` will know what to do of it.\n", + "import torch\n", + "from demucs.apply import apply_model\n", + "x = torch.randn(1, 2, 44100 * 10) # ten seconds of white noise for the demo\n", + "out = apply_model(model, x)[0] # shape is [S, C, T] with S the number of sources\n", + "\n", + "# So let see, where is all the white noise content is going ?\n", + "for name, source in zip(model.sources, out):\n", + " print(name, source.std() / x.std())\n", + "# The outputs are quite weird to be fair, not what I would have expected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# now let's take a single model from the bag, and let's test it on a pure cosine\n", + "freq = 440 # in Hz\n", + "sr = model.samplerate\n", + "t = torch.arange(10 * sr).float() / sr\n", + "x = torch.cos(2 * 3.1416 * freq * t).expand(1, 2, -1)\n", + "sub_model = model.models[3]\n", + "out = sub_model(x)[0]\n", + "\n", + "# Same question where does it go?\n", + "for name, source in zip(model.sources, out):\n", + " print(name, source.std() / x.std())\n", + " \n", + "# Well now it makes much more sense, all the energy is going\n", + "# in the `other` source.\n", + "# Feel free to try lower pitch (try 80 Hz) to see what happens !" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# For training or more fun, refer to the Demucs README on our repo\n", + "# https://github.com/facebookresearch/demucs/tree/main/demucs" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "authorship_tag": "ABX9TyM9xpVr1M86NRcjtQ7g9tCx", + "collapsed_sections": [], + "name": "Demucs.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..f3cdb6e638704afedc94585411ce8303407448b8 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,23 @@ +# Use Python 3.9 slim base +FROM python:3.9-slim + +# Install system dependencies +RUN apt-get update && apt-get install -y ffmpeg git && apt-get clean + +# Set work directory +WORKDIR /app + +# Install Python packages +RUN pip install --upgrade pip +RUN pip install torch torchaudio +RUN pip install fastapi uvicorn +RUN pip install git+https://github.com/facebookresearch/demucs + +# Copy your inference script into the container +COPY predict.py . + +# Expose port for FastAPI +EXPOSE 8000 + +# Run the FastAPI app +CMD ["uvicorn", "predict:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..880b6cf9bd18215c23845a1feb603060c08a6e3b --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Meta Platforms, Inc. and affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..318e315281731fb1e79e5ee949a82a6bf5e3fad7 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,13 @@ +recursive-exclude env * +recursive-include conf *.yaml +include Makefile +include LICENSE +include demucs.png +include outputs.tar.gz +include test.mp3 +include requirements.txt +include requirements_minimal.txt +include mypy.ini +include demucs/py.typed +include demucs/remote/*.txt +include demucs/remote/*.yaml diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..3c54c6236238314744d291d3557e3e7912c2b543 --- /dev/null +++ b/Makefile @@ -0,0 +1,36 @@ +all: linter tests + +linter: + flake8 demucs + mypy demucs + +tests: test_train test_eval + +test_train: tests/musdb + _DORA_TEST_PATH=/tmp/demucs python3 -m dora run --clear \ + dset.musdb=./tests/musdb dset.segment=4 dset.shift=2 epochs=2 model=demucs \ + demucs.depth=2 demucs.channels=4 test.sdr=false misc.num_workers=0 test.workers=0 \ + test.shifts=0 + +test_eval: + python3 -m demucs -n demucs_unittest test.mp3 + python3 -m demucs -n demucs_unittest --two-stems=vocals test.mp3 + python3 -m demucs -n demucs_unittest --mp3 test.mp3 + python3 -m demucs -n demucs_unittest --flac --int24 test.mp3 + python3 -m demucs -n demucs_unittest --int24 --clip-mode clamp test.mp3 + python3 -m demucs -n demucs_unittest --segment 8 test.mp3 + python3 -m demucs.api -n demucs_unittest --segment 8 test.mp3 + python3 -m demucs --list-models + +tests/musdb: + test -e tests || mkdir tests + python3 -c 'import musdb; musdb.DB("tests/tmp", download=True)' + musdbconvert tests/tmp tests/musdb + +dist: + python3 setup.py sdist + +clean: + rm -r dist build *.egg-info + +.PHONY: linter dist test_train test_eval diff --git a/README.md b/README.md index 5d049a835d3b811986245e473f30fd957a26a55b..e4db1c591bb6e2b7e5e955cd48660e4a6e30ff25 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,319 @@ ---- -title: Audio -emoji: 📈 -colorFrom: pink -colorTo: blue -sdk: gradio -sdk_version: 5.35.0 -app_file: app.py -pinned: false -license: unknown -short_description: audio processor ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Demucs Music Source Separation + +[![Support Ukraine](https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB)](https://opensource.fb.com/support-ukraine) +![tests badge](https://github.com/facebookresearch/demucs/workflows/tests/badge.svg) +![linter badge](https://github.com/facebookresearch/demucs/workflows/linter/badge.svg) + + +**Important:** As I am no longer working at Meta, **this repository is not maintained anymore**. +I've created a fork at [github.com/adefossez/demucs](https://github.com/adefossez/demucs). Note that this project is not actively maintained anymore +and only important bug fixes will be processed on the new repo. Please do not open issues for feature request or if Demucs doesn't work perfectly for your use case :) + +This is the 4th release of Demucs (v4), featuring Hybrid Transformer based source separation. +**For the classic Hybrid Demucs (v3):** [Go this commit][demucs_v3]. +If you are experiencing issues and want the old Demucs back, please file an issue, and then you can get back to Demucs v3 with +`git checkout v3`. You can also go [Demucs v2][demucs_v2]. + + +Demucs is a state-of-the-art music source separation model, currently capable of separating +drums, bass, and vocals from the rest of the accompaniment. +Demucs is based on a U-Net convolutional architecture inspired by [Wave-U-Net][waveunet]. +The v4 version features [Hybrid Transformer Demucs][htdemucs], a hybrid spectrogram/waveform separation model using Transformers. +It is based on [Hybrid Demucs][hybrid_paper] (also provided in this repo), with the innermost layers +replaced by a cross-domain Transformer Encoder. This Transformer uses self-attention within each domain, +and cross-attention across domains. +The model achieves a SDR of 9.00 dB on the MUSDB HQ test set. Moreover, when using sparse attention +kernels to extend its receptive field and per source fine-tuning, we achieve state-of-the-art 9.20 dB of SDR. + +Samples are available [on our sample page](https://ai.honu.io/papers/htdemucs/index.html). +Checkout [our paper][htdemucs] for more information. +It has been trained on the [MUSDB HQ][musdb] dataset + an extra training dataset of 800 songs. +This model separates drums, bass and vocals and other stems for any song. + + +As Hybrid Transformer Demucs is brand new, it is not activated by default, you can activate it in the usual +commands described hereafter with `-n htdemucs_ft`. +The single, non fine-tuned model is provided as `-n htdemucs`, and the retrained baseline +as `-n hdemucs_mmi`. The Sparse Hybrid Transformer model decribed in our paper is not provided as its +requires custom CUDA code that is not ready for release yet. +We are also releasing an experimental 6 sources model, that adds a `guitar` and `piano` source. +Quick testing seems to show okay quality for `guitar`, but a lot of bleeding and artifacts for the `piano` source. + + +

+Schema representing the structure of Hybrid Transformer Demucs,
+    with a dual U-Net structure, one branch for the temporal domain,
+    and one branch for the spectral domain. There is a cross-domain Transformer between the Encoders and Decoders.

+ + + +## Important news if you are already using Demucs + +See the [release notes](./docs/release.md) for more details. + +- 22/02/2023: added support for the [SDX 2023 Challenge](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023), + see the dedicated [doc page](./docs/sdx23.md) +- 07/12/2022: Demucs v4 now on PyPI. **htdemucs** model now used by default. Also releasing + a 6 sources models (adding `guitar` and `piano`, although the latter doesn't work so well at the moment). +- 16/11/2022: Added the new **Hybrid Transformer Demucs v4** models. + Adding support for the [torchaudio implementation of HDemucs](https://pytorch.org/audio/stable/tutorials/hybrid_demucs_tutorial.html). +- 30/08/2022: added reproducibility and ablation grids, along with an updated version of the paper. +- 17/08/2022: Releasing v3.0.5: Set split segment length to reduce memory. Compatible with pyTorch 1.12. +- 24/02/2022: Releasing v3.0.4: split into two stems (i.e. karaoke mode). + Export as float32 or int24. +- 17/12/2021: Releasing v3.0.3: bug fixes (thanks @keunwoochoi), memory drastically + reduced on GPU (thanks @famzah) and new multi-core evaluation on CPU (`-j` flag). +- 12/11/2021: Releasing **Demucs v3** with hybrid domain separation. Strong improvements + on all sources. This is the model that won Sony MDX challenge. +- 11/05/2021: Adding support for MusDB-HQ and arbitrary wav set, for the MDX challenge. For more information +on joining the challenge with Demucs see [the Demucs MDX instructions](docs/mdx.md) + + +## Comparison with other models + +We provide hereafter a summary of the different metrics presented in the paper. +You can also compare Hybrid Demucs (v3), [KUIELAB-MDX-Net][kuielab], [Spleeter][spleeter], Open-Unmix, Demucs (v1), and Conv-Tasnet on one of my favorite +songs on my [soundcloud playlist][soundcloud]. + +### Comparison of accuracy + +`Overall SDR` is the mean of the SDR for each of the 4 sources, `MOS Quality` is a rating from 1 to 5 +of the naturalness and absence of artifacts given by human listeners (5 = no artifacts), `MOS Contamination` +is a rating from 1 to 5 with 5 being zero contamination by other sources. We refer the reader to our [paper][hybrid_paper], +for more details. + +| Model | Domain | Extra data? | Overall SDR | MOS Quality | MOS Contamination | +|------------------------------|-------------|-------------------|-------------|-------------|-------------------| +| [Wave-U-Net][waveunet] | waveform | no | 3.2 | - | - | +| [Open-Unmix][openunmix] | spectrogram | no | 5.3 | - | - | +| [D3Net][d3net] | spectrogram | no | 6.0 | - | - | +| [Conv-Tasnet][demucs_v2] | waveform | no | 5.7 | - | | +| [Demucs (v2)][demucs_v2] | waveform | no | 6.3 | 2.37 | 2.36 | +| [ResUNetDecouple+][decouple] | spectrogram | no | 6.7 | - | - | +| [KUIELAB-MDX-Net][kuielab] | hybrid | no | 7.5 | **2.86** | 2.55 | +| [Band-Spit RNN][bandsplit] | spectrogram | no | **8.2** | - | - | +| **Hybrid Demucs (v3)** | hybrid | no | 7.7 | **2.83** | **3.04** | +| [MMDenseLSTM][mmdenselstm] | spectrogram | 804 songs | 6.0 | - | - | +| [D3Net][d3net] | spectrogram | 1.5k songs | 6.7 | - | - | +| [Spleeter][spleeter] | spectrogram | 25k songs | 5.9 | - | - | +| [Band-Spit RNN][bandsplit] | spectrogram | 1.7k (mixes only) | **9.0** | - | - | +| **HT Demucs f.t. (v4)** | hybrid | 800 songs | **9.0** | - | - | + + + +## Requirements + +You will need at least Python 3.8. See `requirements_minimal.txt` for requirements for separation only, +and `environment-[cpu|cuda].yml` (or `requirements.txt`) if you want to train a new model. + +### For Windows users + +Everytime you see `python3`, replace it with `python.exe`. You should always run commands from the +Anaconda console. + +### For musicians + +If you just want to use Demucs to separate tracks, you can install it with + +```bash +python3 -m pip install -U demucs +``` + +For bleeding edge versions, you can install directly from this repo using +```bash +python3 -m pip install -U git+https://github.com/facebookresearch/demucs#egg=demucs +``` + +Advanced OS support are provided on the following page, **you must read the page for your OS before posting an issues**: +- **If you are using Windows:** [Windows support](docs/windows.md). +- **If you are using macOS:** [macOS support](docs/mac.md). +- **If you are using Linux:** [Linux support](docs/linux.md). + +### For machine learning scientists + +If you have anaconda installed, you can run from the root of this repository: + +```bash +conda env update -f environment-cpu.yml # if you don't have GPUs +conda env update -f environment-cuda.yml # if you have GPUs +conda activate demucs +pip install -e . +``` + +This will create a `demucs` environment with all the dependencies installed. + +You will also need to install [soundstretch/soundtouch](https://www.surina.net/soundtouch/soundstretch.html): on macOS you can do `brew install sound-touch`, +and on Ubuntu `sudo apt-get install soundstretch`. This is used for the +pitch/tempo augmentation. + + +### Running in Docker + +Thanks to @xserrat, there is now a Docker image definition ready for using Demucs. This can ensure all libraries are correctly installed without interfering with the host OS. See his repo [Docker Facebook Demucs](https://github.com/xserrat/docker-facebook-demucs) for more information. + + +### Running from Colab + +I made a Colab to easily separate track with Demucs. Note that +transfer speeds with Colab are a bit slow for large media files, +but it will allow you to use Demucs without installing anything. + +[Demucs on Google Colab](https://colab.research.google.com/drive/1dC9nVxk3V_VPjUADsnFu8EiT-xnU1tGH?usp=sharing) + +### Web Demo + +Integrated to [Hugging Face Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/demucs) + +### Graphical Interface + +@CarlGao4 has released a GUI for Demucs: [CarlGao4/Demucs-Gui](https://github.com/CarlGao4/Demucs-Gui). Downloads for Windows and macOS is available [here](https://github.com/CarlGao4/Demucs-Gui/releases). Use [FossHub mirror](https://fosshub.com/Demucs-GUI.html) to speed up your download. + +@Anjok07 is providing a self contained GUI in [UVR (Ultimate Vocal Remover)](https://github.com/facebookresearch/demucs/issues/334) that supports Demucs. + +### Other providers + +Audiostrip is providing free online separation with Demucs on their website [https://audiostrip.co.uk/](https://audiostrip.co.uk/). + +[MVSep](https://mvsep.com/) also provides free online separation, select `Demucs3 model B` for the best quality. + +[Neutone](https://neutone.space/) provides a realtime Demucs model in their free VST/AU plugin that can be used in your favorite DAW. + + +## Separating tracks + +In order to try Demucs, you can just run from any folder (as long as you properly installed it) + +```bash +demucs PATH_TO_AUDIO_FILE_1 [PATH_TO_AUDIO_FILE_2 ...] # for Demucs +# If you used `pip install --user` you might need to replace demucs with python3 -m demucs +python3 -m demucs --mp3 --mp3-bitrate BITRATE PATH_TO_AUDIO_FILE_1 # output files saved as MP3 + # use --mp3-preset to change encoder preset, 2 for best quality, 7 for fastest +# If your filename contain spaces don't forget to quote it !!! +demucs "my music/my favorite track.mp3" +# You can select different models with `-n` mdx_q is the quantized model, smaller but maybe a bit less accurate. +demucs -n mdx_q myfile.mp3 +# If you only want to separate vocals out of an audio, use `--two-stems=vocals` (You can also set to drums or bass) +demucs --two-stems=vocals myfile.mp3 +``` + + +If you have a GPU, but you run out of memory, please use `--segment SEGMENT` to reduce length of each split. `SEGMENT` should be changed to a integer describing the length of each segment in seconds. +A segment length of at least 10 is recommended (the bigger the number is, the more memory is required, but quality may increase). Note that the Hybrid Transformer models only support a maximum segment length of 7.8 seconds. +Creating an environment variable `PYTORCH_NO_CUDA_MEMORY_CACHING=1` is also helpful. If this still does not help, please add `-d cpu` to the command line. See the section hereafter for more details on the memory requirements for GPU acceleration. + +Separated tracks are stored in the `separated/MODEL_NAME/TRACK_NAME` folder. There you will find four stereo wav files sampled at 44.1 kHz: `drums.wav`, `bass.wav`, +`other.wav`, `vocals.wav` (or `.mp3` if you used the `--mp3` option). + +All audio formats supported by `torchaudio` can be processed (i.e. wav, mp3, flac, ogg/vorbis on Linux/macOS, etc.). On Windows, `torchaudio` has limited support, so we rely on `ffmpeg`, which should support pretty much anything. +Audio is resampled on the fly if necessary. +The output will be a wav file encoded as int16. +You can save as float32 wav files with `--float32`, or 24 bits integer wav with `--int24`. +You can pass `--mp3` to save as mp3 instead, and set the bitrate (in kbps) with `--mp3-bitrate` (default is 320). + +It can happen that the output would need clipping, in particular due to some separation artifacts. +Demucs will automatically rescale each output stem so as to avoid clipping. This can however break +the relative volume between stems. If instead you prefer hard clipping, pass `--clip-mode clamp`. +You can also try to reduce the volume of the input mixture before feeding it to Demucs. + + +Other pre-trained models can be selected with the `-n` flag. +The list of pre-trained models is: +- `htdemucs`: first version of Hybrid Transformer Demucs. Trained on MusDB + 800 songs. Default model. +- `htdemucs_ft`: fine-tuned version of `htdemucs`, separation will take 4 times more time + but might be a bit better. Same training set as `htdemucs`. +- `htdemucs_6s`: 6 sources version of `htdemucs`, with `piano` and `guitar` being added as sources. + Note that the `piano` source is not working great at the moment. +- `hdemucs_mmi`: Hybrid Demucs v3, retrained on MusDB + 800 songs. +- `mdx`: trained only on MusDB HQ, winning model on track A at the [MDX][mdx] challenge. +- `mdx_extra`: trained with extra training data (**including MusDB test set**), ranked 2nd on the track B + of the [MDX][mdx] challenge. +- `mdx_q`, `mdx_extra_q`: quantized version of the previous models. Smaller download and storage + but quality can be slightly worse. +- `SIG`: where `SIG` is a single model from the [model zoo](docs/training.md#model-zoo). + +The `--two-stems=vocals` option allows separating vocals from the rest of the accompaniment (i.e., karaoke mode). +`vocals` can be changed to any source in the selected model. +This will mix the files after separating the mix fully, so this won't be faster or use less memory. + +The `--shifts=SHIFTS` performs multiple predictions with random shifts (a.k.a the *shift trick*) of the input and average them. This makes prediction `SHIFTS` times +slower. Don't use it unless you have a GPU. + +The `--overlap` option controls the amount of overlap between prediction windows. Default is 0.25 (i.e. 25%) which is probably fine. +It can probably be reduced to 0.1 to improve a bit speed. + + +The `-j` flag allow to specify a number of parallel jobs (e.g. `demucs -j 2 myfile.mp3`). +This will multiply by the same amount the RAM used so be careful! + +### Memory requirements for GPU acceleration + +If you want to use GPU acceleration, you will need at least 3GB of RAM on your GPU for `demucs`. However, about 7GB of RAM will be required if you use the default arguments. Add `--segment SEGMENT` to change size of each split. If you only have 3GB memory, set SEGMENT to 8 (though quality may be worse if this argument is too small). Creating an environment variable `PYTORCH_NO_CUDA_MEMORY_CACHING=1` can help users with even smaller RAM such as 2GB (I separated a track that is 4 minutes but only 1.5GB is used), but this would make the separation slower. + +If you do not have enough memory on your GPU, simply add `-d cpu` to the command line to use the CPU. With Demucs, processing time should be roughly equal to 1.5 times the duration of the track. + +## Calling from another Python program + +The main function provides an `opt` parameter as a simple API. You can just pass the parsed command line as this parameter: +```python +# Assume that your command is `demucs --mp3 --two-stems vocals -n mdx_extra "track with space.mp3"` +# The following codes are same as the command above: +import demucs.separate +demucs.separate.main(["--mp3", "--two-stems", "vocals", "-n", "mdx_extra", "track with space.mp3"]) + +# Or like this +import demucs.separate +import shlex +demucs.separate.main(shlex.split('--mp3 --two-stems vocals -n mdx_extra "track with space.mp3"')) +``` + +To use more complicated APIs, see [API docs](docs/api.md) + +## Training Demucs + +If you want to train (Hybrid) Demucs, please follow the [training doc](docs/training.md). + +## MDX Challenge reproduction + +In order to reproduce the results from the Track A and Track B submissions, checkout the [MDX Hybrid Demucs submission repo][mdx_submission]. + + + +## How to cite + +``` +@inproceedings{rouard2022hybrid, + title={Hybrid Transformers for Music Source Separation}, + author={Rouard, Simon and Massa, Francisco and D{\'e}fossez, Alexandre}, + booktitle={ICASSP 23}, + year={2023} +} + +@inproceedings{defossez2021hybrid, + title={Hybrid Spectrogram and Waveform Source Separation}, + author={D{\'e}fossez, Alexandre}, + booktitle={Proceedings of the ISMIR 2021 Workshop on Music Source Separation}, + year={2021} +} +``` + +## License + +Demucs is released under the MIT license as found in the [LICENSE](LICENSE) file. + +[hybrid_paper]: https://arxiv.org/abs/2111.03600 +[waveunet]: https://github.com/f90/Wave-U-Net +[musdb]: https://sigsep.github.io/datasets/musdb.html +[openunmix]: https://github.com/sigsep/open-unmix-pytorch +[mmdenselstm]: https://arxiv.org/abs/1805.02410 +[demucs_v2]: https://github.com/facebookresearch/demucs/tree/v2 +[demucs_v3]: https://github.com/facebookresearch/demucs/tree/v3 +[spleeter]: https://github.com/deezer/spleeter +[soundcloud]: https://soundcloud.com/honualx/sets/source-separation-in-the-waveform-domain +[d3net]: https://arxiv.org/abs/2010.01733 +[mdx]: https://www.aicrowd.com/challenges/music-demixing-challenge-ismir-2021 +[kuielab]: https://github.com/kuielab/mdx-net-submission +[decouple]: https://arxiv.org/abs/2109.05418 +[mdx_submission]: https://github.com/adefossez/mdx21_demucs +[bandsplit]: https://arxiv.org/abs/2209.15174 +[htdemucs]: https://arxiv.org/abs/2211.08553 diff --git a/app.py b/app.py index d271244bad614271ca447f1f78f781c6a86195ac..7044f4095fff086cfbff969ab13f5a68b94973ab 100644 --- a/app.py +++ b/app.py @@ -1,37 +1,37 @@ -import os -import shutil -import gradio as gr -from demucs.separate import main - -def separate_stems(audio_file): - input_path = "input.mp3" - shutil.copy(audio_file, input_path) - - output_dir = "output" - if os.path.exists(output_dir): - shutil.rmtree(output_dir) - os.makedirs(output_dir, exist_ok=True) - - # Run Demucs - main(["-n", "htdemucs", "-o", output_dir, input_path]) - - # Build list of stems to return - base = os.path.splitext(os.path.basename(input_path))[0] - stem_path = os.path.join(output_dir, "htdemucs", base) - stems = [os.path.join(stem_path, f"{stem}.mp3") for stem in ["vocals", "drums", "bass", "other"]] - return stems - -demo = gr.Interface( - fn=separate_stems, - inputs=gr.Audio(type="filepath", label="Upload Song"), - outputs=[ - gr.Audio(label="Vocals"), - gr.Audio(label="Drums"), - gr.Audio(label="Bass"), - gr.Audio(label="Other"), - ], - title="Demucs v4 Stem Separator", - description="Upload a song to separate vocals, drums, bass, and other using Facebook's Demucs model.", -) - -demo.launch() +import os +import shutil +import gradio as gr +from demucs.separate import main + +def separate_stems(audio_file): + input_path = "input.mp3" + shutil.copy(audio_file, input_path) + + output_dir = "output" + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + os.makedirs(output_dir, exist_ok=True) + + # Run Demucs + main(["-n", "htdemucs", "-o", output_dir, input_path]) + + # Build list of stems to return + base = os.path.splitext(os.path.basename(input_path))[0] + stem_path = os.path.join(output_dir, "htdemucs", base) + stems = [os.path.join(stem_path, f"{stem}.mp3") for stem in ["vocals", "drums", "bass", "other"]] + return stems + +demo = gr.Interface( + fn=separate_stems, + inputs=gr.Audio(type="filepath", label="Upload Song"), + outputs=[ + gr.Audio(label="Vocals"), + gr.Audio(label="Drums"), + gr.Audio(label="Bass"), + gr.Audio(label="Other"), + ], + title="Demucs v4 Stem Separator", + description="Upload a song to separate vocals, drums, bass, and other using Facebook's Demucs model.", +) + +demo.launch() diff --git a/conf/config.yaml b/conf/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b76722d4aac077b2afa3110a59a247b5b48f6d96 --- /dev/null +++ b/conf/config.yaml @@ -0,0 +1,304 @@ +defaults: + - _self_ + - dset: musdb44 + - svd: default + - variant: default + - override hydra/hydra_logging: colorlog + - override hydra/job_logging: colorlog + +dummy: +dset: + musdb: /checkpoint/defossez/datasets/musdbhq + musdb_samplerate: 44100 + use_musdb: true # set to false to not use musdb as training data. + wav: # path to custom wav dataset + wav2: # second custom wav dataset + segment: 11 + shift: 1 + train_valid: false + full_cv: true + samplerate: 44100 + channels: 2 + normalize: true + metadata: ./metadata + sources: ['drums', 'bass', 'other', 'vocals'] + valid_samples: # valid dataset size + backend: null # if provided select torchaudio backend. + +test: + save: False + best: True + workers: 2 + every: 20 + split: true + shifts: 1 + overlap: 0.25 + sdr: true + metric: 'loss' # metric used for best model selection on the valid set, can also be nsdr + nonhq: # path to non hq MusDB for evaluation + +epochs: 360 +batch_size: 64 +max_batches: # limit the number of batches per epoch, useful for debugging + # or if your dataset is gigantic. +optim: + lr: 3e-4 + momentum: 0.9 + beta2: 0.999 + loss: l1 # l1 or mse + optim: adam + weight_decay: 0 + clip_grad: 0 + +seed: 42 +debug: false +valid_apply: true +flag: +save_every: +weights: [1., 1., 1., 1.] # weights over each source for the training/valid loss. + +augment: + shift_same: false + repitch: + proba: 0.2 + max_tempo: 12 + remix: + proba: 1 + group_size: 4 + scale: + proba: 1 + min: 0.25 + max: 1.25 + flip: true + +continue_from: # continue from other XP, give the XP Dora signature. +continue_pretrained: # signature of a pretrained XP, this cannot be a bag of models. +pretrained_repo: # repo for pretrained model (default is official AWS) +continue_best: true +continue_opt: false + +misc: + num_workers: 10 + num_prints: 4 + show: false + verbose: false + +# List of decay for EMA at batch or epoch level, e.g. 0.999. +# Batch level EMA are kept on GPU for speed. +ema: + epoch: [] + batch: [] + +use_train_segment: true # to remove +model_segment: # override the segment parameter for the model, usually 4 times the training segment. +model: demucs # see demucs/train.py for the possibilities, and config for each model hereafter. +demucs: # see demucs/demucs.py for a detailed description + # Channels + channels: 64 + growth: 2 + # Main structure + depth: 6 + rewrite: true + lstm_layers: 0 + # Convolutions + kernel_size: 8 + stride: 4 + context: 1 + # Activations + gelu: true + glu: true + # Normalization + norm_groups: 4 + norm_starts: 4 + # DConv residual branch + dconv_depth: 2 + dconv_mode: 1 # 1 = branch in encoder, 2 = in decoder, 3 = in both. + dconv_comp: 4 + dconv_attn: 4 + dconv_lstm: 4 + dconv_init: 1e-4 + # Pre/post treatment + resample: true + normalize: false + # Weight init + rescale: 0.1 + +hdemucs: # see demucs/hdemucs.py for a detailed description + # Channels + channels: 48 + channels_time: + growth: 2 + # STFT + nfft: 4096 + wiener_iters: 0 + end_iters: 0 + wiener_residual: false + cac: true + # Main structure + depth: 6 + rewrite: true + hybrid: true + hybrid_old: false + # Frequency Branch + multi_freqs: [] + multi_freqs_depth: 3 + freq_emb: 0.2 + emb_scale: 10 + emb_smooth: true + # Convolutions + kernel_size: 8 + stride: 4 + time_stride: 2 + context: 1 + context_enc: 0 + # normalization + norm_starts: 4 + norm_groups: 4 + # DConv residual branch + dconv_mode: 1 + dconv_depth: 2 + dconv_comp: 4 + dconv_attn: 4 + dconv_lstm: 4 + dconv_init: 1e-3 + # Weight init + rescale: 0.1 + +# Torchaudio implementation of HDemucs +torch_hdemucs: +# Channels + channels: 48 + growth: 2 + # STFT + nfft: 4096 + # Main structure + depth: 6 + freq_emb: 0.2 + emb_scale: 10 + emb_smooth: true + # Convolutions + kernel_size: 8 + stride: 4 + time_stride: 2 + context: 1 + context_enc: 0 + # normalization + norm_starts: 4 + norm_groups: 4 + # DConv residual branch + dconv_depth: 2 + dconv_comp: 4 + dconv_attn: 4 + dconv_lstm: 4 + dconv_init: 1e-3 + +htdemucs: # see demucs/htdemucs.py for a detailed description + # Channels + channels: 48 + channels_time: + growth: 2 + # STFT + nfft: 4096 + wiener_iters: 0 + end_iters: 0 + wiener_residual: false + cac: true + # Main structure + depth: 4 + rewrite: true + # Frequency Branch + multi_freqs: [] + multi_freqs_depth: 3 + freq_emb: 0.2 + emb_scale: 10 + emb_smooth: true + # Convolutions + kernel_size: 8 + stride: 4 + time_stride: 2 + context: 1 + context_enc: 0 + # normalization + norm_starts: 4 + norm_groups: 4 + # DConv residual branch + dconv_mode: 1 + dconv_depth: 2 + dconv_comp: 8 + dconv_init: 1e-3 + # Before the Transformer + bottom_channels: 0 + # CrossTransformer + # ------ Common to all + # Regular parameters + t_layers: 5 + t_hidden_scale: 4.0 + t_heads: 8 + t_dropout: 0.0 + t_layer_scale: True + t_gelu: True + # ------------- Positional Embedding + t_emb: sin + t_max_positions: 10000 # for the scaled embedding + t_max_period: 10000.0 + t_weight_pos_embed: 1.0 + t_cape_mean_normalize: True + t_cape_augment: True + t_cape_glob_loc_scale: [5000.0, 1.0, 1.4] + t_sin_random_shift: 0 + # ------------- norm before a transformer encoder + t_norm_in: True + t_norm_in_group: False + # ------------- norm inside the encoder + t_group_norm: False + t_norm_first: True + t_norm_out: True + # ------------- optim + t_weight_decay: 0.0 + t_lr: + # ------------- sparsity + t_sparse_self_attn: False + t_sparse_cross_attn: False + t_mask_type: diag + t_mask_random_seed: 42 + t_sparse_attn_window: 400 + t_global_window: 100 + t_sparsity: 0.95 + t_auto_sparsity: False + # Cross Encoder First (False) + t_cross_first: False + # Weight init + rescale: 0.1 + +svd: # see svd.py for documentation + penalty: 0 + min_size: 0.1 + dim: 1 + niters: 2 + powm: false + proba: 1 + conv_only: false + convtr: false + bs: 1 + +quant: # quantization hyper params + diffq: # diffq penalty, typically 1e-4 or 3e-4 + qat: # use QAT with a fixed number of bits (not as good as diffq) + min_size: 0.2 + group_size: 8 + +dora: + dir: outputs + exclude: ["misc.*", "slurm.*", 'test.reval', 'flag', 'dset.backend'] + +slurm: + time: 4320 + constraint: volta32gb + setup: ['module load cudnn/v8.4.1.50-cuda.11.6 NCCL/2.11.4-6-cuda.11.6 cuda/11.6'] + +# Hydra config +hydra: + job_logging: + formatters: + colorlog: + datefmt: "%m-%d %H:%M:%S" diff --git a/conf/dset/aetl.yaml b/conf/dset/aetl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c2560a1e6f978e6a8ad3ca1ba67c95fa23f060f8 --- /dev/null +++ b/conf/dset/aetl.yaml @@ -0,0 +1,19 @@ +# @package _global_ + +# automix dataset with Musdb, extra training data and the test set of Musdb. +# This used even more remixes than auto_extra_test. +dset: + wav: /checkpoint/defossez/datasets/aetl + samplerate: 44100 + channels: 2 +epochs: 320 +max_batches: 500 + +augment: + shift_same: true + scale: + proba: 0. + remix: + proba: 0 + repitch: + proba: 0 diff --git a/conf/dset/auto_extra_test.yaml b/conf/dset/auto_extra_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5455907232703ec8b3af498384e514eee5dce26c --- /dev/null +++ b/conf/dset/auto_extra_test.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +# automix dataset with Musdb, extra training data and the test set of Musdb. +dset: + wav: /checkpoint/defossez/datasets/automix_extra_test2 + samplerate: 44100 + channels: 2 +epochs: 320 +max_batches: 500 + +augment: + shift_same: true + scale: + proba: 0. + remix: + proba: 0 + repitch: + proba: 0 diff --git a/conf/dset/auto_mus.yaml b/conf/dset/auto_mus.yaml new file mode 100644 index 0000000000000000000000000000000000000000..407b5a8e8939c4b914636a0f6f002bc0e8897388 --- /dev/null +++ b/conf/dset/auto_mus.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +# Automix dataset based on musdb train set. +dset: + wav: /checkpoint/defossez/datasets/automix_musdb + samplerate: 44100 + channels: 2 +epochs: 360 +max_batches: 300 +test: + every: 4 + +augment: + shift_same: true + scale: + proba: 0.5 + remix: + proba: 0 + repitch: + proba: 0 diff --git a/conf/dset/extra44.yaml b/conf/dset/extra44.yaml new file mode 100644 index 0000000000000000000000000000000000000000..705dd209cc2b2e29b3f2b5fce93bd98c91cd2b8f --- /dev/null +++ b/conf/dset/extra44.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +# Musdb + extra tracks +dset: + wav: /checkpoint/defossez/datasets/allstems_44/ + samplerate: 44100 + channels: 2 +epochs: 320 diff --git a/conf/dset/extra_mmi_goodclean.yaml b/conf/dset/extra_mmi_goodclean.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3f0582f9a0e09f19e1b85082bbe8e9daa62b2c6e --- /dev/null +++ b/conf/dset/extra_mmi_goodclean.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# Musdb + extra tracks +dset: + wav: /checkpoint/defossez/datasets/allstems_44/ + wav2: /checkpoint/defossez/datasets/mmi44_goodclean + samplerate: 44100 + channels: 2 + wav2_weight: null + wav2_valid: false + valid_samples: 100 +epochs: 1200 diff --git a/conf/dset/extra_test.yaml b/conf/dset/extra_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..158f9b6370e614641453faac188f13a8f8dc659b --- /dev/null +++ b/conf/dset/extra_test.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# Musdb + extra tracks + test set from musdb. +dset: + wav: /checkpoint/defossez/datasets/allstems_test_44/ + samplerate: 44100 + channels: 2 +epochs: 320 +max_batches: 700 +test: + sdr: false + every: 500 diff --git a/conf/dset/musdb44.yaml b/conf/dset/musdb44.yaml new file mode 100644 index 0000000000000000000000000000000000000000..caa82dd8a6c66d6d47c3e1342d5fab1e0f68a862 --- /dev/null +++ b/conf/dset/musdb44.yaml @@ -0,0 +1,5 @@ +# @package _global_ + +dset: + samplerate: 44100 + channels: 2 \ No newline at end of file diff --git a/conf/dset/sdx23_bleeding.yaml b/conf/dset/sdx23_bleeding.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c0d37a8eefa536479ec5dd569bc6d696c0f4339b --- /dev/null +++ b/conf/dset/sdx23_bleeding.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +# Musdb + extra tracks +dset: + wav: /shared/home/defossez/data/datasets/moisesdb23_bleeding_v1.0/ + use_musdb: false + samplerate: 44100 + channels: 2 + backend: soundfile # must use soundfile as some mixture would clip with sox. +epochs: 320 diff --git a/conf/dset/sdx23_labelnoise.yaml b/conf/dset/sdx23_labelnoise.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e2411ad99d13efdc516e231893d63579c441a74c --- /dev/null +++ b/conf/dset/sdx23_labelnoise.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +# Musdb + extra tracks +dset: + wav: /shared/home/defossez/data/datasets/moisesdb23_labelnoise_v1.0 + use_musdb: false + samplerate: 44100 + channels: 2 + backend: soundfile # must use soundfile as some mixture would clip with sox. +epochs: 320 diff --git a/conf/svd/base.yaml b/conf/svd/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bb6c6a4f0f3c2ae0b4d32b6acc4d16cebc01ced8 --- /dev/null +++ b/conf/svd/base.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +svd: + penalty: 0 + min_size: 1 + dim: 50 + niters: 4 + powm: false + proba: 1 + conv_only: false + convtr: false # ideally this should be true, but some models were trained with this to false. + +optim: + beta2: 0.9998 \ No newline at end of file diff --git a/conf/svd/base2.yaml b/conf/svd/base2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1be73c82f8ea148a60bee80d5b35bc414b870a34 --- /dev/null +++ b/conf/svd/base2.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +svd: + penalty: 0 + min_size: 1 + dim: 100 + niters: 4 + powm: false + proba: 1 + conv_only: false + convtr: true + +optim: + beta2: 0.9998 \ No newline at end of file diff --git a/conf/svd/default.yaml b/conf/svd/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6c3b344f5605dd646eaa2df1c682b63e7501536a --- /dev/null +++ b/conf/svd/default.yaml @@ -0,0 +1 @@ +# @package _global_ diff --git a/conf/variant/default.yaml b/conf/variant/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6c3b344f5605dd646eaa2df1c682b63e7501536a --- /dev/null +++ b/conf/variant/default.yaml @@ -0,0 +1 @@ +# @package _global_ diff --git a/conf/variant/example.yaml b/conf/variant/example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..96ca521fa7bf5fa5297d7ea67ac58c5394d37f4c --- /dev/null +++ b/conf/variant/example.yaml @@ -0,0 +1,5 @@ +# @package _global_ + +model: hdemucs +hdemucs: + channels: 32 \ No newline at end of file diff --git a/conf/variant/finetune.yaml b/conf/variant/finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f121c878e3501dedfbfef8bd50b121b5da5912c0 --- /dev/null +++ b/conf/variant/finetune.yaml @@ -0,0 +1,19 @@ +# @package _global_ + +epochs: 4 +batch_size: 16 +optim: + lr: 0.0006 +test: + every: 1 + sdr: false +dset: + segment: 28 + shift: 2 + +augment: + scale: + proba: 0 + shift_same: true + remix: + proba: 0 diff --git a/demucs.png b/demucs.png new file mode 100644 index 0000000000000000000000000000000000000000..40e700de7b7bbd0fa0cf386a46d017564ae0802b --- /dev/null +++ b/demucs.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f8a53c1bbaa6c0268d358cd4cb9c2f1128907758aeb10a79789f7bbf61ded95 +size 339294 diff --git a/demucs/__init__.py b/demucs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df73c42395d5b5f6ee275baf57d908055021047e --- /dev/null +++ b/demucs/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +__version__ = "4.1.0a2" diff --git a/demucs/__main__.py b/demucs/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff67b147498655ac349203b63d3750b32aa2686d --- /dev/null +++ b/demucs/__main__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .separate import main + +if __name__ == '__main__': + main() diff --git a/demucs/api.py b/demucs/api.py new file mode 100644 index 0000000000000000000000000000000000000000..987a8699541d40e06f50b5a66eead96d1cb2fb4e --- /dev/null +++ b/demucs/api.py @@ -0,0 +1,392 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""API methods for demucs + +Classes +------- +`demucs.api.Separator`: The base separator class + +Functions +--------- +`demucs.api.save_audio`: Save an audio +`demucs.api.list_models`: Get models list + +Examples +-------- +See the end of this module (if __name__ == "__main__") +""" + +import subprocess + +import torch as th +import torchaudio as ta + +from dora.log import fatal +from pathlib import Path +from typing import Optional, Callable, Dict, Tuple, Union + +from .apply import apply_model, _replace_dict +from .audio import AudioFile, convert_audio, save_audio +from .pretrained import get_model, _parse_remote_files, REMOTE_ROOT +from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo + + +class LoadAudioError(Exception): + pass + + +class LoadModelError(Exception): + pass + + +class _NotProvided: + pass + + +NotProvided = _NotProvided() + + +class Separator: + def __init__( + self, + model: str = "htdemucs", + repo: Optional[Path] = None, + device: str = "cuda" if th.cuda.is_available() else "cpu", + shifts: int = 1, + overlap: float = 0.25, + split: bool = True, + segment: Optional[int] = None, + jobs: int = 0, + progress: bool = False, + callback: Optional[Callable[[dict], None]] = None, + callback_arg: Optional[dict] = None, + ): + """ + `class Separator` + ================= + + Parameters + ---------- + model: Pretrained model name or signature. Default is htdemucs. + repo: Folder containing all pre-trained models for use. + segment: Length (in seconds) of each segment (only available if `split` is `True`). If \ + not specified, will use the command line option. + shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \ + apply the oppositve shift to the output. This is repeated `shifts` time and all \ + predictions are averaged. This effectively makes the model time equivariant and \ + improves SDR by up to 0.2 points. If not specified, will use the command line option. + split: If True, the input will be broken down into small chunks (length set by `segment`) \ + and predictions will be performed individually on each and concatenated. Useful for \ + model with large memory footprint like Tasnet. If not specified, will use the command \ + line option. + overlap: The overlap between the splits. If not specified, will use the command line \ + option. + device (torch.device, str, or None): If provided, device on which to execute the \ + computation, otherwise `wav.device` is assumed. When `device` is different from \ + `wav.device`, only local computations will be on `device`, while the entire tracks \ + will be stored on `wav.device`. If not specified, will use the command line option. + jobs: Number of jobs. This can increase memory usage but will be much faster when \ + multiple cores are available. If not specified, will use the command line option. + callback: A function will be called when the separation of a chunk starts or finished. \ + The argument passed to the function will be a dict. For more information, please see \ + the Callback section. + callback_arg: A dict containing private parameters to be passed to callback function. For \ + more information, please see the Callback section. + progress: If true, show a progress bar. + + Callback + -------- + The function will be called with only one positional parameter whose type is `dict`. The + `callback_arg` will be combined with information of current separation progress. The + progress information will override the values in `callback_arg` if same key has been used. + To abort the separation, raise `KeyboardInterrupt`. + + Progress information contains several keys (These keys will always exist): + - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. + - `shift_idx`: The index of shifts. Starts from 0. + - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't + mean that it is at the 441000 second of the audio, but the "frame" of the tensor. + - `state`: Could be `"start"` or `"end"`. + - `audio_length`: Length of the audio (in "frame" of the tensor). + - `models`: Count of submodels in the model. + """ + self._name = model + self._repo = repo + self._load_model() + self.update_parameter(device=device, shifts=shifts, overlap=overlap, split=split, + segment=segment, jobs=jobs, progress=progress, callback=callback, + callback_arg=callback_arg) + + def update_parameter( + self, + device: Union[str, _NotProvided] = NotProvided, + shifts: Union[int, _NotProvided] = NotProvided, + overlap: Union[float, _NotProvided] = NotProvided, + split: Union[bool, _NotProvided] = NotProvided, + segment: Optional[Union[int, _NotProvided]] = NotProvided, + jobs: Union[int, _NotProvided] = NotProvided, + progress: Union[bool, _NotProvided] = NotProvided, + callback: Optional[ + Union[Callable[[dict], None], _NotProvided] + ] = NotProvided, + callback_arg: Optional[Union[dict, _NotProvided]] = NotProvided, + ): + """ + Update the parameters of separation. + + Parameters + ---------- + segment: Length (in seconds) of each segment (only available if `split` is `True`). If \ + not specified, will use the command line option. + shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \ + apply the oppositve shift to the output. This is repeated `shifts` time and all \ + predictions are averaged. This effectively makes the model time equivariant and \ + improves SDR by up to 0.2 points. If not specified, will use the command line option. + split: If True, the input will be broken down into small chunks (length set by `segment`) \ + and predictions will be performed individually on each and concatenated. Useful for \ + model with large memory footprint like Tasnet. If not specified, will use the command \ + line option. + overlap: The overlap between the splits. If not specified, will use the command line \ + option. + device (torch.device, str, or None): If provided, device on which to execute the \ + computation, otherwise `wav.device` is assumed. When `device` is different from \ + `wav.device`, only local computations will be on `device`, while the entire tracks \ + will be stored on `wav.device`. If not specified, will use the command line option. + jobs: Number of jobs. This can increase memory usage but will be much faster when \ + multiple cores are available. If not specified, will use the command line option. + callback: A function will be called when the separation of a chunk starts or finished. \ + The argument passed to the function will be a dict. For more information, please see \ + the Callback section. + callback_arg: A dict containing private parameters to be passed to callback function. For \ + more information, please see the Callback section. + progress: If true, show a progress bar. + + Callback + -------- + The function will be called with only one positional parameter whose type is `dict`. The + `callback_arg` will be combined with information of current separation progress. The + progress information will override the values in `callback_arg` if same key has been used. + To abort the separation, raise `KeyboardInterrupt`. + + Progress information contains several keys (These keys will always exist): + - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. + - `shift_idx`: The index of shifts. Starts from 0. + - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't + mean that it is at the 441000 second of the audio, but the "frame" of the tensor. + - `state`: Could be `"start"` or `"end"`. + - `audio_length`: Length of the audio (in "frame" of the tensor). + - `models`: Count of submodels in the model. + """ + if not isinstance(device, _NotProvided): + self._device = device + if not isinstance(shifts, _NotProvided): + self._shifts = shifts + if not isinstance(overlap, _NotProvided): + self._overlap = overlap + if not isinstance(split, _NotProvided): + self._split = split + if not isinstance(segment, _NotProvided): + self._segment = segment + if not isinstance(jobs, _NotProvided): + self._jobs = jobs + if not isinstance(progress, _NotProvided): + self._progress = progress + if not isinstance(callback, _NotProvided): + self._callback = callback + if not isinstance(callback_arg, _NotProvided): + self._callback_arg = callback_arg + + def _load_model(self): + self._model = get_model(name=self._name, repo=self._repo) + if self._model is None: + raise LoadModelError("Failed to load model") + self._audio_channels = self._model.audio_channels + self._samplerate = self._model.samplerate + + def _load_audio(self, track: Path): + errors = {} + wav = None + + try: + wav = AudioFile(track).read(streams=0, samplerate=self._samplerate, + channels=self._audio_channels) + except FileNotFoundError: + errors["ffmpeg"] = "FFmpeg is not installed." + except subprocess.CalledProcessError: + errors["ffmpeg"] = "FFmpeg could not read the file." + + if wav is None: + try: + wav, sr = ta.load(str(track)) + except RuntimeError as err: + errors["torchaudio"] = err.args[0] + else: + wav = convert_audio(wav, sr, self._samplerate, self._audio_channels) + + if wav is None: + raise LoadAudioError( + "\n".join( + "When trying to load using {}, got the following error: {}".format( + backend, error + ) + for backend, error in errors.items() + ) + ) + return wav + + def separate_tensor( + self, wav: th.Tensor, sr: Optional[int] = None + ) -> Tuple[th.Tensor, Dict[str, th.Tensor]]: + """ + Separate a loaded tensor. + + Parameters + ---------- + wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, \ + while the second is the waveform of each channel. Type should be float32. \ + e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels. + sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the \ + model. + + Returns + ------- + A tuple, whose first element is the original wave and second element is a dict, whose keys + are the name of stems and values are separated waves. The original wave will have already + been resampled. + + Notes + ----- + Use this function with cautiousness. This function does not provide data verifying. + """ + if sr is not None and sr != self.samplerate: + wav = convert_audio(wav, sr, self._samplerate, self._audio_channels) + ref = wav.mean(0) + wav -= ref.mean() + wav /= ref.std() + 1e-8 + out = apply_model( + self._model, + wav[None], + segment=self._segment, + shifts=self._shifts, + split=self._split, + overlap=self._overlap, + device=self._device, + num_workers=self._jobs, + callback=self._callback, + callback_arg=_replace_dict( + self._callback_arg, ("audio_length", wav.shape[1]) + ), + progress=self._progress, + ) + if out is None: + raise KeyboardInterrupt + out *= ref.std() + 1e-8 + out += ref.mean() + wav *= ref.std() + 1e-8 + wav += ref.mean() + return (wav, dict(zip(self._model.sources, out[0]))) + + def separate_audio_file(self, file: Path): + """ + Separate an audio file. The method will automatically read the file. + + Parameters + ---------- + wav: Path of the file to be separated. + + Returns + ------- + A tuple, whose first element is the original wave and second element is a dict, whose keys + are the name of stems and values are separated waves. The original wave will have already + been resampled. + """ + return self.separate_tensor(self._load_audio(file), self.samplerate) + + @property + def samplerate(self): + return self._samplerate + + @property + def audio_channels(self): + return self._audio_channels + + @property + def model(self): + return self._model + + +def list_models(repo: Optional[Path] = None) -> Dict[str, Dict[str, Union[str, Path]]]: + """ + List the available models. Please remember that not all the returned models can be + successfully loaded. + + Parameters + ---------- + repo: The repo whose models are to be listed. + + Returns + ------- + A dict with two keys ("single" for single models and "bag" for bag of models). The values are + lists whose components are strs. + """ + model_repo: ModelOnlyRepo + if repo is None: + models = _parse_remote_files(REMOTE_ROOT / 'files.txt') + model_repo = RemoteRepo(models) + bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) + else: + if not repo.is_dir(): + fatal(f"{repo} must exist and be a directory.") + model_repo = LocalRepo(repo) + bag_repo = BagOnlyRepo(repo, model_repo) + return {"single": model_repo.list_model(), "bag": bag_repo.list_model()} + + +if __name__ == "__main__": + # Test API functions + # two-stem not supported + + from .separate import get_parser + + args = get_parser().parse_args() + separator = Separator( + model=args.name, + repo=args.repo, + device=args.device, + shifts=args.shifts, + overlap=args.overlap, + split=args.split, + segment=args.segment, + jobs=args.jobs, + callback=print + ) + out = args.out / args.name + out.mkdir(parents=True, exist_ok=True) + for file in args.tracks: + separated = separator.separate_audio_file(file)[1] + if args.mp3: + ext = "mp3" + elif args.flac: + ext = "flac" + else: + ext = "wav" + kwargs = { + "samplerate": separator.samplerate, + "bitrate": args.mp3_bitrate, + "clip": args.clip_mode, + "as_float": args.float32, + "bits_per_sample": 24 if args.int24 else 16, + } + for stem, source in separated.items(): + stem = out / args.filename.format( + track=Path(file).name.rsplit(".", 1)[0], + trackext=Path(file).name.rsplit(".", 1)[-1], + stem=stem, + ext=ext, + ) + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(source, str(stem), **kwargs) diff --git a/demucs/apply.py b/demucs/apply.py new file mode 100644 index 0000000000000000000000000000000000000000..de01b1f56864a92377394bf433420998ff5fb24e --- /dev/null +++ b/demucs/apply.py @@ -0,0 +1,322 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Code to apply a model to a mix. It will handle chunking with overlaps and +inteprolation between chunks, as well as the "shift trick". +""" +from concurrent.futures import ThreadPoolExecutor +import copy +import random +from threading import Lock +import typing as tp + +import torch as th +from torch import nn +from torch.nn import functional as F +import tqdm + +from .demucs import Demucs +from .hdemucs import HDemucs +from .htdemucs import HTDemucs +from .utils import center_trim, DummyPoolExecutor + +Model = tp.Union[Demucs, HDemucs, HTDemucs] + + +class BagOfModels(nn.Module): + def __init__(self, models: tp.List[Model], + weights: tp.Optional[tp.List[tp.List[float]]] = None, + segment: tp.Optional[float] = None): + """ + Represents a bag of models with specific weights. + You should call `apply_model` rather than calling directly the forward here for + optimal performance. + + Args: + models (list[nn.Module]): list of Demucs/HDemucs models. + weights (list[list[float]]): list of weights. If None, assumed to + be all ones, otherwise it should be a list of N list (N number of models), + each containing S floats (S number of sources). + segment (None or float): overrides the `segment` attribute of each model + (this is performed inplace, be careful is you reuse the models passed). + """ + super().__init__() + assert len(models) > 0 + first = models[0] + for other in models: + assert other.sources == first.sources + assert other.samplerate == first.samplerate + assert other.audio_channels == first.audio_channels + if segment is not None: + if not isinstance(other, HTDemucs) and segment > other.segment: + other.segment = segment + + self.audio_channels = first.audio_channels + self.samplerate = first.samplerate + self.sources = first.sources + self.models = nn.ModuleList(models) + + if weights is None: + weights = [[1. for _ in first.sources] for _ in models] + else: + assert len(weights) == len(models) + for weight in weights: + assert len(weight) == len(first.sources) + self.weights = weights + + @property + def max_allowed_segment(self) -> float: + max_allowed_segment = float('inf') + for model in self.models: + if isinstance(model, HTDemucs): + max_allowed_segment = min(max_allowed_segment, float(model.segment)) + return max_allowed_segment + + def forward(self, x): + raise NotImplementedError("Call `apply_model` on this.") + + +class TensorChunk: + def __init__(self, tensor, offset=0, length=None): + total_length = tensor.shape[-1] + assert offset >= 0 + assert offset < total_length + + if length is None: + length = total_length - offset + else: + length = min(total_length - offset, length) + + if isinstance(tensor, TensorChunk): + self.tensor = tensor.tensor + self.offset = offset + tensor.offset + else: + self.tensor = tensor + self.offset = offset + self.length = length + self.device = tensor.device + + @property + def shape(self): + shape = list(self.tensor.shape) + shape[-1] = self.length + return shape + + def padded(self, target_length): + delta = target_length - self.length + total_length = self.tensor.shape[-1] + assert delta >= 0 + + start = self.offset - delta // 2 + end = start + target_length + + correct_start = max(0, start) + correct_end = min(total_length, end) + + pad_left = correct_start - start + pad_right = end - correct_end + + out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) + assert out.shape[-1] == target_length + return out + + +def tensor_chunk(tensor_or_chunk): + if isinstance(tensor_or_chunk, TensorChunk): + return tensor_or_chunk + else: + assert isinstance(tensor_or_chunk, th.Tensor) + return TensorChunk(tensor_or_chunk) + + +def _replace_dict(_dict: tp.Optional[dict], *subs: tp.Tuple[tp.Hashable, tp.Any]) -> dict: + if _dict is None: + _dict = {} + else: + _dict = copy.copy(_dict) + for key, value in subs: + _dict[key] = value + return _dict + + +def apply_model(model: tp.Union[BagOfModels, Model], + mix: tp.Union[th.Tensor, TensorChunk], + shifts: int = 1, split: bool = True, + overlap: float = 0.25, transition_power: float = 1., + progress: bool = False, device=None, + num_workers: int = 0, segment: tp.Optional[float] = None, + pool=None, lock=None, + callback: tp.Optional[tp.Callable[[dict], None]] = None, + callback_arg: tp.Optional[dict] = None) -> th.Tensor: + """ + Apply model to a given mixture. + + Args: + shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec + and apply the oppositve shift to the output. This is repeated `shifts` time and + all predictions are averaged. This effectively makes the model time equivariant + and improves SDR by up to 0.2 points. + split (bool): if True, the input will be broken down in 8 seconds extracts + and predictions will be performed individually on each and concatenated. + Useful for model with large memory footprint like Tasnet. + progress (bool): if True, show a progress bar (requires split=True) + device (torch.device, str, or None): if provided, device on which to + execute the computation, otherwise `mix.device` is assumed. + When `device` is different from `mix.device`, only local computations will + be on `device`, while the entire tracks will be stored on `mix.device`. + num_workers (int): if non zero, device is 'cpu', how many threads to + use in parallel. + segment (float or None): override the model segment parameter. + """ + if device is None: + device = mix.device + else: + device = th.device(device) + if pool is None: + if num_workers > 0 and device.type == 'cpu': + pool = ThreadPoolExecutor(num_workers) + else: + pool = DummyPoolExecutor() + if lock is None: + lock = Lock() + callback_arg = _replace_dict( + callback_arg, *{"model_idx_in_bag": 0, "shift_idx": 0, "segment_offset": 0}.items() + ) + kwargs: tp.Dict[str, tp.Any] = { + 'shifts': shifts, + 'split': split, + 'overlap': overlap, + 'transition_power': transition_power, + 'progress': progress, + 'device': device, + 'pool': pool, + 'segment': segment, + 'lock': lock, + } + out: tp.Union[float, th.Tensor] + res: tp.Union[float, th.Tensor] + if isinstance(model, BagOfModels): + # Special treatment for bag of model. + # We explicitely apply multiple times `apply_model` so that the random shifts + # are different for each model. + estimates: tp.Union[float, th.Tensor] = 0. + totals = [0.] * len(model.sources) + callback_arg["models"] = len(model.models) + for sub_model, model_weights in zip(model.models, model.weights): + kwargs["callback"] = (( + lambda d, i=callback_arg["model_idx_in_bag"]: callback( + _replace_dict(d, ("model_idx_in_bag", i))) if callback else None) + ) + original_model_device = next(iter(sub_model.parameters())).device + sub_model.to(device) + + res = apply_model(sub_model, mix, **kwargs, callback_arg=callback_arg) + out = res + sub_model.to(original_model_device) + for k, inst_weight in enumerate(model_weights): + out[:, k, :, :] *= inst_weight + totals[k] += inst_weight + estimates += out + del out + callback_arg["model_idx_in_bag"] += 1 + + assert isinstance(estimates, th.Tensor) + for k in range(estimates.shape[1]): + estimates[:, k, :, :] /= totals[k] + return estimates + + if "models" not in callback_arg: + callback_arg["models"] = 1 + model.to(device) + model.eval() + assert transition_power >= 1, "transition_power < 1 leads to weird behavior." + batch, channels, length = mix.shape + if shifts: + kwargs['shifts'] = 0 + max_shift = int(0.5 * model.samplerate) + mix = tensor_chunk(mix) + assert isinstance(mix, TensorChunk) + padded_mix = mix.padded(length + 2 * max_shift) + out = 0. + for shift_idx in range(shifts): + offset = random.randint(0, max_shift) + shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) + kwargs["callback"] = ( + (lambda d, i=shift_idx: callback(_replace_dict(d, ("shift_idx", i))) + if callback else None) + ) + res = apply_model(model, shifted, **kwargs, callback_arg=callback_arg) + shifted_out = res + out += shifted_out[..., max_shift - offset:] + out /= shifts + assert isinstance(out, th.Tensor) + return out + elif split: + kwargs['split'] = False + out = th.zeros(batch, len(model.sources), channels, length, device=mix.device) + sum_weight = th.zeros(length, device=mix.device) + if segment is None: + segment = model.segment + assert segment is not None and segment > 0. + segment_length: int = int(model.samplerate * segment) + stride = int((1 - overlap) * segment_length) + offsets = range(0, length, stride) + scale = float(format(stride / model.samplerate, ".2f")) + # We start from a triangle shaped weight, with maximal weight in the middle + # of the segment. Then we normalize and take to the power `transition_power`. + # Large values of transition power will lead to sharper transitions. + weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device), + th.arange(segment_length - segment_length // 2, 0, -1, device=device)]) + assert len(weight) == segment_length + # If the overlap < 50%, this will translate to linear transition when + # transition_power is 1. + weight = (weight / weight.max())**transition_power + futures = [] + for offset in offsets: + chunk = TensorChunk(mix, offset, segment_length) + future = pool.submit(apply_model, model, chunk, **kwargs, callback_arg=callback_arg, + callback=(lambda d, i=offset: + callback(_replace_dict(d, ("segment_offset", i))) + if callback else None)) + futures.append((future, offset)) + offset += segment_length + if progress: + futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds') + for future, offset in futures: + try: + chunk_out = future.result() # type: th.Tensor + except Exception: + pool.shutdown(wait=True, cancel_futures=True) + raise + chunk_length = chunk_out.shape[-1] + out[..., offset:offset + segment_length] += ( + weight[:chunk_length] * chunk_out).to(mix.device) + sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device) + assert sum_weight.min() > 0 + out /= sum_weight + assert isinstance(out, th.Tensor) + return out + else: + valid_length: int + if isinstance(model, HTDemucs) and segment is not None: + valid_length = int(segment * model.samplerate) + elif hasattr(model, 'valid_length'): + valid_length = model.valid_length(length) # type: ignore + else: + valid_length = length + mix = tensor_chunk(mix) + assert isinstance(mix, TensorChunk) + padded_mix = mix.padded(valid_length).to(device) + with lock: + if callback is not None: + callback(_replace_dict(callback_arg, ("state", "start"))) # type: ignore + with th.no_grad(): + out = model(padded_mix) + with lock: + if callback is not None: + callback(_replace_dict(callback_arg, ("state", "end"))) # type: ignore + assert isinstance(out, th.Tensor) + return center_trim(out, length) diff --git a/demucs/audio.py b/demucs/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d50f5fb538bee91baa913cc22210bb48cc8792 --- /dev/null +++ b/demucs/audio.py @@ -0,0 +1,265 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import json +import subprocess as sp +from pathlib import Path + +import lameenc +import julius +import numpy as np +import torch +import torchaudio as ta +import typing as tp + +from .utils import temp_filenames + + +def _read_info(path): + stdout_data = sp.check_output([ + 'ffprobe', "-loglevel", "panic", + str(path), '-print_format', 'json', '-show_format', '-show_streams' + ]) + return json.loads(stdout_data.decode('utf-8')) + + +class AudioFile: + """ + Allows to read audio from any format supported by ffmpeg, as well as resampling or + converting to mono on the fly. See :method:`read` for more details. + """ + def __init__(self, path: Path): + self.path = Path(path) + self._info = None + + def __repr__(self): + features = [("path", self.path)] + features.append(("samplerate", self.samplerate())) + features.append(("channels", self.channels())) + features.append(("streams", len(self))) + features_str = ", ".join(f"{name}={value}" for name, value in features) + return f"AudioFile({features_str})" + + @property + def info(self): + if self._info is None: + self._info = _read_info(self.path) + return self._info + + @property + def duration(self): + return float(self.info['format']['duration']) + + @property + def _audio_streams(self): + return [ + index for index, stream in enumerate(self.info["streams"]) + if stream["codec_type"] == "audio" + ] + + def __len__(self): + return len(self._audio_streams) + + def channels(self, stream=0): + return int(self.info['streams'][self._audio_streams[stream]]['channels']) + + def samplerate(self, stream=0): + return int(self.info['streams'][self._audio_streams[stream]]['sample_rate']) + + def read(self, + seek_time=None, + duration=None, + streams=slice(None), + samplerate=None, + channels=None): + """ + Slightly more efficient implementation than stempeg, + in particular, this will extract all stems at once + rather than having to loop over one file multiple times + for each stream. + + Args: + seek_time (float): seek time in seconds or None if no seeking is needed. + duration (float): duration in seconds to extract or None to extract until the end. + streams (slice, int or list): streams to extract, can be a single int, a list or + a slice. If it is a slice or list, the output will be of size [S, C, T] + with S the number of streams, C the number of channels and T the number of samples. + If it is an int, the output will be [C, T]. + samplerate (int): if provided, will resample on the fly. If None, no resampling will + be done. Original sampling rate can be obtained with :method:`samplerate`. + channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that + as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers. + See https://sound.stackexchange.com/a/42710. + Our definition of mono is simply the average of the two channels. Any other + value will be ignored. + """ + streams = np.array(range(len(self)))[streams] + single = not isinstance(streams, np.ndarray) + if single: + streams = [streams] + + if duration is None: + target_size = None + query_duration = None + else: + target_size = int((samplerate or self.samplerate()) * duration) + query_duration = float((target_size + 1) / (samplerate or self.samplerate())) + + with temp_filenames(len(streams)) as filenames: + command = ['ffmpeg', '-y'] + command += ['-loglevel', 'panic'] + if seek_time: + command += ['-ss', str(seek_time)] + command += ['-i', str(self.path)] + for stream, filename in zip(streams, filenames): + command += ['-map', f'0:{self._audio_streams[stream]}'] + if query_duration is not None: + command += ['-t', str(query_duration)] + command += ['-threads', '1'] + command += ['-f', 'f32le'] + if samplerate is not None: + command += ['-ar', str(samplerate)] + command += [filename] + + sp.run(command, check=True) + wavs = [] + for filename in filenames: + wav = np.fromfile(filename, dtype=np.float32) + wav = torch.from_numpy(wav) + wav = wav.view(-1, self.channels()).t() + if channels is not None: + wav = convert_audio_channels(wav, channels) + if target_size is not None: + wav = wav[..., :target_size] + wavs.append(wav) + wav = torch.stack(wavs, dim=0) + if single: + wav = wav[0] + return wav + + +def convert_audio_channels(wav, channels=2): + """Convert audio to the given number of channels.""" + *shape, src_channels, length = wav.shape + if src_channels == channels: + pass + elif channels == 1: + # Case 1: + # The caller asked 1-channel audio, but the stream have multiple + # channels, downmix all channels. + wav = wav.mean(dim=-2, keepdim=True) + elif src_channels == 1: + # Case 2: + # The caller asked for multiple channels, but the input file have + # one single channel, replicate the audio over all channels. + wav = wav.expand(*shape, channels, length) + elif src_channels >= channels: + # Case 3: + # The caller asked for multiple channels, and the input file have + # more channels than requested. In that case return the first channels. + wav = wav[..., :channels, :] + else: + # Case 4: What is a reasonable choice here? + raise ValueError('The audio file has less channels than requested but is not mono.') + return wav + + +def convert_audio(wav, from_samplerate, to_samplerate, channels) -> torch.Tensor: + """Convert audio from a given samplerate to a target one and target number of channels.""" + wav = convert_audio_channels(wav, channels) + return julius.resample_frac(wav, from_samplerate, to_samplerate) + + +def i16_pcm(wav): + """Convert audio to 16 bits integer PCM format.""" + if wav.dtype.is_floating_point: + return (wav.clamp_(-1, 1) * (2**15 - 1)).short() + else: + return wav + + +def f32_pcm(wav): + """Convert audio to float 32 bits PCM format.""" + if wav.dtype.is_floating_point: + return wav + else: + return wav.float() / (2**15 - 1) + + +def as_dtype_pcm(wav, dtype): + """Convert audio to either f32 pcm or i16 pcm depending on the given dtype.""" + if wav.dtype.is_floating_point: + return f32_pcm(wav) + else: + return i16_pcm(wav) + + +def encode_mp3(wav, path, samplerate=44100, bitrate=320, quality=2, verbose=False): + """Save given audio as mp3. This should work on all OSes.""" + C, T = wav.shape + wav = i16_pcm(wav) + encoder = lameenc.Encoder() + encoder.set_bit_rate(bitrate) + encoder.set_in_sample_rate(samplerate) + encoder.set_channels(C) + encoder.set_quality(quality) # 2-highest, 7-fastest + if not verbose: + encoder.silence() + wav = wav.data.cpu() + wav = wav.transpose(0, 1).numpy() + mp3_data = encoder.encode(wav.tobytes()) + mp3_data += encoder.flush() + with open(path, "wb") as f: + f.write(mp3_data) + + +def prevent_clip(wav, mode='rescale'): + """ + different strategies for avoiding raw clipping. + """ + if mode is None or mode == 'none': + return wav + assert wav.dtype.is_floating_point, "too late for clipping" + if mode == 'rescale': + wav = wav / max(1.01 * wav.abs().max(), 1) + elif mode == 'clamp': + wav = wav.clamp(-0.99, 0.99) + elif mode == 'tanh': + wav = torch.tanh(wav) + else: + raise ValueError(f"Invalid mode {mode}") + return wav + + +def save_audio(wav: torch.Tensor, + path: tp.Union[str, Path], + samplerate: int, + bitrate: int = 320, + clip: tp.Literal["rescale", "clamp", "tanh", "none"] = 'rescale', + bits_per_sample: tp.Literal[16, 24, 32] = 16, + as_float: bool = False, + preset: tp.Literal[2, 3, 4, 5, 6, 7] = 2): + """Save audio file, automatically preventing clipping if necessary + based on the given `clip` strategy. If the path ends in `.mp3`, this + will save as mp3 with the given `bitrate`. Use `preset` to set mp3 quality: + 2 for highest quality, 7 for fastest speed + """ + wav = prevent_clip(wav, mode=clip) + path = Path(path) + suffix = path.suffix.lower() + if suffix == ".mp3": + encode_mp3(wav, path, samplerate, bitrate, preset, verbose=True) + elif suffix == ".wav": + if as_float: + bits_per_sample = 32 + encoding = 'PCM_F' + else: + encoding = 'PCM_S' + ta.save(str(path), wav, sample_rate=samplerate, + encoding=encoding, bits_per_sample=bits_per_sample) + elif suffix == ".flac": + ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample) + else: + raise ValueError(f"Invalid suffix for path: {suffix}") diff --git a/demucs/augment.py b/demucs/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..94ecdfd6f6c937133dec779564db89aec7430193 --- /dev/null +++ b/demucs/augment.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Data augmentations. +""" + +import random +import torch as th +from torch import nn + + +class Shift(nn.Module): + """ + Randomly shift audio in time by up to `shift` samples. + """ + def __init__(self, shift=8192, same=False): + super().__init__() + self.shift = shift + self.same = same + + def forward(self, wav): + batch, sources, channels, time = wav.size() + length = time - self.shift + if self.shift > 0: + if not self.training: + wav = wav[..., :length] + else: + srcs = 1 if self.same else sources + offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device) + offsets = offsets.expand(-1, sources, channels, -1) + indexes = th.arange(length, device=wav.device) + wav = wav.gather(3, indexes + offsets) + return wav + + +class FlipChannels(nn.Module): + """ + Flip left-right channels. + """ + def forward(self, wav): + batch, sources, channels, time = wav.size() + if self.training and wav.size(2) == 2: + left = th.randint(2, (batch, sources, 1, 1), device=wav.device) + left = left.expand(-1, -1, -1, time) + right = 1 - left + wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2) + return wav + + +class FlipSign(nn.Module): + """ + Random sign flip. + """ + def forward(self, wav): + batch, sources, channels, time = wav.size() + if self.training: + signs = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32) + wav = wav * (2 * signs - 1) + return wav + + +class Remix(nn.Module): + """ + Shuffle sources to make new mixes. + """ + def __init__(self, proba=1, group_size=4): + """ + Shuffle sources within one batch. + Each batch is divided into groups of size `group_size` and shuffling is done within + each group separatly. This allow to keep the same probability distribution no matter + the number of GPUs. Without this grouping, using more GPUs would lead to a higher + probability of keeping two sources from the same track together which can impact + performance. + """ + super().__init__() + self.proba = proba + self.group_size = group_size + + def forward(self, wav): + batch, streams, channels, time = wav.size() + device = wav.device + + if self.training and random.random() < self.proba: + group_size = self.group_size or batch + if batch % group_size != 0: + raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}") + groups = batch // group_size + wav = wav.view(groups, group_size, streams, channels, time) + permutations = th.argsort(th.rand(groups, group_size, streams, 1, 1, device=device), + dim=1) + wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time)) + wav = wav.view(batch, streams, channels, time) + return wav + + +class Scale(nn.Module): + def __init__(self, proba=1., min=0.25, max=1.25): + super().__init__() + self.proba = proba + self.min = min + self.max = max + + def forward(self, wav): + batch, streams, channels, time = wav.size() + device = wav.device + if self.training and random.random() < self.proba: + scales = th.empty(batch, streams, 1, 1, device=device).uniform_(self.min, self.max) + wav *= scales + return wav diff --git a/demucs/demucs.py b/demucs/demucs.py new file mode 100644 index 0000000000000000000000000000000000000000..bfebccdd68b3116c0d77bf6935564ded42cf2072 --- /dev/null +++ b/demucs/demucs.py @@ -0,0 +1,447 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import typing as tp + +import julius +import torch +from torch import nn +from torch.nn import functional as F + +from .states import capture_init +from .utils import center_trim, unfold +from .transformer import LayerScale + + +class BLSTM(nn.Module): + """ + BiLSTM with same hidden units as input dim. + If `max_steps` is not None, input will be splitting in overlapping + chunks and the LSTM applied separately on each chunk. + """ + def __init__(self, dim, layers=1, max_steps=None, skip=False): + super().__init__() + assert max_steps is None or max_steps % 4 == 0 + self.max_steps = max_steps + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + self.skip = skip + + def forward(self, x): + B, C, T = x.shape + y = x + framed = False + if self.max_steps is not None and T > self.max_steps: + width = self.max_steps + stride = width // 2 + frames = unfold(x, width, stride) + nframes = frames.shape[2] + framed = True + x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) + + x = x.permute(2, 0, 1) + + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + if framed: + out = [] + frames = x.reshape(B, -1, C, width) + limit = stride // 2 + for k in range(nframes): + if k == 0: + out.append(frames[:, k, :, :-limit]) + elif k == nframes - 1: + out.append(frames[:, k, :, limit:]) + else: + out.append(frames[:, k, :, limit:-limit]) + out = torch.cat(out, -1) + out = out[..., :T] + x = out + if self.skip: + x = x + y + return x + + +def rescale_conv(conv, reference): + """Rescale initial weight scale. It is unclear why it helps but it certainly does. + """ + std = conv.weight.std().detach() + scale = (std / reference)**0.5 + conv.weight.data /= scale + if conv.bias is not None: + conv.bias.data /= scale + + +def rescale_module(module, reference): + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): + rescale_conv(sub, reference) + + +class DConv(nn.Module): + """ + New residual branches in each encoder layer. + This alternates dilated convolutions, potentially with LSTMs and attention. + Also before entering each residual branch, dimension is projected on a smaller subspace, + e.g. of dim `channels // compress`. + """ + def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, + norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, + kernel=3, dilate=True): + """ + Args: + channels: input/output channels for residual branch. + compress: amount of channel compression inside the branch. + depth: number of layers in the residual branch. Each layer has its own + projection, and potentially LSTM and attention. + init: initial scale for LayerNorm. + norm: use GroupNorm. + attn: use LocalAttention. + heads: number of heads for the LocalAttention. + ndecay: number of decay controls in the LocalAttention. + lstm: use LSTM. + gelu: Use GELU activation. + kernel: kernel size for the (dilated) convolutions. + dilate: if true, use dilation, increasing with the depth. + """ + + super().__init__() + assert kernel % 2 == 1 + self.channels = channels + self.compress = compress + self.depth = abs(depth) + dilate = depth > 0 + + norm_fn: tp.Callable[[int], nn.Module] + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(1, d) # noqa + + hidden = int(channels / compress) + + act: tp.Type[nn.Module] + if gelu: + act = nn.GELU + else: + act = nn.ReLU + + self.layers = nn.ModuleList([]) + for d in range(self.depth): + dilation = 2 ** d if dilate else 1 + padding = dilation * (kernel // 2) + mods = [ + nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding), + norm_fn(hidden), act(), + nn.Conv1d(hidden, 2 * channels, 1), + norm_fn(2 * channels), nn.GLU(1), + LayerScale(channels, init), + ] + if attn: + mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay)) + if lstm: + mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True)) + layer = nn.Sequential(*mods) + self.layers.append(layer) + + def forward(self, x): + for layer in self.layers: + x = x + layer(x) + return x + + +class LocalState(nn.Module): + """Local state allows to have attention based only on data (no positional embedding), + but while setting a constraint on the time window (e.g. decaying penalty term). + + Also a failed experiments with trying to provide some frequency based attention. + """ + def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4): + super().__init__() + assert channels % heads == 0, (channels, heads) + self.heads = heads + self.nfreqs = nfreqs + self.ndecay = ndecay + self.content = nn.Conv1d(channels, channels, 1) + self.query = nn.Conv1d(channels, channels, 1) + self.key = nn.Conv1d(channels, channels, 1) + if nfreqs: + self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1) + if ndecay: + self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) + # Initialize decay close to zero (there is a sigmoid), for maximum initial window. + self.query_decay.weight.data *= 0.01 + assert self.query_decay.bias is not None # stupid type checker + self.query_decay.bias.data[:] = -2 + self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1) + + def forward(self, x): + B, C, T = x.shape + heads = self.heads + indexes = torch.arange(T, device=x.device, dtype=x.dtype) + # left index are keys, right index are queries + delta = indexes[:, None] - indexes[None, :] + + queries = self.query(x).view(B, heads, -1, T) + keys = self.key(x).view(B, heads, -1, T) + # t are keys, s are queries + dots = torch.einsum("bhct,bhcs->bhts", keys, queries) + dots /= keys.shape[2]**0.5 + if self.nfreqs: + periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype) + freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1)) + freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5 + dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q) + if self.ndecay: + decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) + decay_q = self.query_decay(x).view(B, heads, -1, T) + decay_q = torch.sigmoid(decay_q) / 2 + decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5 + dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) + + # Kill self reference. + dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) + weights = torch.softmax(dots, dim=2) + + content = self.content(x).view(B, heads, -1, T) + result = torch.einsum("bhts,bhct->bhcs", weights, content) + if self.nfreqs: + time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel) + result = torch.cat([result, time_sig], 2) + result = result.reshape(B, -1, T) + return x + self.proj(result) + + +class Demucs(nn.Module): + @capture_init + def __init__(self, + sources, + # Channels + audio_channels=2, + channels=64, + growth=2., + # Main structure + depth=6, + rewrite=True, + lstm_layers=0, + # Convolutions + kernel_size=8, + stride=4, + context=1, + # Activations + gelu=True, + glu=True, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=4, + dconv_attn=4, + dconv_lstm=4, + dconv_init=1e-4, + # Pre/post processing + normalize=True, + resample=True, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=4 * 10): + """ + Args: + sources (list[str]): list of source names + audio_channels (int): stereo or mono + channels (int): first convolution channels + depth (int): number of encoder/decoder layers + growth (float): multiply (resp divide) number of channels by that + for each layer of the encoder (resp decoder) + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated + by default, as this is now replaced by the smaller and faster small LSTMs + in the DConv branches. + kernel_size (int): kernel size for convolutions + stride (int): stride for convolutions + context (int): kernel size of the convolution in the + decoder before the transposed convolution. If > 1, + will provide some context from neighboring time steps. + gelu: use GELU activation function. + glu (bool): use glu instead of ReLU for the 1x1 rewrite conv. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + normalize (bool): normalizes the input audio on the fly, and scales back + the output by the same amount. + resample (bool): upsample x2 the input and downsample /2 the output. + rescale (float): rescale initial weights of convolutions + to get their standard deviation closer to `rescale`. + samplerate (int): stored as meta information for easing + future evaluations of the model. + segment (float): duration of the chunks of audio to ideally evaluate the model on. + This is used by `demucs.apply.apply_model`. + """ + + super().__init__() + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.resample = resample + self.channels = channels + self.normalize = normalize + self.samplerate = samplerate + self.segment = segment + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + self.skip_scales = nn.ModuleList() + + if glu: + activation = nn.GLU(dim=1) + ch_scale = 2 + else: + activation = nn.ReLU() + ch_scale = 1 + if gelu: + act2 = nn.GELU + else: + act2 = nn.ReLU + + in_channels = audio_channels + padding = 0 + for index in range(depth): + norm_fn = lambda d: nn.Identity() # noqa + if index >= norm_starts: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + + encode = [] + encode += [ + nn.Conv1d(in_channels, channels, kernel_size, stride), + norm_fn(channels), + act2(), + ] + attn = index >= dconv_attn + lstm = index >= dconv_lstm + if dconv_mode & 1: + encode += [DConv(channels, depth=dconv_depth, init=dconv_init, + compress=dconv_comp, attn=attn, lstm=lstm)] + if rewrite: + encode += [ + nn.Conv1d(channels, ch_scale * channels, 1), + norm_fn(ch_scale * channels), activation] + self.encoder.append(nn.Sequential(*encode)) + + decode = [] + if index > 0: + out_channels = in_channels + else: + out_channels = len(self.sources) * audio_channels + if rewrite: + decode += [ + nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), + norm_fn(ch_scale * channels), activation] + if dconv_mode & 2: + decode += [DConv(channels, depth=dconv_depth, init=dconv_init, + compress=dconv_comp, attn=attn, lstm=lstm)] + decode += [nn.ConvTranspose1d(channels, out_channels, + kernel_size, stride, padding=padding)] + if index > 0: + decode += [norm_fn(out_channels), act2()] + self.decoder.insert(0, nn.Sequential(*decode)) + in_channels = channels + channels = int(growth * channels) + + channels = in_channels + if lstm_layers: + self.lstm = BLSTM(channels, lstm_layers) + else: + self.lstm = None + + if rescale: + rescale_module(self, reference=rescale) + + def valid_length(self, length): + """ + Return the nearest valid length to use with the model so that + there is no time steps left over in a convolution, e.g. for all + layers, size of the input - kernel_size % stride = 0. + + Note that input are automatically padded if necessary to ensure that the output + has the same length as the input. + """ + if self.resample: + length *= 2 + + for _ in range(self.depth): + length = math.ceil((length - self.kernel_size) / self.stride) + 1 + length = max(1, length) + + for idx in range(self.depth): + length = (length - 1) * self.stride + self.kernel_size + + if self.resample: + length = math.ceil(length / 2) + return int(length) + + def forward(self, mix): + x = mix + length = x.shape[-1] + + if self.normalize: + mono = mix.mean(dim=1, keepdim=True) + mean = mono.mean(dim=-1, keepdim=True) + std = mono.std(dim=-1, keepdim=True) + x = (x - mean) / (1e-5 + std) + else: + mean = 0 + std = 1 + + delta = self.valid_length(length) - length + x = F.pad(x, (delta // 2, delta - delta // 2)) + + if self.resample: + x = julius.resample_frac(x, 1, 2) + + saved = [] + for encode in self.encoder: + x = encode(x) + saved.append(x) + + if self.lstm: + x = self.lstm(x) + + for decode in self.decoder: + skip = saved.pop(-1) + skip = center_trim(skip, x) + x = decode(x + skip) + + if self.resample: + x = julius.resample_frac(x, 2, 1) + x = x * std + mean + x = center_trim(x, length) + x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) + return x + + def load_state_dict(self, state, strict=True): + # fix a mismatch with previous generation Demucs models. + for idx in range(self.depth): + for a in ['encoder', 'decoder']: + for b in ['bias', 'weight']: + new = f'{a}.{idx}.3.{b}' + old = f'{a}.{idx}.2.{b}' + if old in state and new not in state: + state[new] = state.pop(old) + super().load_state_dict(state, strict=strict) diff --git a/demucs/distrib.py b/demucs/distrib.py new file mode 100644 index 0000000000000000000000000000000000000000..305fa4611a5062a518121acc2ef6e49eaae9b95f --- /dev/null +++ b/demucs/distrib.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Distributed training utilities. +""" +import logging +import pickle + +import numpy as np +import torch +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader, Subset +from torch.nn.parallel.distributed import DistributedDataParallel + +from dora import distrib as dora_distrib + +logger = logging.getLogger(__name__) +rank = 0 +world_size = 1 + + +def init(): + global rank, world_size + if not torch.distributed.is_initialized(): + dora_distrib.init() + rank = dora_distrib.rank() + world_size = dora_distrib.world_size() + + +def average(metrics, count=1.): + if isinstance(metrics, dict): + keys, values = zip(*sorted(metrics.items())) + values = average(values, count) + return dict(zip(keys, values)) + if world_size == 1: + return metrics + tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32) + tensor *= count + torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) + return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist() + + +def wrap(model): + if world_size == 1: + return model + else: + return DistributedDataParallel( + model, + # find_unused_parameters=True, + device_ids=[torch.cuda.current_device()], + output_device=torch.cuda.current_device()) + + +def barrier(): + if world_size > 1: + torch.distributed.barrier() + + +def share(obj=None, src=0): + if world_size == 1: + return obj + size = torch.empty(1, device='cuda', dtype=torch.long) + if rank == src: + dump = pickle.dumps(obj) + size[0] = len(dump) + torch.distributed.broadcast(size, src=src) + # size variable is now set to the length of pickled obj in all processes + + if rank == src: + buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda() + else: + buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8) + torch.distributed.broadcast(buffer, src=src) + # buffer variable is now set to pickled obj in all processes + + if rank != src: + obj = pickle.loads(buffer.cpu().numpy().tobytes()) + logger.debug(f"Shared object of size {len(buffer)}") + return obj + + +def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs): + """ + Create a dataloader properly in case of distributed training. + If a gradient is going to be computed you must set `shuffle=True`. + """ + if world_size == 1: + return klass(dataset, *args, shuffle=shuffle, **kwargs) + + if shuffle: + # train means we will compute backward, we use DistributedSampler + sampler = DistributedSampler(dataset) + # We ignore shuffle, DistributedSampler already shuffles + return klass(dataset, *args, **kwargs, sampler=sampler) + else: + # We make a manual shard, as DistributedSampler otherwise replicate some examples + dataset = Subset(dataset, list(range(rank, len(dataset), world_size))) + return klass(dataset, *args, shuffle=shuffle, **kwargs) diff --git a/demucs/ema.py b/demucs/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..62bd219db23246a2c6c49bad2a70bec765d47cf6 --- /dev/null +++ b/demucs/ema.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Inspired from https://github.com/rwightman/pytorch-image-models +from contextlib import contextmanager + +import torch + +from .states import swap_state + + +class ModelEMA: + """ + Perform EMA on a model. You can switch to the EMA weights temporarily + with the `swap` method. + + ema = ModelEMA(model) + with ema.swap(): + # compute valid metrics with averaged model. + """ + def __init__(self, model, decay=0.9999, unbias=True, device='cpu'): + self.decay = decay + self.model = model + self.state = {} + self.count = 0 + self.device = device + self.unbias = unbias + + self._init() + + def _init(self): + for key, val in self.model.state_dict().items(): + if val.dtype != torch.float32: + continue + device = self.device or val.device + if key not in self.state: + self.state[key] = val.detach().to(device, copy=True) + + def update(self): + if self.unbias: + self.count = self.count * self.decay + 1 + w = 1 / self.count + else: + w = 1 - self.decay + for key, val in self.model.state_dict().items(): + if val.dtype != torch.float32: + continue + device = self.device or val.device + self.state[key].mul_(1 - w) + self.state[key].add_(val.detach().to(device), alpha=w) + + @contextmanager + def swap(self): + with swap_state(self.model, self.state): + yield + + def state_dict(self): + return {'state': self.state, 'count': self.count} + + def load_state_dict(self, state): + self.count = state['count'] + for k, v in state['state'].items(): + self.state[k].copy_(v) diff --git a/demucs/evaluate.py b/demucs/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..1dcb11c65ad35755e19cda842192a83b0ccae093 --- /dev/null +++ b/demucs/evaluate.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Test time evaluation, either using the original SDR from [Vincent et al. 2006] +or the newest SDR definition from the MDX 2021 competition (this one will +be reported as `nsdr` for `new sdr`). +""" + +from concurrent import futures +import logging + +from dora.log import LogProgress +import numpy as np +import musdb +import museval +import torch as th + +from .apply import apply_model +from .audio import convert_audio, save_audio +from . import distrib +from .utils import DummyPoolExecutor + + +logger = logging.getLogger(__name__) + + +def new_sdr(references, estimates): + """ + Compute the SDR according to the MDX challenge definition. + Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license) + """ + assert references.dim() == 4 + assert estimates.dim() == 4 + delta = 1e-7 # avoid numerical errors + num = th.sum(th.square(references), dim=(2, 3)) + den = th.sum(th.square(references - estimates), dim=(2, 3)) + num += delta + den += delta + scores = 10 * th.log10(num / den) + return scores + + +def eval_track(references, estimates, win, hop, compute_sdr=True): + references = references.transpose(1, 2).double() + estimates = estimates.transpose(1, 2).double() + + new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0] + + if not compute_sdr: + return None, new_scores + else: + references = references.numpy() + estimates = estimates.numpy() + scores = museval.metrics.bss_eval( + references, estimates, + compute_permutation=False, + window=win, + hop=hop, + framewise_filters=False, + bsseval_sources_version=False)[:-1] + return scores, new_scores + + +def evaluate(solver, compute_sdr=False): + """ + Evaluate model using museval. + compute_sdr=False means using only the MDX definition of the SDR, which + is much faster to evaluate. + """ + + args = solver.args + + output_dir = solver.folder / "results" + output_dir.mkdir(exist_ok=True, parents=True) + json_folder = solver.folder / "results/test" + json_folder.mkdir(exist_ok=True, parents=True) + + # we load tracks from the original musdb set + if args.test.nonhq is None: + test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True) + else: + test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False) + src_rate = args.dset.musdb_samplerate + + eval_device = 'cpu' + + model = solver.model + win = int(1. * model.samplerate) + hop = int(1. * model.samplerate) + + indexes = range(distrib.rank, len(test_set), distrib.world_size) + indexes = LogProgress(logger, indexes, updates=args.misc.num_prints, + name='Eval') + pendings = [] + + pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor + with pool(args.test.workers) as pool: + for index in indexes: + track = test_set.tracks[index] + + mix = th.from_numpy(track.audio).t().float() + if mix.dim() == 1: + mix = mix[None] + mix = mix.to(solver.device) + ref = mix.mean(dim=0) # mono mixture + mix = (mix - ref.mean()) / ref.std() + mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels) + estimates = apply_model(model, mix[None], + shifts=args.test.shifts, split=args.test.split, + overlap=args.test.overlap)[0] + estimates = estimates * ref.std() + ref.mean() + estimates = estimates.to(eval_device) + + references = th.stack( + [th.from_numpy(track.targets[name].audio).t() for name in model.sources]) + if references.dim() == 2: + references = references[:, None] + references = references.to(eval_device) + references = convert_audio(references, src_rate, + model.samplerate, model.audio_channels) + if args.test.save: + folder = solver.folder / "wav" / track.name + folder.mkdir(exist_ok=True, parents=True) + for name, estimate in zip(model.sources, estimates): + save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate) + + pendings.append((track.name, pool.submit( + eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr))) + + pendings = LogProgress(logger, pendings, updates=args.misc.num_prints, + name='Eval (BSS)') + tracks = {} + for track_name, pending in pendings: + pending = pending.result() + scores, nsdrs = pending + tracks[track_name] = {} + for idx, target in enumerate(model.sources): + tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]} + if scores is not None: + (sdr, isr, sir, sar) = scores + for idx, target in enumerate(model.sources): + values = { + "SDR": sdr[idx].tolist(), + "SIR": sir[idx].tolist(), + "ISR": isr[idx].tolist(), + "SAR": sar[idx].tolist() + } + tracks[track_name][target].update(values) + + all_tracks = {} + for src in range(distrib.world_size): + all_tracks.update(distrib.share(tracks, src)) + + result = {} + metric_names = next(iter(all_tracks.values()))[model.sources[0]] + for metric_name in metric_names: + avg = 0 + avg_of_medians = 0 + for source in model.sources: + medians = [ + np.nanmedian(all_tracks[track][source][metric_name]) + for track in all_tracks.keys()] + mean = np.mean(medians) + median = np.median(medians) + result[metric_name.lower() + "_" + source] = mean + result[metric_name.lower() + "_med" + "_" + source] = median + avg += mean / len(model.sources) + avg_of_medians += median / len(model.sources) + result[metric_name.lower()] = avg + result[metric_name.lower() + "_med"] = avg_of_medians + return result diff --git a/demucs/grids/__init__.py b/demucs/grids/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/demucs/grids/_explorers.py b/demucs/grids/_explorers.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1772958328b76d2b73c64452c2bc3d671896f6 --- /dev/null +++ b/demucs/grids/_explorers.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dora import Explorer +import treetable as tt + + +class MyExplorer(Explorer): + test_metrics = ['nsdr', 'sdr_med'] + + def get_grid_metrics(self): + """Return the metrics that should be displayed in the tracking table. + """ + return [ + tt.group("train", [ + tt.leaf("epoch"), + tt.leaf("reco", ".3f"), + ], align=">"), + tt.group("valid", [ + tt.leaf("penalty", ".1f"), + tt.leaf("ms", ".1f"), + tt.leaf("reco", ".2%"), + tt.leaf("breco", ".2%"), + tt.leaf("b_nsdr", ".2f"), + # tt.leaf("b_nsdr_drums", ".2f"), + # tt.leaf("b_nsdr_bass", ".2f"), + # tt.leaf("b_nsdr_other", ".2f"), + # tt.leaf("b_nsdr_vocals", ".2f"), + ], align=">"), + tt.group("test", [ + tt.leaf(name, ".2f") + for name in self.test_metrics + ], align=">") + ] + + def process_history(self, history): + train = { + 'epoch': len(history), + } + valid = {} + test = {} + best_v_main = float('inf') + breco = float('inf') + for metrics in history: + train.update(metrics['train']) + valid.update(metrics['valid']) + if 'main' in metrics['valid']: + best_v_main = min(best_v_main, metrics['valid']['main']['loss']) + valid['bmain'] = best_v_main + valid['breco'] = min(breco, metrics['valid']['reco']) + breco = valid['breco'] + if (metrics['valid']['loss'] == metrics['valid']['best'] or + metrics['valid'].get('nsdr') == metrics['valid']['best']): + for k, v in metrics['valid'].items(): + if k.startswith('reco_'): + valid['b_' + k[len('reco_'):]] = v + if k.startswith('nsdr'): + valid[f'b_{k}'] = v + if 'test' in metrics: + test.update(metrics['test']) + metrics = history[-1] + return {"train": train, "valid": valid, "test": test} diff --git a/demucs/grids/mdx.py b/demucs/grids/mdx.py new file mode 100644 index 0000000000000000000000000000000000000000..11963aa90aec036f12fdb09e1905a86ba00a4147 --- /dev/null +++ b/demucs/grids/mdx.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Main training for the Track A MDX models. +""" + +from ._explorers import MyExplorer +from ..train import main + + +TRACK_A = ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68'] + + +@MyExplorer +def explorer(launcher): + launcher.slurm_( + gpus=8, + time=3 * 24 * 60, + partition='learnlab') + + # Reproduce results from MDX competition Track A + # This trains the first round of models. Once this is trained, + # you will need to schedule `mdx_refine`. + for sig in TRACK_A: + xp = main.get_xp_from_sig(sig) + parent = xp.cfg.continue_from + xp = main.get_xp_from_sig(parent) + launcher(xp.argv) + launcher(xp.argv, {'quant.diffq': 1e-4}) + launcher(xp.argv, {'quant.diffq': 3e-4}) diff --git a/demucs/grids/mdx_extra.py b/demucs/grids/mdx_extra.py new file mode 100644 index 0000000000000000000000000000000000000000..c4c241632434c85469831a438e0ed8b3fdc4403b --- /dev/null +++ b/demucs/grids/mdx_extra.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Main training for the Track A MDX models. +""" + +from ._explorers import MyExplorer +from ..train import main + +TRACK_B = ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08'] + + +@MyExplorer +def explorer(launcher): + launcher.slurm_( + gpus=8, + time=3 * 24 * 60, + partition='learnlab') + + # Reproduce results from MDX competition Track A + # This trains the first round of models. Once this is trained, + # you will need to schedule `mdx_refine`. + for sig in TRACK_B: + while sig is not None: + xp = main.get_xp_from_sig(sig) + sig = xp.cfg.continue_from + + for dset in ['extra44', 'extra_test']: + sub = launcher.bind(xp.argv, dset=dset) + sub() + if dset == 'extra_test': + sub({'quant.diffq': 1e-4}) + sub({'quant.diffq': 3e-4}) diff --git a/demucs/grids/mdx_refine.py b/demucs/grids/mdx_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea2443f1b458f613c9458dff0491bd2c6b3a4c1 --- /dev/null +++ b/demucs/grids/mdx_refine.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Main training for the Track A MDX models. +""" + +from ._explorers import MyExplorer +from .mdx import TRACK_A +from ..train import main + + +@MyExplorer +def explorer(launcher): + launcher.slurm_( + gpus=8, + time=3 * 24 * 60, + partition='learnlab') + + # Reproduce results from MDX competition Track A + # WARNING: all the experiments in the `mdx` grid must have completed. + for sig in TRACK_A: + xp = main.get_xp_from_sig(sig) + launcher(xp.argv) + for diffq in [1e-4, 3e-4]: + xp_src = main.get_xp_from_sig(xp.cfg.continue_from) + q_argv = [f'quant.diffq={diffq}'] + actual_src = main.get_xp(xp_src.argv + q_argv) + actual_src.link.load() + assert len(actual_src.link.history) == actual_src.cfg.epochs + argv = xp.argv + q_argv + [f'continue_from="{actual_src.sig}"'] + launcher(argv) diff --git a/demucs/grids/mmi.py b/demucs/grids/mmi.py new file mode 100644 index 0000000000000000000000000000000000000000..315b624418702599f877b5135b3817d66f4a085c --- /dev/null +++ b/demucs/grids/mmi.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from ._explorers import MyExplorer +from dora import Launcher + + +@MyExplorer +def explorer(launcher: Launcher): + launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days + + sub = launcher.bind_( + { + "dset": "extra_mmi_goodclean", + "test.shifts": 0, + "model": "htdemucs", + "htdemucs.dconv_mode": 3, + "htdemucs.depth": 4, + "htdemucs.t_dropout": 0.02, + "htdemucs.t_layers": 5, + "max_batches": 800, + "ema.epoch": [0.9, 0.95], + "ema.batch": [0.9995, 0.9999], + "dset.segment": 10, + "batch_size": 32, + } + ) + sub({"model": "hdemucs"}) + sub({"model": "hdemucs", "dset": "extra44"}) + sub({"model": "hdemucs", "dset": "musdb44"}) + + sparse = { + 'batch_size': 3 * 8, + 'augment.remix.group_size': 3, + 'htdemucs.t_auto_sparsity': True, + 'htdemucs.t_sparse_self_attn': True, + 'htdemucs.t_sparse_cross_attn': True, + 'htdemucs.t_sparsity': 0.9, + "htdemucs.t_layers": 7 + } + + with launcher.job_array(): + for transf_layers in [5, 7]: + for bottom_channels in [0, 512]: + sub = launcher.bind({ + "htdemucs.t_layers": transf_layers, + "htdemucs.bottom_channels": bottom_channels, + }) + if bottom_channels == 0 and transf_layers == 5: + sub({"augment.remix.proba": 0.0}) + sub({ + "augment.repitch.proba": 0.0, + # when doing repitching, we trim the outut to align on the + # highest change of BPM. When removing repitching, + # we simulate it here to ensure the training context is the same. + # Another second is lost for all experiments due to the random + # shift augmentation. + "dset.segment": 10 * 0.88}) + elif bottom_channels == 512 and transf_layers == 5: + sub(dset="musdb44") + sub(dset="extra44") + # Sparse kernel XP, currently not released as kernels are still experimental. + sub(sparse, {'dset.segment': 15, "htdemucs.t_layers": 7}) + + for duration in [5, 10, 15]: + sub({"dset.segment": duration}) diff --git a/demucs/grids/mmi_ft.py b/demucs/grids/mmi_ft.py new file mode 100644 index 0000000000000000000000000000000000000000..31f81e9c2266ff45d5d5b51a44559f7fb15402f5 --- /dev/null +++ b/demucs/grids/mmi_ft.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from ._explorers import MyExplorer +from dora import Launcher +from demucs import train + + +def get_sub(launcher, sig): + xp = train.main.get_xp_from_sig(sig) + sub = launcher.bind(xp.argv) + sub() + sub.bind_({ + 'continue_from': sig, + 'continue_best': True}) + return sub + + +@MyExplorer +def explorer(launcher: Launcher): + launcher.slurm_(gpus=4, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days + ft = { + 'optim.lr': 1e-4, + 'augment.remix.proba': 0, + 'augment.scale.proba': 0, + 'augment.shift_same': True, + 'htdemucs.t_weight_decay': 0.05, + 'batch_size': 8, + 'optim.clip_grad': 5, + 'optim.optim': 'adamw', + 'epochs': 50, + 'dset.wav2_valid': True, + 'ema.epoch': [], # let's make valid a bit faster + } + with launcher.job_array(): + for sig in ['2899e11a']: + sub = get_sub(launcher, sig) + sub.bind_(ft) + for segment in [15, 18]: + for source in range(4): + w = [0] * 4 + w[source] = 1 + sub({'weights': w, 'dset.segment': segment}) + + for sig in ['955717e8']: + sub = get_sub(launcher, sig) + sub.bind_(ft) + for segment in [10, 15]: + for source in range(4): + w = [0] * 4 + w[source] = 1 + sub({'weights': w, 'dset.segment': segment}) diff --git a/demucs/grids/repro.py b/demucs/grids/repro.py new file mode 100644 index 0000000000000000000000000000000000000000..b26525cbcdc92c4418e0b3fa8453437528b3cfdd --- /dev/null +++ b/demucs/grids/repro.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Easier training for reproducibility +""" + +from ._explorers import MyExplorer + + +@MyExplorer +def explorer(launcher): + launcher.slurm_( + gpus=8, + time=3 * 24 * 60, + partition='devlab,learnlab') + + launcher.bind_({'ema.epoch': [0.9, 0.95]}) + launcher.bind_({'ema.batch': [0.9995, 0.9999]}) + launcher.bind_({'epochs': 600}) + + base = {'model': 'demucs', 'demucs.dconv_mode': 0, 'demucs.gelu': False, + 'demucs.lstm_layers': 2} + newt = {'model': 'demucs', 'demucs.normalize': True} + hdem = {'model': 'hdemucs'} + svd = {'svd.penalty': 1e-5, 'svd': 'base2'} + + with launcher.job_array(): + for model in [base, newt, hdem]: + sub = launcher.bind(model) + if model is base: + # Training the v2 Demucs on MusDB HQ + sub(epochs=360) + continue + + # those two will be used in the repro_mdx_a bag of models. + sub(svd) + sub(svd, seed=43) + if model == newt: + # Ablation study + sub() + abl = sub.bind(svd) + abl({'ema.epoch': [], 'ema.batch': []}) + abl({'demucs.dconv_lstm': 10}) + abl({'demucs.dconv_attn': 10}) + abl({'demucs.dconv_attn': 10, 'demucs.dconv_lstm': 10, 'demucs.lstm_layers': 2}) + abl({'demucs.dconv_mode': 0}) + abl({'demucs.gelu': False}) diff --git a/demucs/grids/repro_ft.py b/demucs/grids/repro_ft.py new file mode 100644 index 0000000000000000000000000000000000000000..e43af5d1ef394db8f17cff0ea9e145ab07e92543 --- /dev/null +++ b/demucs/grids/repro_ft.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Fine tuning experiments +""" + +from ._explorers import MyExplorer +from ..train import main + + +@MyExplorer +def explorer(launcher): + launcher.slurm_( + gpus=8, + time=300, + partition='devlab,learnlab') + + # Mus + launcher.slurm_(constraint='volta32gb') + + grid = "repro" + folder = main.dora.dir / "grids" / grid + + for sig in folder.iterdir(): + if not sig.is_symlink(): + continue + xp = main.get_xp_from_sig(sig) + xp.link.load() + if len(xp.link.history) != xp.cfg.epochs: + continue + sub = launcher.bind(xp.argv, [f'continue_from="{xp.sig}"']) + sub.bind_({'ema.epoch': [0.9, 0.95], 'ema.batch': [0.9995, 0.9999]}) + sub.bind_({'test.every': 1, 'test.sdr': True, 'epochs': 4}) + sub.bind_({'dset.segment': 28, 'dset.shift': 2}) + sub.bind_({'batch_size': 32}) + auto = {'dset': 'auto_mus'} + auto.update({'augment.remix.proba': 0, 'augment.scale.proba': 0, + 'augment.shift_same': True}) + sub.bind_(auto) + sub.bind_({'batch_size': 16}) + sub.bind_({'optim.lr': 1e-4}) + sub.bind_({'model_segment': 44}) + sub() diff --git a/demucs/grids/sdx23.py b/demucs/grids/sdx23.py new file mode 100644 index 0000000000000000000000000000000000000000..30bb15d5edf1c14bffbe74619ae879675275d6cc --- /dev/null +++ b/demucs/grids/sdx23.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from ._explorers import MyExplorer +from dora import Launcher + + +@MyExplorer +def explorer(launcher: Launcher): + launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="speechgpt,learnfair", + mem_per_gpu=None, constraint='') + launcher.bind_({"dset.use_musdb": False}) + + with launcher.job_array(): + launcher(dset='sdx23_bleeding') + launcher(dset='sdx23_labelnoise') diff --git a/demucs/hdemucs.py b/demucs/hdemucs.py new file mode 100644 index 0000000000000000000000000000000000000000..fb156880ec9d015265fffaf61e7ee6f06b32fce5 --- /dev/null +++ b/demucs/hdemucs.py @@ -0,0 +1,794 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +This code contains the spectrogram and Hybrid version of Demucs. +""" +from copy import deepcopy +import math +import typing as tp + +from openunmix.filtering import wiener +import torch +from torch import nn +from torch.nn import functional as F + +from .demucs import DConv, rescale_module +from .states import capture_init +from .spec import spectro, ispectro + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen.""" + x0 = x + length = x.shape[-1] + padding_left, padding_right = paddings + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + if length <= max_pad: + extra_pad = max_pad - length + 1 + extra_pad_right = min(padding_right, extra_pad) + extra_pad_left = extra_pad - extra_pad_right + paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right) + x = F.pad(x, (extra_pad_left, extra_pad_right)) + out = F.pad(x, paddings, mode, value) + assert out.shape[-1] == length + padding_left + padding_right + assert (out[..., padding_left: padding_left + length] == x0).all() + return out + + +class ScaledEmbedding(nn.Module): + """ + Boost learning rate for embeddings (with `scale`). + Also, can make embeddings continuous with `smooth`. + """ + def __init__(self, num_embeddings: int, embedding_dim: int, + scale: float = 10., smooth=False): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + if smooth: + weight = torch.cumsum(self.embedding.weight.data, dim=0) + # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that. + weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None] + self.embedding.weight.data[:] = weight + self.embedding.weight.data /= scale + self.scale = scale + + @property + def weight(self): + return self.embedding.weight * self.scale + + def forward(self, x): + out = self.embedding(x) * self.scale + return out + + +class HEncLayer(nn.Module): + def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, + freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, + rewrite=True): + """Encoder layer. This used both by the time and the frequency branch. + + Args: + chin: number of input channels. + chout: number of output channels. + norm_groups: number of groups for group norm. + empty: used to make a layer with just the first conv. this is used + before merging the time and freq. branches. + freq: this is acting on frequencies. + dconv: insert DConv residual branches. + norm: use GroupNorm. + context: context size for the 1x1 conv. + dconv_kw: list of kwargs for the DConv class. + pad: pad the input. Padding is done so that the output size is + always the input size / stride. + rewrite: add 1x1 conv at the end of the layer. + """ + super().__init__() + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + if pad: + pad = kernel_size // 4 + else: + pad = 0 + klass = nn.Conv1d + self.freq = freq + self.kernel_size = kernel_size + self.stride = stride + self.empty = empty + self.norm = norm + self.pad = pad + if freq: + kernel_size = [kernel_size, 1] + stride = [stride, 1] + pad = [pad, 0] + klass = nn.Conv2d + self.conv = klass(chin, chout, kernel_size, stride, pad) + if self.empty: + return + self.norm1 = norm_fn(chout) + self.rewrite = None + if rewrite: + self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context) + self.norm2 = norm_fn(2 * chout) + + self.dconv = None + if dconv: + self.dconv = DConv(chout, **dconv_kw) + + def forward(self, x, inject=None): + """ + `inject` is used to inject the result from the time branch into the frequency branch, + when both have the same stride. + """ + if not self.freq and x.dim() == 4: + B, C, Fr, T = x.shape + x = x.view(B, -1, T) + + if not self.freq: + le = x.shape[-1] + if not le % self.stride == 0: + x = F.pad(x, (0, self.stride - (le % self.stride))) + y = self.conv(x) + if self.empty: + return y + if inject is not None: + assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape) + if inject.dim() == 3 and y.dim() == 4: + inject = inject[:, :, None] + y = y + inject + y = F.gelu(self.norm1(y)) + if self.dconv: + if self.freq: + B, C, Fr, T = y.shape + y = y.permute(0, 2, 1, 3).reshape(-1, C, T) + y = self.dconv(y) + if self.freq: + y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) + if self.rewrite: + z = self.norm2(self.rewrite(y)) + z = F.glu(z, dim=1) + else: + z = y + return z + + +class MultiWrap(nn.Module): + """ + Takes one layer and replicate it N times. each replica will act + on a frequency band. All is done so that if the N replica have the same weights, + then this is exactly equivalent to applying the original module on all frequencies. + + This is a bit over-engineered to avoid edge artifacts when splitting + the frequency bands, but it is possible the naive implementation would work as well... + """ + def __init__(self, layer, split_ratios): + """ + Args: + layer: module to clone, must be either HEncLayer or HDecLayer. + split_ratios: list of float indicating which ratio to keep for each band. + """ + super().__init__() + self.split_ratios = split_ratios + self.layers = nn.ModuleList() + self.conv = isinstance(layer, HEncLayer) + assert not layer.norm + assert layer.freq + assert layer.pad + if not self.conv: + assert not layer.context_freq + for k in range(len(split_ratios) + 1): + lay = deepcopy(layer) + if self.conv: + lay.conv.padding = (0, 0) + else: + lay.pad = False + for m in lay.modules(): + if hasattr(m, 'reset_parameters'): + m.reset_parameters() + self.layers.append(lay) + + def forward(self, x, skip=None, length=None): + B, C, Fr, T = x.shape + + ratios = list(self.split_ratios) + [1] + start = 0 + outs = [] + for ratio, layer in zip(ratios, self.layers): + if self.conv: + pad = layer.kernel_size // 4 + if ratio == 1: + limit = Fr + frames = -1 + else: + limit = int(round(Fr * ratio)) + le = limit - start + if start == 0: + le += pad + frames = round((le - layer.kernel_size) / layer.stride + 1) + limit = start + (frames - 1) * layer.stride + layer.kernel_size + if start == 0: + limit -= pad + assert limit - start > 0, (limit, start) + assert limit <= Fr, (limit, Fr) + y = x[:, :, start:limit, :] + if start == 0: + y = F.pad(y, (0, 0, pad, 0)) + if ratio == 1: + y = F.pad(y, (0, 0, 0, pad)) + outs.append(layer(y)) + start = limit - layer.kernel_size + layer.stride + else: + if ratio == 1: + limit = Fr + else: + limit = int(round(Fr * ratio)) + last = layer.last + layer.last = True + + y = x[:, :, start:limit] + s = skip[:, :, start:limit] + out, _ = layer(y, s, None) + if outs: + outs[-1][:, :, -layer.stride:] += ( + out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)) + out = out[:, :, layer.stride:] + if ratio == 1: + out = out[:, :, :-layer.stride // 2, :] + if start == 0: + out = out[:, :, layer.stride // 2:, :] + outs.append(out) + layer.last = last + start = limit + out = torch.cat(outs, dim=2) + if not self.conv and not last: + out = F.gelu(out) + if self.conv: + return out + else: + return out, None + + +class HDecLayer(nn.Module): + def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, + freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, + context_freq=True, rewrite=True): + """ + Same as HEncLayer but for decoder. See `HEncLayer` for documentation. + """ + super().__init__() + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + if pad: + pad = kernel_size // 4 + else: + pad = 0 + self.pad = pad + self.last = last + self.freq = freq + self.chin = chin + self.empty = empty + self.stride = stride + self.kernel_size = kernel_size + self.norm = norm + self.context_freq = context_freq + klass = nn.Conv1d + klass_tr = nn.ConvTranspose1d + if freq: + kernel_size = [kernel_size, 1] + stride = [stride, 1] + klass = nn.Conv2d + klass_tr = nn.ConvTranspose2d + self.conv_tr = klass_tr(chin, chout, kernel_size, stride) + self.norm2 = norm_fn(chout) + if self.empty: + return + self.rewrite = None + if rewrite: + if context_freq: + self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context) + else: + self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, + [0, context]) + self.norm1 = norm_fn(2 * chin) + + self.dconv = None + if dconv: + self.dconv = DConv(chin, **dconv_kw) + + def forward(self, x, skip, length): + if self.freq and x.dim() == 3: + B, C, T = x.shape + x = x.view(B, self.chin, -1, T) + + if not self.empty: + x = x + skip + + if self.rewrite: + y = F.glu(self.norm1(self.rewrite(x)), dim=1) + else: + y = x + if self.dconv: + if self.freq: + B, C, Fr, T = y.shape + y = y.permute(0, 2, 1, 3).reshape(-1, C, T) + y = self.dconv(y) + if self.freq: + y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) + else: + y = x + assert skip is None + z = self.norm2(self.conv_tr(y)) + if self.freq: + if self.pad: + z = z[..., self.pad:-self.pad, :] + else: + z = z[..., self.pad:self.pad + length] + assert z.shape[-1] == length, (z.shape[-1], length) + if not self.last: + z = F.gelu(z) + return z, y + + +class HDemucs(nn.Module): + """ + Spectrogram and hybrid Demucs model. + The spectrogram model has the same structure as Demucs, except the first few layers are over the + frequency axis, until there is only 1 frequency, and then it moves to time convolutions. + Frequency layers can still access information across time steps thanks to the DConv residual. + + Hybrid model have a parallel time branch. At some layer, the time branch has the same stride + as the frequency branch and then the two are combined. The opposite happens in the decoder. + + Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), + or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on + Open Unmix implementation [Stoter et al. 2019]. + + The loss is always on the temporal domain, by backpropagating through the above + output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks + a bit Wiener filtering, as doing more iteration at test time will change the spectrogram + contribution, without changing the one from the waveform, which will lead to worse performance. + I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. + CaC on the other hand provides similar performance for hybrid, and works naturally with + hybrid models. + + This model also uses frequency embeddings are used to improve efficiency on convolutions + over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). + + Unlike classic Demucs, there is no resampling here, and normalization is always applied. + """ + @capture_init + def __init__(self, + sources, + # Channels + audio_channels=2, + channels=48, + channels_time=None, + growth=2, + # STFT + nfft=4096, + wiener_iters=0, + end_iters=0, + wiener_residual=False, + cac=True, + # Main structure + depth=6, + rewrite=True, + hybrid=True, + hybrid_old=False, + # Frequency branch + multi_freqs=None, + multi_freqs_depth=2, + freq_emb=0.2, + emb_scale=10, + emb_smooth=True, + # Convolutions + kernel_size=8, + time_stride=2, + stride=4, + context=1, + context_enc=0, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=4, + dconv_attn=4, + dconv_lstm=4, + dconv_init=1e-4, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=4 * 10): + """ + Args: + sources (list[str]): list of source names. + audio_channels (int): input/output audio channels. + channels (int): initial number of hidden channels. + channels_time: if not None, use a different `channels` value for the time branch. + growth: increase the number of hidden channels by this factor at each layer. + nfft: number of fft bins. Note that changing this require careful computation of + various shape parameters and will not work out of the box for hybrid models. + wiener_iters: when using Wiener filtering, number of iterations at test time. + end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. + wiener_residual: add residual source before wiener filtering. + cac: uses complex as channels, i.e. complex numbers are 2 channels each + in input and output. no further processing is done before ISTFT. + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only. + hybrid_old: some models trained for MDX had a padding bug. This replicates + this bug to avoid retraining them. + multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. + multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost + layers will be wrapped. + freq_emb: add frequency embedding after the first frequency layer if > 0, + the actual value controls the weight of the embedding. + emb_scale: equivalent to scaling the embedding learning rate + emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). + kernel_size: kernel_size for encoder and decoder layers. + stride: stride for encoder and decoder layers. + time_stride: stride for the final time layer, after the merge. + context: context for 1x1 conv in the decoder. + context_enc: context for 1x1 conv in the encoder. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + rescale: weight recaling trick + + """ + super().__init__() + self.cac = cac + self.wiener_residual = wiener_residual + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.channels = channels + self.samplerate = samplerate + self.segment = segment + + self.nfft = nfft + self.hop_length = nfft // 4 + self.wiener_iters = wiener_iters + self.end_iters = end_iters + self.freq_emb = None + self.hybrid = hybrid + self.hybrid_old = hybrid_old + if hybrid_old: + assert hybrid, "hybrid_old must come with hybrid=True" + if hybrid: + assert wiener_iters == end_iters + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + if hybrid: + self.tencoder = nn.ModuleList() + self.tdecoder = nn.ModuleList() + + chin = audio_channels + chin_z = chin # number of channels for the freq branch + if self.cac: + chin_z *= 2 + chout = channels_time or channels + chout_z = channels + freqs = nfft // 2 + + for index in range(depth): + lstm = index >= dconv_lstm + attn = index >= dconv_attn + norm = index >= norm_starts + freq = freqs > 1 + stri = stride + ker = kernel_size + if not freq: + assert freqs == 1 + ker = time_stride * 2 + stri = time_stride + + pad = True + last_freq = False + if freq and freqs <= kernel_size: + ker = freqs + pad = False + last_freq = True + + kw = { + 'kernel_size': ker, + 'stride': stri, + 'freq': freq, + 'pad': pad, + 'norm': norm, + 'rewrite': rewrite, + 'norm_groups': norm_groups, + 'dconv_kw': { + 'lstm': lstm, + 'attn': attn, + 'depth': dconv_depth, + 'compress': dconv_comp, + 'init': dconv_init, + 'gelu': True, + } + } + kwt = dict(kw) + kwt['freq'] = 0 + kwt['kernel_size'] = kernel_size + kwt['stride'] = stride + kwt['pad'] = True + kw_dec = dict(kw) + multi = False + if multi_freqs and index < multi_freqs_depth: + multi = True + kw_dec['context_freq'] = False + + if last_freq: + chout_z = max(chout, chout_z) + chout = chout_z + + enc = HEncLayer(chin_z, chout_z, + dconv=dconv_mode & 1, context=context_enc, **kw) + if hybrid and freq: + tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, + empty=last_freq, **kwt) + self.tencoder.append(tenc) + + if multi: + enc = MultiWrap(enc, multi_freqs) + self.encoder.append(enc) + if index == 0: + chin = self.audio_channels * len(self.sources) + chin_z = chin + if self.cac: + chin_z *= 2 + dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, + last=index == 0, context=context, **kw_dec) + if multi: + dec = MultiWrap(dec, multi_freqs) + if hybrid and freq: + tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, + last=index == 0, context=context, **kwt) + self.tdecoder.insert(0, tdec) + self.decoder.insert(0, dec) + + chin = chout + chin_z = chout_z + chout = int(growth * chout) + chout_z = int(growth * chout_z) + if freq: + if freqs <= kernel_size: + freqs = 1 + else: + freqs //= stride + if index == 0 and freq_emb: + self.freq_emb = ScaledEmbedding( + freqs, chin_z, smooth=emb_smooth, scale=emb_scale) + self.freq_emb_scale = freq_emb + + if rescale: + rescale_module(self, reference=rescale) + + def _spec(self, x): + hl = self.hop_length + nfft = self.nfft + x0 = x # noqa + + if self.hybrid: + # We re-pad the signal in order to keep the property + # that the size of the output is exactly the size of the input + # divided by the stride (here hop_length), when divisible. + # This is achieved by padding by 1/4th of the kernel size (here nfft). + # which is not supported by torch.stft. + # Having all convolution operations follow this convention allow to easily + # align the time and frequency branches later on. + assert hl == nfft // 4 + le = int(math.ceil(x.shape[-1] / hl)) + pad = hl // 2 * 3 + if not self.hybrid_old: + x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect') + else: + x = pad1d(x, (pad, pad + le * hl - x.shape[-1])) + + z = spectro(x, nfft, hl)[..., :-1, :] + if self.hybrid: + assert z.shape[-1] == le + 4, (z.shape, x.shape, le) + z = z[..., 2:2+le] + return z + + def _ispec(self, z, length=None, scale=0): + hl = self.hop_length // (4 ** scale) + z = F.pad(z, (0, 0, 0, 1)) + if self.hybrid: + z = F.pad(z, (2, 2)) + pad = hl // 2 * 3 + if not self.hybrid_old: + le = hl * int(math.ceil(length / hl)) + 2 * pad + else: + le = hl * int(math.ceil(length / hl)) + x = ispectro(z, hl, length=le) + if not self.hybrid_old: + x = x[..., pad:pad + length] + else: + x = x[..., :length] + else: + x = ispectro(z, hl, length) + return x + + def _magnitude(self, z): + # return the magnitude of the spectrogram, except when cac is True, + # in which case we just move the complex dimension to the channel one. + if self.cac: + B, C, Fr, T = z.shape + m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) + m = m.reshape(B, C * 2, Fr, T) + else: + m = z.abs() + return m + + def _mask(self, z, m): + # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. + # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. + niters = self.wiener_iters + if self.cac: + B, S, C, Fr, T = m.shape + out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) + out = torch.view_as_complex(out.contiguous()) + return out + if self.training: + niters = self.end_iters + if niters < 0: + z = z[:, None] + return z / (1e-8 + z.abs()) * m + else: + return self._wiener(m, z, niters) + + def _wiener(self, mag_out, mix_stft, niters): + # apply wiener filtering from OpenUnmix. + init = mix_stft.dtype + wiener_win_len = 300 + residual = self.wiener_residual + + B, S, C, Fq, T = mag_out.shape + mag_out = mag_out.permute(0, 4, 3, 2, 1) + mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) + + outs = [] + for sample in range(B): + pos = 0 + out = [] + for pos in range(0, T, wiener_win_len): + frame = slice(pos, pos + wiener_win_len) + z_out = wiener( + mag_out[sample, frame], mix_stft[sample, frame], niters, + residual=residual) + out.append(z_out.transpose(-1, -2)) + outs.append(torch.cat(out, dim=0)) + out = torch.view_as_complex(torch.stack(outs, 0)) + out = out.permute(0, 4, 3, 2, 1).contiguous() + if residual: + out = out[:, :-1] + assert list(out.shape) == [B, S, C, Fq, T] + return out.to(init) + + def forward(self, mix): + x = mix + length = x.shape[-1] + + z = self._spec(mix) + mag = self._magnitude(z).to(mix.device) + x = mag + + B, C, Fq, T = x.shape + + # unlike previous Demucs, we always normalize because it is easier. + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + # x will be the freq. branch input. + + if self.hybrid: + # Prepare the time branch input. + xt = mix + meant = xt.mean(dim=(1, 2), keepdim=True) + stdt = xt.std(dim=(1, 2), keepdim=True) + xt = (xt - meant) / (1e-5 + stdt) + + # okay, this is a giant mess I know... + saved = [] # skip connections, freq. + saved_t = [] # skip connections, time. + lengths = [] # saved lengths to properly remove padding, freq branch. + lengths_t = [] # saved lengths for time branch. + for idx, encode in enumerate(self.encoder): + lengths.append(x.shape[-1]) + inject = None + if self.hybrid and idx < len(self.tencoder): + # we have not yet merged branches. + lengths_t.append(xt.shape[-1]) + tenc = self.tencoder[idx] + xt = tenc(xt) + if not tenc.empty: + # save for skip connection + saved_t.append(xt) + else: + # tenc contains just the first conv., so that now time and freq. + # branches have the same shape and can be merged. + inject = xt + x = encode(x, inject) + if idx == 0 and self.freq_emb is not None: + # add frequency embedding to allow for non equivariant convolutions + # over the frequency axis. + frs = torch.arange(x.shape[-2], device=x.device) + emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) + x = x + self.freq_emb_scale * emb + + saved.append(x) + + x = torch.zeros_like(x) + if self.hybrid: + xt = torch.zeros_like(x) + # initialize everything to zero (signal will go through u-net skips). + + for idx, decode in enumerate(self.decoder): + skip = saved.pop(-1) + x, pre = decode(x, skip, lengths.pop(-1)) + # `pre` contains the output just before final transposed convolution, + # which is used when the freq. and time branch separate. + + if self.hybrid: + offset = self.depth - len(self.tdecoder) + if self.hybrid and idx >= offset: + tdec = self.tdecoder[idx - offset] + length_t = lengths_t.pop(-1) + if tdec.empty: + assert pre.shape[2] == 1, pre.shape + pre = pre[:, :, 0] + xt, _ = tdec(pre, None, length_t) + else: + skip = saved_t.pop(-1) + xt, _ = tdec(xt, skip, length_t) + + # Let's make sure we used all stored skip connections. + assert len(saved) == 0 + assert len(lengths_t) == 0 + assert len(saved_t) == 0 + + S = len(self.sources) + x = x.view(B, S, -1, Fq, T) + x = x * std[:, None] + mean[:, None] + + # to cpu as mps doesnt support complex numbers + # demucs issue #435 ##432 + # NOTE: in this case z already is on cpu + # TODO: remove this when mps supports complex numbers + x_is_mps = x.device.type == "mps" + if x_is_mps: + x = x.cpu() + + zout = self._mask(z, x) + x = self._ispec(zout, length) + + # back to mps device + if x_is_mps: + x = x.to('mps') + + if self.hybrid: + xt = xt.view(B, S, -1, length) + xt = xt * stdt[:, None] + meant[:, None] + x = xt + x + return x diff --git a/demucs/htdemucs.py b/demucs/htdemucs.py new file mode 100644 index 0000000000000000000000000000000000000000..c79668425daf423bc71c6c83ad92a0f126d63e93 --- /dev/null +++ b/demucs/htdemucs.py @@ -0,0 +1,660 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# First author is Simon Rouard. +""" +This code contains the spectrogram and Hybrid version of Demucs. +""" +import math + +from openunmix.filtering import wiener +import torch +from torch import nn +from torch.nn import functional as F +from fractions import Fraction +from einops import rearrange + +from .transformer import CrossTransformerEncoder + +from .demucs import rescale_module +from .states import capture_init +from .spec import spectro, ispectro +from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer + + +class HTDemucs(nn.Module): + """ + Spectrogram and hybrid Demucs model. + The spectrogram model has the same structure as Demucs, except the first few layers are over the + frequency axis, until there is only 1 frequency, and then it moves to time convolutions. + Frequency layers can still access information across time steps thanks to the DConv residual. + + Hybrid model have a parallel time branch. At some layer, the time branch has the same stride + as the frequency branch and then the two are combined. The opposite happens in the decoder. + + Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), + or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on + Open Unmix implementation [Stoter et al. 2019]. + + The loss is always on the temporal domain, by backpropagating through the above + output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks + a bit Wiener filtering, as doing more iteration at test time will change the spectrogram + contribution, without changing the one from the waveform, which will lead to worse performance. + I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. + CaC on the other hand provides similar performance for hybrid, and works naturally with + hybrid models. + + This model also uses frequency embeddings are used to improve efficiency on convolutions + over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). + + Unlike classic Demucs, there is no resampling here, and normalization is always applied. + """ + + @capture_init + def __init__( + self, + sources, + # Channels + audio_channels=2, + channels=48, + channels_time=None, + growth=2, + # STFT + nfft=4096, + wiener_iters=0, + end_iters=0, + wiener_residual=False, + cac=True, + # Main structure + depth=4, + rewrite=True, + # Frequency branch + multi_freqs=None, + multi_freqs_depth=3, + freq_emb=0.2, + emb_scale=10, + emb_smooth=True, + # Convolutions + kernel_size=8, + time_stride=2, + stride=4, + context=1, + context_enc=0, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=8, + dconv_init=1e-3, + # Before the Transformer + bottom_channels=0, + # Transformer + t_layers=5, + t_emb="sin", + t_hidden_scale=4.0, + t_heads=8, + t_dropout=0.0, + t_max_positions=10000, + t_norm_in=True, + t_norm_in_group=False, + t_group_norm=False, + t_norm_first=True, + t_norm_out=True, + t_max_period=10000.0, + t_weight_decay=0.0, + t_lr=None, + t_layer_scale=True, + t_gelu=True, + t_weight_pos_embed=1.0, + t_sin_random_shift=0, + t_cape_mean_normalize=True, + t_cape_augment=True, + t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], + t_sparse_self_attn=False, + t_sparse_cross_attn=False, + t_mask_type="diag", + t_mask_random_seed=42, + t_sparse_attn_window=500, + t_global_window=100, + t_sparsity=0.95, + t_auto_sparsity=False, + # ------ Particuliar parameters + t_cross_first=False, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=10, + use_train_segment=True, + ): + """ + Args: + sources (list[str]): list of source names. + audio_channels (int): input/output audio channels. + channels (int): initial number of hidden channels. + channels_time: if not None, use a different `channels` value for the time branch. + growth: increase the number of hidden channels by this factor at each layer. + nfft: number of fft bins. Note that changing this require careful computation of + various shape parameters and will not work out of the box for hybrid models. + wiener_iters: when using Wiener filtering, number of iterations at test time. + end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. + wiener_residual: add residual source before wiener filtering. + cac: uses complex as channels, i.e. complex numbers are 2 channels each + in input and output. no further processing is done before ISTFT. + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. + multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost + layers will be wrapped. + freq_emb: add frequency embedding after the first frequency layer if > 0, + the actual value controls the weight of the embedding. + emb_scale: equivalent to scaling the embedding learning rate + emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). + kernel_size: kernel_size for encoder and decoder layers. + stride: stride for encoder and decoder layers. + time_stride: stride for the final time layer, after the merge. + context: context for 1x1 conv in the decoder. + context_enc: context for 1x1 conv in the encoder. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the + transformer in order to change the number of channels + t_layers: number of layers in each branch (waveform and spec) of the transformer + t_emb: "sin", "cape" or "scaled" + t_hidden_scale: the hidden scale of the Feedforward parts of the transformer + for instance if C = 384 (the number of channels in the transformer) and + t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension + 384 * 4 = 1536 + t_heads: number of heads for the transformer + t_dropout: dropout in the transformer + t_max_positions: max_positions for the "scaled" positional embedding, only + useful if t_emb="scaled" + t_norm_in: (bool) norm before addinf positional embedding and getting into the + transformer layers + t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the + timesteps (GroupNorm with group=1) + t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the + timesteps (GroupNorm with group=1) + t_norm_first: (bool) if True the norm is before the attention and before the FFN + t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer + t_max_period: (float) denominator in the sinusoidal embedding expression + t_weight_decay: (float) weight decay for the transformer + t_lr: (float) specific learning rate for the transformer + t_layer_scale: (bool) Layer Scale for the transformer + t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else + t_weight_pos_embed: (float) weighting of the positional embedding + t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings + see: https://arxiv.org/abs/2106.03143 + t_cape_augment: (bool) if t_emb="cape", must be True during training and False + during the inference, see: https://arxiv.org/abs/2106.03143 + t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters + see: https://arxiv.org/abs/2106.03143 + t_sparse_self_attn: (bool) if True, the self attentions are sparse + t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it + unless you designed really specific masks) + t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination + with '_' between: i.e. "diag_jmask_random" (note that this is permutation + invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag") + t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed + that generated the random part of the mask + t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and + a key (j), the mask is True id |i-j|<=t_sparse_attn_window + t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :] + and mask[:, :t_global_window] will be True + t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity + level of the random part of the mask. + t_cross_first: (bool) if True cross attention is the first layer of the + transformer (False seems to be better) + rescale: weight rescaling trick + use_train_segment: (bool) if True, the actual size that is used during the + training is used during inference. + """ + super().__init__() + self.cac = cac + self.wiener_residual = wiener_residual + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.bottom_channels = bottom_channels + self.channels = channels + self.samplerate = samplerate + self.segment = segment + self.use_train_segment = use_train_segment + self.nfft = nfft + self.hop_length = nfft // 4 + self.wiener_iters = wiener_iters + self.end_iters = end_iters + self.freq_emb = None + assert wiener_iters == end_iters + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.tencoder = nn.ModuleList() + self.tdecoder = nn.ModuleList() + + chin = audio_channels + chin_z = chin # number of channels for the freq branch + if self.cac: + chin_z *= 2 + chout = channels_time or channels + chout_z = channels + freqs = nfft // 2 + + for index in range(depth): + norm = index >= norm_starts + freq = freqs > 1 + stri = stride + ker = kernel_size + if not freq: + assert freqs == 1 + ker = time_stride * 2 + stri = time_stride + + pad = True + last_freq = False + if freq and freqs <= kernel_size: + ker = freqs + pad = False + last_freq = True + + kw = { + "kernel_size": ker, + "stride": stri, + "freq": freq, + "pad": pad, + "norm": norm, + "rewrite": rewrite, + "norm_groups": norm_groups, + "dconv_kw": { + "depth": dconv_depth, + "compress": dconv_comp, + "init": dconv_init, + "gelu": True, + }, + } + kwt = dict(kw) + kwt["freq"] = 0 + kwt["kernel_size"] = kernel_size + kwt["stride"] = stride + kwt["pad"] = True + kw_dec = dict(kw) + multi = False + if multi_freqs and index < multi_freqs_depth: + multi = True + kw_dec["context_freq"] = False + + if last_freq: + chout_z = max(chout, chout_z) + chout = chout_z + + enc = HEncLayer( + chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw + ) + if freq: + tenc = HEncLayer( + chin, + chout, + dconv=dconv_mode & 1, + context=context_enc, + empty=last_freq, + **kwt + ) + self.tencoder.append(tenc) + + if multi: + enc = MultiWrap(enc, multi_freqs) + self.encoder.append(enc) + if index == 0: + chin = self.audio_channels * len(self.sources) + chin_z = chin + if self.cac: + chin_z *= 2 + dec = HDecLayer( + chout_z, + chin_z, + dconv=dconv_mode & 2, + last=index == 0, + context=context, + **kw_dec + ) + if multi: + dec = MultiWrap(dec, multi_freqs) + if freq: + tdec = HDecLayer( + chout, + chin, + dconv=dconv_mode & 2, + empty=last_freq, + last=index == 0, + context=context, + **kwt + ) + self.tdecoder.insert(0, tdec) + self.decoder.insert(0, dec) + + chin = chout + chin_z = chout_z + chout = int(growth * chout) + chout_z = int(growth * chout_z) + if freq: + if freqs <= kernel_size: + freqs = 1 + else: + freqs //= stride + if index == 0 and freq_emb: + self.freq_emb = ScaledEmbedding( + freqs, chin_z, smooth=emb_smooth, scale=emb_scale + ) + self.freq_emb_scale = freq_emb + + if rescale: + rescale_module(self, reference=rescale) + + transformer_channels = channels * growth ** (depth - 1) + if bottom_channels: + self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1) + self.channel_downsampler = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + self.channel_upsampler_t = nn.Conv1d( + transformer_channels, bottom_channels, 1 + ) + self.channel_downsampler_t = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + + transformer_channels = bottom_channels + + if t_layers > 0: + self.crosstransformer = CrossTransformerEncoder( + dim=transformer_channels, + emb=t_emb, + hidden_scale=t_hidden_scale, + num_heads=t_heads, + num_layers=t_layers, + cross_first=t_cross_first, + dropout=t_dropout, + max_positions=t_max_positions, + norm_in=t_norm_in, + norm_in_group=t_norm_in_group, + group_norm=t_group_norm, + norm_first=t_norm_first, + norm_out=t_norm_out, + max_period=t_max_period, + weight_decay=t_weight_decay, + lr=t_lr, + layer_scale=t_layer_scale, + gelu=t_gelu, + sin_random_shift=t_sin_random_shift, + weight_pos_embed=t_weight_pos_embed, + cape_mean_normalize=t_cape_mean_normalize, + cape_augment=t_cape_augment, + cape_glob_loc_scale=t_cape_glob_loc_scale, + sparse_self_attn=t_sparse_self_attn, + sparse_cross_attn=t_sparse_cross_attn, + mask_type=t_mask_type, + mask_random_seed=t_mask_random_seed, + sparse_attn_window=t_sparse_attn_window, + global_window=t_global_window, + sparsity=t_sparsity, + auto_sparsity=t_auto_sparsity, + ) + else: + self.crosstransformer = None + + def _spec(self, x): + hl = self.hop_length + nfft = self.nfft + x0 = x # noqa + + # We re-pad the signal in order to keep the property + # that the size of the output is exactly the size of the input + # divided by the stride (here hop_length), when divisible. + # This is achieved by padding by 1/4th of the kernel size (here nfft). + # which is not supported by torch.stft. + # Having all convolution operations follow this convention allow to easily + # align the time and frequency branches later on. + assert hl == nfft // 4 + le = int(math.ceil(x.shape[-1] / hl)) + pad = hl // 2 * 3 + x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") + + z = spectro(x, nfft, hl)[..., :-1, :] + assert z.shape[-1] == le + 4, (z.shape, x.shape, le) + z = z[..., 2: 2 + le] + return z + + def _ispec(self, z, length=None, scale=0): + hl = self.hop_length // (4**scale) + z = F.pad(z, (0, 0, 0, 1)) + z = F.pad(z, (2, 2)) + pad = hl // 2 * 3 + le = hl * int(math.ceil(length / hl)) + 2 * pad + x = ispectro(z, hl, length=le) + x = x[..., pad: pad + length] + return x + + def _magnitude(self, z): + # return the magnitude of the spectrogram, except when cac is True, + # in which case we just move the complex dimension to the channel one. + if self.cac: + B, C, Fr, T = z.shape + m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) + m = m.reshape(B, C * 2, Fr, T) + else: + m = z.abs() + return m + + def _mask(self, z, m): + # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. + # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. + niters = self.wiener_iters + if self.cac: + B, S, C, Fr, T = m.shape + out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) + out = torch.view_as_complex(out.contiguous()) + return out + if self.training: + niters = self.end_iters + if niters < 0: + z = z[:, None] + return z / (1e-8 + z.abs()) * m + else: + return self._wiener(m, z, niters) + + def _wiener(self, mag_out, mix_stft, niters): + # apply wiener filtering from OpenUnmix. + init = mix_stft.dtype + wiener_win_len = 300 + residual = self.wiener_residual + + B, S, C, Fq, T = mag_out.shape + mag_out = mag_out.permute(0, 4, 3, 2, 1) + mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) + + outs = [] + for sample in range(B): + pos = 0 + out = [] + for pos in range(0, T, wiener_win_len): + frame = slice(pos, pos + wiener_win_len) + z_out = wiener( + mag_out[sample, frame], + mix_stft[sample, frame], + niters, + residual=residual, + ) + out.append(z_out.transpose(-1, -2)) + outs.append(torch.cat(out, dim=0)) + out = torch.view_as_complex(torch.stack(outs, 0)) + out = out.permute(0, 4, 3, 2, 1).contiguous() + if residual: + out = out[:, :-1] + assert list(out.shape) == [B, S, C, Fq, T] + return out.to(init) + + def valid_length(self, length: int): + """ + Return a length that is appropriate for evaluation. + In our case, always return the training length, unless + it is smaller than the given length, in which case this + raises an error. + """ + if not self.use_train_segment: + return length + training_length = int(self.segment * self.samplerate) + if training_length < length: + raise ValueError( + f"Given length {length} is longer than " + f"training length {training_length}") + return training_length + + def forward(self, mix): + length = mix.shape[-1] + length_pre_pad = None + if self.use_train_segment: + if self.training: + self.segment = Fraction(mix.shape[-1], self.samplerate) + else: + training_length = int(self.segment * self.samplerate) + if mix.shape[-1] < training_length: + length_pre_pad = mix.shape[-1] + mix = F.pad(mix, (0, training_length - length_pre_pad)) + z = self._spec(mix) + mag = self._magnitude(z).to(mix.device) + x = mag + + B, C, Fq, T = x.shape + + # unlike previous Demucs, we always normalize because it is easier. + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + # x will be the freq. branch input. + + # Prepare the time branch input. + xt = mix + meant = xt.mean(dim=(1, 2), keepdim=True) + stdt = xt.std(dim=(1, 2), keepdim=True) + xt = (xt - meant) / (1e-5 + stdt) + + # okay, this is a giant mess I know... + saved = [] # skip connections, freq. + saved_t = [] # skip connections, time. + lengths = [] # saved lengths to properly remove padding, freq branch. + lengths_t = [] # saved lengths for time branch. + for idx, encode in enumerate(self.encoder): + lengths.append(x.shape[-1]) + inject = None + if idx < len(self.tencoder): + # we have not yet merged branches. + lengths_t.append(xt.shape[-1]) + tenc = self.tencoder[idx] + xt = tenc(xt) + if not tenc.empty: + # save for skip connection + saved_t.append(xt) + else: + # tenc contains just the first conv., so that now time and freq. + # branches have the same shape and can be merged. + inject = xt + x = encode(x, inject) + if idx == 0 and self.freq_emb is not None: + # add frequency embedding to allow for non equivariant convolutions + # over the frequency axis. + frs = torch.arange(x.shape[-2], device=x.device) + emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) + x = x + self.freq_emb_scale * emb + + saved.append(x) + if self.crosstransformer: + if self.bottom_channels: + b, c, f, t = x.shape + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_upsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_upsampler_t(xt) + + x, xt = self.crosstransformer(x, xt) + + if self.bottom_channels: + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_downsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_downsampler_t(xt) + + for idx, decode in enumerate(self.decoder): + skip = saved.pop(-1) + x, pre = decode(x, skip, lengths.pop(-1)) + # `pre` contains the output just before final transposed convolution, + # which is used when the freq. and time branch separate. + + offset = self.depth - len(self.tdecoder) + if idx >= offset: + tdec = self.tdecoder[idx - offset] + length_t = lengths_t.pop(-1) + if tdec.empty: + assert pre.shape[2] == 1, pre.shape + pre = pre[:, :, 0] + xt, _ = tdec(pre, None, length_t) + else: + skip = saved_t.pop(-1) + xt, _ = tdec(xt, skip, length_t) + + # Let's make sure we used all stored skip connections. + assert len(saved) == 0 + assert len(lengths_t) == 0 + assert len(saved_t) == 0 + + S = len(self.sources) + x = x.view(B, S, -1, Fq, T) + x = x * std[:, None] + mean[:, None] + + # to cpu as mps doesnt support complex numbers + # demucs issue #435 ##432 + # NOTE: in this case z already is on cpu + # TODO: remove this when mps supports complex numbers + x_is_mps = x.device.type == "mps" + if x_is_mps: + x = x.cpu() + + zout = self._mask(z, x) + if self.use_train_segment: + if self.training: + x = self._ispec(zout, length) + else: + x = self._ispec(zout, training_length) + else: + x = self._ispec(zout, length) + + # back to mps device + if x_is_mps: + x = x.to("mps") + + if self.use_train_segment: + if self.training: + xt = xt.view(B, S, -1, length) + else: + xt = xt.view(B, S, -1, training_length) + else: + xt = xt.view(B, S, -1, length) + xt = xt * stdt[:, None] + meant[:, None] + x = xt + x + if length_pre_pad: + x = x[..., :length_pre_pad] + return x diff --git a/demucs/pretrained.py b/demucs/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..38b48917aa1cd5ef6571087124a40138d651682e --- /dev/null +++ b/demucs/pretrained.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Loading pretrained models. +""" + +import logging +from pathlib import Path +import typing as tp + +from dora.log import fatal, bold + +from .hdemucs import HDemucs +from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa +from .states import _check_diffq + +logger = logging.getLogger(__name__) +ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/" +REMOTE_ROOT = Path(__file__).parent / 'remote' + +SOURCES = ["drums", "bass", "other", "vocals"] +DEFAULT_MODEL = 'htdemucs' + + +def demucs_unittest(): + model = HDemucs(channels=4, sources=SOURCES) + return model + + +def add_model_flags(parser): + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument("-s", "--sig", help="Locally trained XP signature.") + group.add_argument("-n", "--name", default="htdemucs", + help="Pretrained model name or signature. Default is htdemucs.") + parser.add_argument("--repo", type=Path, + help="Folder containing all pre-trained models for use with -n.") + + +def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]: + root: str = '' + models: tp.Dict[str, str] = {} + for line in remote_file_list.read_text().split('\n'): + line = line.strip() + if line.startswith('#'): + continue + elif len(line) == 0: + continue + elif line.startswith('root:'): + root = line.split(':', 1)[1].strip() + else: + sig = line.split('-', 1)[0] + assert sig not in models + models[sig] = ROOT_URL + root + line + return models + + +def get_model(name: str, + repo: tp.Optional[Path] = None): + """`name` must be a bag of models name or a pretrained signature + from the remote AWS model repo or the specified local repo if `repo` is not None. + """ + if name == 'demucs_unittest': + return demucs_unittest() + model_repo: ModelOnlyRepo + if repo is None: + models = _parse_remote_files(REMOTE_ROOT / 'files.txt') + model_repo = RemoteRepo(models) + bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) + else: + if not repo.is_dir(): + fatal(f"{repo} must exist and be a directory.") + model_repo = LocalRepo(repo) + bag_repo = BagOnlyRepo(repo, model_repo) + any_repo = AnyModelRepo(model_repo, bag_repo) + try: + model = any_repo.get_model(name) + except ImportError as exc: + if 'diffq' in exc.args[0]: + _check_diffq() + raise + + model.eval() + return model + + +def get_model_from_args(args): + """ + Load local model package or pre-trained model. + """ + if args.name is None: + args.name = DEFAULT_MODEL + print(bold("Important: the default model was recently changed to `htdemucs`"), + "the latest Hybrid Transformer Demucs model. In some cases, this model can " + "actually perform worse than previous models. To get back the old default model " + "use `-n mdx_extra_q`.") + return get_model(name=args.name, repo=args.repo) diff --git a/demucs/py.typed b/demucs/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/demucs/remote/files.txt b/demucs/remote/files.txt new file mode 100644 index 0000000000000000000000000000000000000000..bef4947bf9bfb2dcc9fdaf94ad36a6267b7295ec --- /dev/null +++ b/demucs/remote/files.txt @@ -0,0 +1,32 @@ +# MDX Models +root: mdx_final/ +0d19c1c6-0f06f20e.th +5d2d6c55-db83574e.th +7d865c68-3d5dd56b.th +7ecf8ec1-70f50cc9.th +a1d90b5c-ae9d2452.th +c511e2ab-fe698775.th +cfa93e08-61801ae1.th +e51eebcc-c1b80bdd.th +6b9c2ca1-3fd82607.th +b72baf4e-8778635e.th +42e558d4-196e0e1b.th +305bc58f-18378783.th +14fc6a69-a89dd0ee.th +464b36d7-e5a9386e.th +7fd6ef75-a905dd85.th +83fc094f-4a16d450.th +1ef250f1-592467ce.th +902315c2-b39ce9c9.th +9a6b4851-03af0aa6.th +fa0cb7f9-100d8bf4.th +# Hybrid Transformer models +root: hybrid_transformer/ +955717e8-8726e21a.th +f7e0c4bc-ba3fe64a.th +d12395a8-e57c48e6.th +92cfc3b6-ef3bcb9c.th +04573f0d-f3cf25b2.th +75fc33f5-1941ce65.th +# Experimental 6 sources model +5c90dfd2-34c22ccb.th diff --git a/demucs/remote/hdemucs_mmi.yaml b/demucs/remote/hdemucs_mmi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5f12224cfa2896e9c700b8639ce55ceaea1d20b2 --- /dev/null +++ b/demucs/remote/hdemucs_mmi.yaml @@ -0,0 +1,2 @@ +models: ['75fc33f5'] +segment: 44 diff --git a/demucs/remote/htdemucs.yaml b/demucs/remote/htdemucs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..88f21093f342224f557ab20b1c5fa2a69a1e5da2 --- /dev/null +++ b/demucs/remote/htdemucs.yaml @@ -0,0 +1 @@ +models: ['955717e8'] diff --git a/demucs/remote/htdemucs_6s.yaml b/demucs/remote/htdemucs_6s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9131f96d85aa7bcbb8491aec2dccb9a596588d6f --- /dev/null +++ b/demucs/remote/htdemucs_6s.yaml @@ -0,0 +1 @@ +models: ['5c90dfd2'] diff --git a/demucs/remote/htdemucs_ft.yaml b/demucs/remote/htdemucs_ft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e22c959b41683b8d7e92807fd960d2f3f616cf6 --- /dev/null +++ b/demucs/remote/htdemucs_ft.yaml @@ -0,0 +1,7 @@ +models: ['f7e0c4bc', 'd12395a8', '92cfc3b6', '04573f0d'] +weights: [ + [1., 0., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.], +] \ No newline at end of file diff --git a/demucs/remote/mdx.yaml b/demucs/remote/mdx.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ba43d37cf4669f3004021bfcab2050756e65cb9 --- /dev/null +++ b/demucs/remote/mdx.yaml @@ -0,0 +1,8 @@ +models: ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68'] +weights: [ + [1., 1., 0., 0.], + [0., 1., 0., 0.], + [1., 0., 1., 1.], + [1., 0., 1., 1.], +] +segment: 44 diff --git a/demucs/remote/mdx_extra.yaml b/demucs/remote/mdx_extra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17addae1f20dccf49672ccb0e0516e0ef5650bb8 --- /dev/null +++ b/demucs/remote/mdx_extra.yaml @@ -0,0 +1,2 @@ +models: ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08'] +segment: 44 \ No newline at end of file diff --git a/demucs/remote/mdx_extra_q.yaml b/demucs/remote/mdx_extra_q.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a27d6f0db5ece317b33d351809eabba8a6dc4602 --- /dev/null +++ b/demucs/remote/mdx_extra_q.yaml @@ -0,0 +1,2 @@ +models: ['83fc094f', '464b36d7', '14fc6a69', '7fd6ef75'] +segment: 44 diff --git a/demucs/remote/mdx_q.yaml b/demucs/remote/mdx_q.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e34e7a699babbc0ff3ecdb86a50a06116ce58066 --- /dev/null +++ b/demucs/remote/mdx_q.yaml @@ -0,0 +1,8 @@ +models: ['6b9c2ca1', 'b72baf4e', '42e558d4', '305bc58f'] +weights: [ + [1., 1., 0., 0.], + [0., 1., 0., 0.], + [1., 0., 1., 1.], + [1., 0., 1., 1.], +] +segment: 44 diff --git a/demucs/remote/repro_mdx_a.yaml b/demucs/remote/repro_mdx_a.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a224cfcb706db0a7345cae0d9e9b346b6776655a --- /dev/null +++ b/demucs/remote/repro_mdx_a.yaml @@ -0,0 +1,2 @@ +models: ['9a6b4851', '1ef250f1', 'fa0cb7f9', '902315c2'] +segment: 44 diff --git a/demucs/remote/repro_mdx_a_hybrid_only.yaml b/demucs/remote/repro_mdx_a_hybrid_only.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d7420c394c8f8fde29484202887e792e8273312b --- /dev/null +++ b/demucs/remote/repro_mdx_a_hybrid_only.yaml @@ -0,0 +1,2 @@ +models: ['fa0cb7f9', '902315c2', 'fa0cb7f9', '902315c2'] +segment: 44 diff --git a/demucs/remote/repro_mdx_a_time_only.yaml b/demucs/remote/repro_mdx_a_time_only.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eb1f442209cabd38c0470aa5a513e269160a7959 --- /dev/null +++ b/demucs/remote/repro_mdx_a_time_only.yaml @@ -0,0 +1,2 @@ +models: ['9a6b4851', '9a6b4851', '1ef250f1', '1ef250f1'] +segment: 44 diff --git a/demucs/repitch.py b/demucs/repitch.py new file mode 100644 index 0000000000000000000000000000000000000000..262c427fab9e646b2ba46e7d3e05172550d30bfb --- /dev/null +++ b/demucs/repitch.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Utility for on the fly pitch/tempo change for data augmentation.""" + +import random +import subprocess as sp +import tempfile + +import torch +import torchaudio as ta + +from .audio import save_audio + + +class RepitchedWrapper: + """ + Wrap a dataset to apply online change of pitch / tempo. + """ + def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12, + tempo_std=5, vocals=[3], same=True): + self.dataset = dataset + self.proba = proba + self.max_pitch = max_pitch + self.max_tempo = max_tempo + self.tempo_std = tempo_std + self.same = same + self.vocals = vocals + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + streams = self.dataset[index] + in_length = streams.shape[-1] + out_length = int((1 - 0.01 * self.max_tempo) * in_length) + + if random.random() < self.proba: + outs = [] + for idx, stream in enumerate(streams): + if idx == 0 or not self.same: + delta_pitch = random.randint(-self.max_pitch, self.max_pitch) + delta_tempo = random.gauss(0, self.tempo_std) + delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo) + stream = repitch( + stream, + delta_pitch, + delta_tempo, + voice=idx in self.vocals) + outs.append(stream[:, :out_length]) + streams = torch.stack(outs) + else: + streams = streams[..., :out_length] + return streams + + +def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100): + """ + tempo is a relative delta in percentage, so tempo=10 means tempo at 110%! + pitch is in semi tones. + Requires `soundstretch` to be installed, see + https://www.surina.net/soundtouch/soundstretch.html + """ + infile = tempfile.NamedTemporaryFile(suffix=".wav") + outfile = tempfile.NamedTemporaryFile(suffix=".wav") + save_audio(wav, infile.name, samplerate, clip='clamp') + command = [ + "soundstretch", + infile.name, + outfile.name, + f"-pitch={pitch}", + f"-tempo={tempo:.6f}", + ] + if quick: + command += ["-quick"] + if voice: + command += ["-speech"] + try: + sp.run(command, capture_output=True, check=True) + except sp.CalledProcessError as error: + raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}") + wav, sr = ta.load(outfile.name) + assert sr == samplerate + return wav diff --git a/demucs/repo.py b/demucs/repo.py new file mode 100644 index 0000000000000000000000000000000000000000..5df3ba846d80e80c88cf9a26a5e58c2140acc636 --- /dev/null +++ b/demucs/repo.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Represents a model repository, including pre-trained models and bags of models. +A repo can either be the main remote repository stored in AWS, or a local repository +with your own models. +""" + +from hashlib import sha256 +from pathlib import Path +import typing as tp + +import torch +import yaml + +from .apply import BagOfModels, Model +from .states import load_model + + +AnyModel = tp.Union[Model, BagOfModels] + + +class ModelLoadingError(RuntimeError): + pass + + +def check_checksum(path: Path, checksum: str): + sha = sha256() + with open(path, 'rb') as file: + while True: + buf = file.read(2**20) + if not buf: + break + sha.update(buf) + actual_checksum = sha.hexdigest()[:len(checksum)] + if actual_checksum != checksum: + raise ModelLoadingError(f'Invalid checksum for file {path}, ' + f'expected {checksum} but got {actual_checksum}') + + +class ModelOnlyRepo: + """Base class for all model only repos. + """ + def has_model(self, sig: str) -> bool: + raise NotImplementedError() + + def get_model(self, sig: str) -> Model: + raise NotImplementedError() + + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + raise NotImplementedError() + + +class RemoteRepo(ModelOnlyRepo): + def __init__(self, models: tp.Dict[str, str]): + self._models = models + + def has_model(self, sig: str) -> bool: + return sig in self._models + + def get_model(self, sig: str) -> Model: + try: + url = self._models[sig] + except KeyError: + raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.') + pkg = torch.hub.load_state_dict_from_url( + url, map_location='cpu', check_hash=True) # type: ignore + return load_model(pkg) + + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + return self._models # type: ignore + + +class LocalRepo(ModelOnlyRepo): + def __init__(self, root: Path): + self.root = root + self.scan() + + def scan(self): + self._models = {} + self._checksums = {} + for file in self.root.iterdir(): + if file.suffix == '.th': + if '-' in file.stem: + xp_sig, checksum = file.stem.split('-') + self._checksums[xp_sig] = checksum + else: + xp_sig = file.stem + if xp_sig in self._models: + raise ModelLoadingError( + f'Duplicate pre-trained model exist for signature {xp_sig}. ' + 'Please delete all but one.') + self._models[xp_sig] = file + + def has_model(self, sig: str) -> bool: + return sig in self._models + + def get_model(self, sig: str) -> Model: + try: + file = self._models[sig] + except KeyError: + raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.') + if sig in self._checksums: + check_checksum(file, self._checksums[sig]) + return load_model(file) + + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + return self._models + + +class BagOnlyRepo: + """Handles only YAML files containing bag of models, leaving the actual + model loading to some Repo. + """ + def __init__(self, root: Path, model_repo: ModelOnlyRepo): + self.root = root + self.model_repo = model_repo + self.scan() + + def scan(self): + self._bags = {} + for file in self.root.iterdir(): + if file.suffix == '.yaml': + self._bags[file.stem] = file + + def has_model(self, name: str) -> bool: + return name in self._bags + + def get_model(self, name: str) -> BagOfModels: + try: + yaml_file = self._bags[name] + except KeyError: + raise ModelLoadingError(f'{name} is neither a single pre-trained model or ' + 'a bag of models.') + bag = yaml.safe_load(open(yaml_file)) + signatures = bag['models'] + models = [self.model_repo.get_model(sig) for sig in signatures] + weights = bag.get('weights') + segment = bag.get('segment') + return BagOfModels(models, weights, segment) + + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + return self._bags + + +class AnyModelRepo: + def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): + self.model_repo = model_repo + self.bag_repo = bag_repo + + def has_model(self, name_or_sig: str) -> bool: + return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig) + + def get_model(self, name_or_sig: str) -> AnyModel: + if self.model_repo.has_model(name_or_sig): + return self.model_repo.get_model(name_or_sig) + else: + return self.bag_repo.get_model(name_or_sig) + + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + models = self.model_repo.list_model() + for key, value in self.bag_repo.list_model().items(): + models[key] = value + return models diff --git a/demucs/separate.py b/demucs/separate.py new file mode 100644 index 0000000000000000000000000000000000000000..8b2940a1ed449391528fa5590752730720b347ef --- /dev/null +++ b/demucs/separate.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys +from pathlib import Path + +from dora.log import fatal +import torch as th + +from .api import Separator, save_audio, list_models + +from .apply import BagOfModels +from .htdemucs import HTDemucs +from .pretrained import add_model_flags, ModelLoadingError + + +def get_parser(): + parser = argparse.ArgumentParser("demucs.separate", + description="Separate the sources for the given tracks") + parser.add_argument("tracks", nargs='*', type=Path, default=[], help='Path to tracks') + add_model_flags(parser) + parser.add_argument("--list-models", action="store_true", help="List available models " + "from current repo and exit") + parser.add_argument("-v", "--verbose", action="store_true") + parser.add_argument("-o", + "--out", + type=Path, + default=Path("separated"), + help="Folder where to put extracted tracks. A subfolder " + "with the model name will be created.") + parser.add_argument("--filename", + default="{track}/{stem}.{ext}", + help="Set the name of output file. \n" + 'Use "{track}", "{trackext}", "{stem}", "{ext}" to use ' + "variables of track name without extension, track extension, " + "stem name and default output file extension. \n" + 'Default is "{track}/{stem}.{ext}".') + parser.add_argument("-d", + "--device", + default="cuda" if th.cuda.is_available() else "cpu", + help="Device to use, default is cuda if available else cpu") + parser.add_argument("--shifts", + default=1, + type=int, + help="Number of random shifts for equivariant stabilization." + "Increase separation time but improves quality for Demucs. 10 was used " + "in the original paper.") + parser.add_argument("--overlap", + default=0.25, + type=float, + help="Overlap between the splits.") + split_group = parser.add_mutually_exclusive_group() + split_group.add_argument("--no-split", + action="store_false", + dest="split", + default=True, + help="Doesn't split audio in chunks. " + "This can use large amounts of memory.") + split_group.add_argument("--segment", type=int, + help="Set split size of each chunk. " + "This can help save memory of graphic card. ") + parser.add_argument("--two-stems", + dest="stem", metavar="STEM", + help="Only separate audio into {STEM} and no_{STEM}. ") + parser.add_argument("--other-method", dest="other_method", choices=["none", "add", "minus"], + default="add", help='Decide how to get "no_{STEM}". "none" will not save ' + '"no_{STEM}". "add" will add all the other stems. "minus" will use the ' + "original track minus the selected stem.") + depth_group = parser.add_mutually_exclusive_group() + depth_group.add_argument("--int24", action="store_true", + help="Save wav output as 24 bits wav.") + depth_group.add_argument("--float32", action="store_true", + help="Save wav output as float32 (2x bigger).") + parser.add_argument("--clip-mode", default="rescale", choices=["rescale", "clamp", "none"], + help="Strategy for avoiding clipping: rescaling entire signal " + "if necessary (rescale) or hard clipping (clamp).") + format_group = parser.add_mutually_exclusive_group() + format_group.add_argument("--flac", action="store_true", + help="Convert the output wavs to flac.") + format_group.add_argument("--mp3", action="store_true", + help="Convert the output wavs to mp3.") + parser.add_argument("--mp3-bitrate", + default=320, + type=int, + help="Bitrate of converted mp3.") + parser.add_argument("--mp3-preset", choices=range(2, 8), type=int, default=2, + help="Encoder preset of MP3, 2 for highest quality, 7 for " + "fastest speed. Default is 2") + parser.add_argument("-j", "--jobs", + default=0, + type=int, + help="Number of jobs. This can increase memory usage but will " + "be much faster when multiple cores are available.") + + return parser + + +def main(opts=None): + parser = get_parser() + args = parser.parse_args(opts) + if args.list_models: + models = list_models(args.repo) + print("Bag of models:", end="\n ") + print("\n ".join(models["bag"])) + print("Single models:", end="\n ") + print("\n ".join(models["single"])) + sys.exit(0) + if len(args.tracks) == 0: + print("error: the following arguments are required: tracks", file=sys.stderr) + sys.exit(1) + + try: + separator = Separator(model=args.name, + repo=args.repo, + device=args.device, + shifts=args.shifts, + split=args.split, + overlap=args.overlap, + progress=True, + jobs=args.jobs, + segment=args.segment) + except ModelLoadingError as error: + fatal(error.args[0]) + + max_allowed_segment = float('inf') + if isinstance(separator.model, HTDemucs): + max_allowed_segment = float(separator.model.segment) + elif isinstance(separator.model, BagOfModels): + max_allowed_segment = separator.model.max_allowed_segment + if args.segment is not None and args.segment > max_allowed_segment: + fatal("Cannot use a Transformer model with a longer segment " + f"than it was trained for. Maximum segment is: {max_allowed_segment}") + + if isinstance(separator.model, BagOfModels): + print( + f"Selected model is a bag of {len(separator.model.models)} models. " + "You will see that many progress bars per track." + ) + + if args.stem is not None and args.stem not in separator.model.sources: + fatal( + 'error: stem "{stem}" is not in selected model. ' + "STEM must be one of {sources}.".format( + stem=args.stem, sources=", ".join(separator.model.sources) + ) + ) + out = args.out / args.name + out.mkdir(parents=True, exist_ok=True) + print(f"Separated tracks will be stored in {out.resolve()}") + for track in args.tracks: + if not track.exists(): + print(f"File {track} does not exist. If the path contains spaces, " + 'please try again after surrounding the entire path with quotes "".', + file=sys.stderr) + continue + print(f"Separating track {track}") + + origin, res = separator.separate_audio_file(track) + + if args.mp3: + ext = "mp3" + elif args.flac: + ext = "flac" + else: + ext = "wav" + kwargs = { + "samplerate": separator.samplerate, + "bitrate": args.mp3_bitrate, + "preset": args.mp3_preset, + "clip": args.clip_mode, + "as_float": args.float32, + "bits_per_sample": 24 if args.int24 else 16, + } + if args.stem is None: + for name, source in res.items(): + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem=name, + ext=ext, + ) + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(source, str(stem), **kwargs) + else: + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem="minus_" + args.stem, + ext=ext, + ) + if args.other_method == "minus": + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(origin - res[args.stem], str(stem), **kwargs) + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem=args.stem, + ext=ext, + ) + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(res.pop(args.stem), str(stem), **kwargs) + # Warning : after poping the stem, selected stem is no longer in the dict 'res' + if args.other_method == "add": + other_stem = th.zeros_like(next(iter(res.values()))) + for i in res.values(): + other_stem += i + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem="no_" + args.stem, + ext=ext, + ) + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(other_stem, str(stem), **kwargs) + + +if __name__ == "__main__": + main() diff --git a/demucs/solver.py b/demucs/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..38537dbadd94bd3d40cb19a73acfcc83049383b2 --- /dev/null +++ b/demucs/solver.py @@ -0,0 +1,405 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Main training loop.""" + +import logging + +from dora import get_xp +from dora.utils import write_and_rename +from dora.log import LogProgress, bold +import torch +import torch.nn.functional as F + +from . import augment, distrib, states, pretrained +from .apply import apply_model +from .ema import ModelEMA +from .evaluate import evaluate, new_sdr +from .svd import svd_penalty +from .utils import pull_metric, EMA + +logger = logging.getLogger(__name__) + + +def _summary(metrics): + return " | ".join(f"{key.capitalize()}={val}" for key, val in metrics.items()) + + +class Solver(object): + def __init__(self, loaders, model, optimizer, args): + self.args = args + self.loaders = loaders + + self.model = model + self.optimizer = optimizer + self.quantizer = states.get_quantizer(self.model, args.quant, self.optimizer) + self.dmodel = distrib.wrap(model) + self.device = next(iter(self.model.parameters())).device + + # Exponential moving average of the model, either updated every batch or epoch. + # The best model from all the EMAs and the original one is kept based on the valid + # loss for the final best model. + self.emas = {'batch': [], 'epoch': []} + for kind in self.emas.keys(): + decays = getattr(args.ema, kind) + device = self.device if kind == 'batch' else 'cpu' + if decays: + for decay in decays: + self.emas[kind].append(ModelEMA(self.model, decay, device=device)) + + # data augment + augments = [augment.Shift(shift=int(args.dset.samplerate * args.dset.shift), + same=args.augment.shift_same)] + if args.augment.flip: + augments += [augment.FlipChannels(), augment.FlipSign()] + for aug in ['scale', 'remix']: + kw = getattr(args.augment, aug) + if kw.proba: + augments.append(getattr(augment, aug.capitalize())(**kw)) + self.augment = torch.nn.Sequential(*augments) + + xp = get_xp() + self.folder = xp.folder + # Checkpoints + self.checkpoint_file = xp.folder / 'checkpoint.th' + self.best_file = xp.folder / 'best.th' + logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve()) + self.best_state = None + self.best_changed = False + + self.link = xp.link + self.history = self.link.history + + self._reset() + + def _serialize(self, epoch): + package = {} + package['state'] = self.model.state_dict() + package['optimizer'] = self.optimizer.state_dict() + package['history'] = self.history + package['best_state'] = self.best_state + package['args'] = self.args + for kind, emas in self.emas.items(): + for k, ema in enumerate(emas): + package[f'ema_{kind}_{k}'] = ema.state_dict() + with write_and_rename(self.checkpoint_file) as tmp: + torch.save(package, tmp) + + save_every = self.args.save_every + if save_every and (epoch + 1) % save_every == 0 and epoch + 1 != self.args.epochs: + with write_and_rename(self.folder / f'checkpoint_{epoch + 1}.th') as tmp: + torch.save(package, tmp) + + if self.best_changed: + # Saving only the latest best model. + with write_and_rename(self.best_file) as tmp: + package = states.serialize_model(self.model, self.args) + package['state'] = self.best_state + torch.save(package, tmp) + self.best_changed = False + + def _reset(self): + """Reset state of the solver, potentially using checkpoint.""" + if self.checkpoint_file.exists(): + logger.info(f'Loading checkpoint model: {self.checkpoint_file}') + package = torch.load(self.checkpoint_file, 'cpu') + self.model.load_state_dict(package['state']) + self.optimizer.load_state_dict(package['optimizer']) + self.history[:] = package['history'] + self.best_state = package['best_state'] + for kind, emas in self.emas.items(): + for k, ema in enumerate(emas): + ema.load_state_dict(package[f'ema_{kind}_{k}']) + elif self.args.continue_pretrained: + model = pretrained.get_model( + name=self.args.continue_pretrained, + repo=self.args.pretrained_repo) + self.model.load_state_dict(model.state_dict()) + elif self.args.continue_from: + name = 'checkpoint.th' + root = self.folder.parent + cf = root / str(self.args.continue_from) / name + logger.info("Loading from %s", cf) + package = torch.load(cf, 'cpu') + self.best_state = package['best_state'] + if self.args.continue_best: + self.model.load_state_dict(package['best_state'], strict=False) + else: + self.model.load_state_dict(package['state'], strict=False) + if self.args.continue_opt: + self.optimizer.load_state_dict(package['optimizer']) + + def _format_train(self, metrics: dict) -> dict: + """Formatting for train/valid metrics.""" + losses = { + 'loss': format(metrics['loss'], ".4f"), + 'reco': format(metrics['reco'], ".4f"), + } + if 'nsdr' in metrics: + losses['nsdr'] = format(metrics['nsdr'], ".3f") + if self.quantizer is not None: + losses['ms'] = format(metrics['ms'], ".2f") + if 'grad' in metrics: + losses['grad'] = format(metrics['grad'], ".4f") + if 'best' in metrics: + losses['best'] = format(metrics['best'], '.4f') + if 'bname' in metrics: + losses['bname'] = metrics['bname'] + if 'penalty' in metrics: + losses['penalty'] = format(metrics['penalty'], ".4f") + if 'hloss' in metrics: + losses['hloss'] = format(metrics['hloss'], ".4f") + return losses + + def _format_test(self, metrics: dict) -> dict: + """Formatting for test metrics.""" + losses = {} + if 'sdr' in metrics: + losses['sdr'] = format(metrics['sdr'], '.3f') + if 'nsdr' in metrics: + losses['nsdr'] = format(metrics['nsdr'], '.3f') + for source in self.model.sources: + key = f'sdr_{source}' + if key in metrics: + losses[key] = format(metrics[key], '.3f') + key = f'nsdr_{source}' + if key in metrics: + losses[key] = format(metrics[key], '.3f') + return losses + + def train(self): + # Optimizing the model + if self.history: + logger.info("Replaying metrics from previous run") + for epoch, metrics in enumerate(self.history): + formatted = self._format_train(metrics['train']) + logger.info( + bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}')) + formatted = self._format_train(metrics['valid']) + logger.info( + bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}')) + if 'test' in metrics: + formatted = self._format_test(metrics['test']) + if formatted: + logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}")) + + epoch = 0 + for epoch in range(len(self.history), self.args.epochs): + # Train one epoch + self.model.train() # Turn on BatchNorm & Dropout + metrics = {} + logger.info('-' * 70) + logger.info("Training...") + metrics['train'] = self._run_one_epoch(epoch) + formatted = self._format_train(metrics['train']) + logger.info( + bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}')) + + # Cross validation + logger.info('-' * 70) + logger.info('Cross validation...') + self.model.eval() # Turn off Batchnorm & Dropout + with torch.no_grad(): + valid = self._run_one_epoch(epoch, train=False) + bvalid = valid + bname = 'main' + state = states.copy_state(self.model.state_dict()) + metrics['valid'] = {} + metrics['valid']['main'] = valid + key = self.args.test.metric + for kind, emas in self.emas.items(): + for k, ema in enumerate(emas): + with ema.swap(): + valid = self._run_one_epoch(epoch, train=False) + name = f'ema_{kind}_{k}' + metrics['valid'][name] = valid + a = valid[key] + b = bvalid[key] + if key.startswith('nsdr'): + a = -a + b = -b + if a < b: + bvalid = valid + state = ema.state + bname = name + metrics['valid'].update(bvalid) + metrics['valid']['bname'] = bname + + valid_loss = metrics['valid'][key] + mets = pull_metric(self.link.history, f'valid.{key}') + [valid_loss] + if key.startswith('nsdr'): + best_loss = max(mets) + else: + best_loss = min(mets) + metrics['valid']['best'] = best_loss + if self.args.svd.penalty > 0: + kw = dict(self.args.svd) + kw.pop('penalty') + with torch.no_grad(): + penalty = svd_penalty(self.model, exact=True, **kw) + metrics['valid']['penalty'] = penalty + + formatted = self._format_train(metrics['valid']) + logger.info( + bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}')) + + # Save the best model + if valid_loss == best_loss or self.args.dset.train_valid: + logger.info(bold('New best valid loss %.4f'), valid_loss) + self.best_state = states.copy_state(state) + self.best_changed = True + + # Eval model every `test.every` epoch or on last epoch + should_eval = (epoch + 1) % self.args.test.every == 0 + is_last = epoch == self.args.epochs - 1 + # # Tries to detect divergence in a reliable way and finish job + # # not to waste compute. + # # Commented out as this was super specific to the MDX competition. + # reco = metrics['valid']['main']['reco'] + # div = epoch >= 180 and reco > 0.18 + # div = div or epoch >= 100 and reco > 0.25 + # div = div and self.args.optim.loss == 'l1' + # if div: + # logger.warning("Finishing training early because valid loss is too high.") + # is_last = True + if should_eval or is_last: + # Evaluate on the testset + logger.info('-' * 70) + logger.info('Evaluating on the test set...') + # We switch to the best known model for testing + if self.args.test.best: + state = self.best_state + else: + state = states.copy_state(self.model.state_dict()) + compute_sdr = self.args.test.sdr and is_last + with states.swap_state(self.model, state): + with torch.no_grad(): + metrics['test'] = evaluate(self, compute_sdr=compute_sdr) + formatted = self._format_test(metrics['test']) + logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}")) + self.link.push_metrics(metrics) + + if distrib.rank == 0: + # Save model each epoch + self._serialize(epoch) + logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve()) + if is_last: + break + + def _run_one_epoch(self, epoch, train=True): + args = self.args + data_loader = self.loaders['train'] if train else self.loaders['valid'] + if distrib.world_size > 1 and train: + data_loader.sampler.set_epoch(epoch) + + label = ["Valid", "Train"][train] + name = label + f" | Epoch {epoch + 1}" + total = len(data_loader) + if args.max_batches: + total = min(total, args.max_batches) + logprog = LogProgress(logger, data_loader, total=total, + updates=self.args.misc.num_prints, name=name) + averager = EMA() + + for idx, sources in enumerate(logprog): + sources = sources.to(self.device) + if train: + sources = self.augment(sources) + mix = sources.sum(dim=1) + else: + mix = sources[:, 0] + sources = sources[:, 1:] + + if not train and self.args.valid_apply: + estimate = apply_model(self.model, mix, split=self.args.test.split, overlap=0) + else: + estimate = self.dmodel(mix) + if train and hasattr(self.model, 'transform_target'): + sources = self.model.transform_target(mix, sources) + assert estimate.shape == sources.shape, (estimate.shape, sources.shape) + dims = tuple(range(2, sources.dim())) + + if args.optim.loss == 'l1': + loss = F.l1_loss(estimate, sources, reduction='none') + loss = loss.mean(dims).mean(0) + reco = loss + elif args.optim.loss == 'mse': + loss = F.mse_loss(estimate, sources, reduction='none') + loss = loss.mean(dims) + reco = loss**0.5 + reco = reco.mean(0) + else: + raise ValueError(f"Invalid loss {self.args.loss}") + weights = torch.tensor(args.weights).to(sources) + loss = (loss * weights).sum() / weights.sum() + + ms = 0 + if self.quantizer is not None: + ms = self.quantizer.model_size() + if args.quant.diffq: + loss += args.quant.diffq * ms + + losses = {} + losses['reco'] = (reco * weights).sum() / weights.sum() + losses['ms'] = ms + + if not train: + nsdrs = new_sdr(sources, estimate.detach()).mean(0) + total = 0 + for source, nsdr, w in zip(self.model.sources, nsdrs, weights): + losses[f'nsdr_{source}'] = nsdr + total += w * nsdr + losses['nsdr'] = total / weights.sum() + + if train and args.svd.penalty > 0: + kw = dict(args.svd) + kw.pop('penalty') + penalty = svd_penalty(self.model, **kw) + losses['penalty'] = penalty + loss += args.svd.penalty * penalty + + losses['loss'] = loss + + for k, source in enumerate(self.model.sources): + losses[f'reco_{source}'] = reco[k] + + # optimize model in training mode + if train: + loss.backward() + grad_norm = 0 + grads = [] + for p in self.model.parameters(): + if p.grad is not None: + grad_norm += p.grad.data.norm()**2 + grads.append(p.grad.data) + losses['grad'] = grad_norm ** 0.5 + if args.optim.clip_grad: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + args.optim.clip_grad) + + if self.args.flag == 'uns': + for n, p in self.model.named_parameters(): + if p.grad is None: + print('no grad', n) + self.optimizer.step() + self.optimizer.zero_grad() + for ema in self.emas['batch']: + ema.update() + losses = averager(losses) + logs = self._format_train(losses) + logprog.update(**logs) + # Just in case, clear some memory + del loss, estimate, reco, ms + if args.max_batches == idx: + break + if self.args.debug and train: + break + if self.args.flag == 'debug': + break + if train: + for ema in self.emas['epoch']: + ema.update() + return distrib.average(losses, idx + 1) diff --git a/demucs/spec.py b/demucs/spec.py new file mode 100644 index 0000000000000000000000000000000000000000..f7669fad436d1fec36e7390ee1fa6179725b9110 --- /dev/null +++ b/demucs/spec.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Conveniance wrapper to perform STFT and iSTFT""" + +import torch as th + + +def spectro(x, n_fft=512, hop_length=None, pad=0): + *other, length = x.shape + x = x.reshape(-1, length) + is_mps = x.device.type == 'mps' + if is_mps: + x = x.cpu() + z = th.stft(x, + n_fft * (1 + pad), + hop_length or n_fft // 4, + window=th.hann_window(n_fft).to(x), + win_length=n_fft, + normalized=True, + center=True, + return_complex=True, + pad_mode='reflect') + _, freqs, frame = z.shape + return z.view(*other, freqs, frame) + + +def ispectro(z, hop_length=None, length=None, pad=0): + *other, freqs, frames = z.shape + n_fft = 2 * freqs - 2 + z = z.view(-1, freqs, frames) + win_length = n_fft // (1 + pad) + is_mps = z.device.type == 'mps' + if is_mps: + z = z.cpu() + x = th.istft(z, + n_fft, + hop_length, + window=th.hann_window(win_length).to(z.real), + win_length=win_length, + normalized=True, + length=length, + center=True) + _, length = x.shape + return x.view(*other, length) diff --git a/demucs/states.py b/demucs/states.py new file mode 100644 index 0000000000000000000000000000000000000000..54dfcd5a19425bade2b412eef69e8e3d2e2f43af --- /dev/null +++ b/demucs/states.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Utilities to save and load models. +""" +from contextlib import contextmanager + +import functools +import hashlib +import inspect +import io +from pathlib import Path +import warnings + +from omegaconf import OmegaConf +from dora.log import fatal +import torch + + +def _check_diffq(): + try: + import diffq # noqa + except ImportError: + fatal('Trying to use DiffQ, but diffq is not installed.\n' + 'On Windows run: python.exe -m pip install diffq \n' + 'On Linux/Mac, run: python3 -m pip install diffq') + + +def get_quantizer(model, args, optimizer=None): + """Return the quantizer given the XP quantization args.""" + quantizer = None + if args.diffq: + _check_diffq() + from diffq import DiffQuantizer + quantizer = DiffQuantizer( + model, min_size=args.min_size, group_size=args.group_size) + if optimizer is not None: + quantizer.setup_optimizer(optimizer) + elif args.qat: + _check_diffq() + from diffq import UniformQuantizer + quantizer = UniformQuantizer( + model, bits=args.qat, min_size=args.min_size) + return quantizer + + +def load_model(path_or_package, strict=False): + """Load a model from the given serialized model, either given as a dict (already loaded) + or a path to a file on disk.""" + if isinstance(path_or_package, dict): + package = path_or_package + elif isinstance(path_or_package, (str, Path)): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + path = path_or_package + package = torch.load(path, 'cpu') + else: + raise ValueError(f"Invalid type for {path_or_package}.") + + klass = package["klass"] + args = package["args"] + kwargs = package["kwargs"] + + if strict: + model = klass(*args, **kwargs) + else: + sig = inspect.signature(klass) + for key in list(kwargs): + if key not in sig.parameters: + warnings.warn("Dropping inexistant parameter " + key) + del kwargs[key] + model = klass(*args, **kwargs) + + state = package["state"] + + set_state(model, state) + return model + + +def get_state(model, quantizer, half=False): + """Get the state from a model, potentially with quantization applied. + If `half` is True, model are stored as half precision, which shouldn't impact performance + but half the state size.""" + if quantizer is None: + dtype = torch.half if half else None + state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()} + else: + state = quantizer.get_quantized_state() + state['__quantized'] = True + return state + + +def set_state(model, state, quantizer=None): + """Set the state on a given model.""" + if state.get('__quantized'): + if quantizer is not None: + quantizer.restore_quantized_state(model, state['quantized']) + else: + _check_diffq() + from diffq import restore_quantized_state + restore_quantized_state(model, state) + else: + model.load_state_dict(state) + return state + + +def save_with_checksum(content, path): + """Save the given value on disk, along with a sha256 hash. + Should be used with the output of either `serialize_model` or `get_state`.""" + buf = io.BytesIO() + torch.save(content, buf) + sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] + + path = path.parent / (path.stem + "-" + sig + path.suffix) + path.write_bytes(buf.getvalue()) + + +def serialize_model(model, training_args, quantizer=None, half=True): + args, kwargs = model._init_args_kwargs + klass = model.__class__ + + state = get_state(model, quantizer, half) + return { + 'klass': klass, + 'args': args, + 'kwargs': kwargs, + 'state': state, + 'training_args': OmegaConf.to_container(training_args, resolve=True), + } + + +def copy_state(state): + return {k: v.cpu().clone() for k, v in state.items()} + + +@contextmanager +def swap_state(model, state): + """ + Context manager that swaps the state of a model, e.g: + + # model is in old state + with swap_state(model, new_state): + # model in new state + # model back to old state + """ + old_state = copy_state(model.state_dict()) + model.load_state_dict(state, strict=False) + try: + yield + finally: + model.load_state_dict(old_state) + + +def capture_init(init): + @functools.wraps(init) + def __init__(self, *args, **kwargs): + self._init_args_kwargs = (args, kwargs) + init(self, *args, **kwargs) + + return __init__ diff --git a/demucs/svd.py b/demucs/svd.py new file mode 100644 index 0000000000000000000000000000000000000000..7e266cb952caec8cc3cfcec742cef0bcc7dc127a --- /dev/null +++ b/demucs/svd.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Ways to make the model stronger.""" +import random +import torch + + +def power_iteration(m, niters=1, bs=1): + """This is the power method. batch size is used to try multiple starting point in parallel.""" + assert m.dim() == 2 + assert m.shape[0] == m.shape[1] + dim = m.shape[0] + b = torch.randn(dim, bs, device=m.device, dtype=m.dtype) + + for _ in range(niters): + n = m.mm(b) + norm = n.norm(dim=0, keepdim=True) + b = n / (1e-10 + norm) + + return norm.mean() + + +# We need a shared RNG to make sure all the distributed worker will skip the penalty together, +# as otherwise we wouldn't get any speed up. +penalty_rng = random.Random(1234) + + +def svd_penalty(model, min_size=0.1, dim=1, niters=2, powm=False, convtr=True, + proba=1, conv_only=False, exact=False, bs=1): + """ + Penalty on the largest singular value for a layer. + Args: + - model: model to penalize + - min_size: minimum size in MB of a layer to penalize. + - dim: projection dimension for the svd_lowrank. Higher is better but slower. + - niters: number of iterations in the algorithm used by svd_lowrank. + - powm: use power method instead of lowrank SVD, my own experience + is that it is both slower and less stable. + - convtr: when True, differentiate between Conv and Transposed Conv. + this is kept for compatibility with older experiments. + - proba: probability to apply the penalty. + - conv_only: only apply to conv and conv transposed, not LSTM + (might not be reliable for other models than Demucs). + - exact: use exact SVD (slow but useful at validation). + - bs: batch_size for power method. + """ + total = 0 + if penalty_rng.random() > proba: + return 0. + + for m in model.modules(): + for name, p in m.named_parameters(recurse=False): + if p.numel() / 2**18 < min_size: + continue + if convtr: + if isinstance(m, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)): + if p.dim() in [3, 4]: + p = p.transpose(0, 1).contiguous() + if p.dim() == 3: + p = p.view(len(p), -1) + elif p.dim() == 4: + p = p.view(len(p), -1) + elif p.dim() == 1: + continue + elif conv_only: + continue + assert p.dim() == 2, (name, p.shape) + if exact: + estimate = torch.svd(p, compute_uv=False)[1].pow(2).max() + elif powm: + a, b = p.shape + if a < b: + n = p.mm(p.t()) + else: + n = p.t().mm(p) + estimate = power_iteration(n, niters, bs) + else: + estimate = torch.svd_lowrank(p, dim, niters)[1][0].pow(2) + total += estimate + return total / proba diff --git a/demucs/train.py b/demucs/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9dbc3c6bdc8cd6b136f397af672a0d3d27742fe1 --- /dev/null +++ b/demucs/train.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Main training script entry point""" + +import logging +import os +from pathlib import Path +import sys + +from dora import hydra_main +import hydra +from hydra.core.global_hydra import GlobalHydra +from omegaconf import OmegaConf +import torch +from torch import nn +import torchaudio +from torch.utils.data import ConcatDataset + +from . import distrib +from .wav import get_wav_datasets, get_musdb_wav_datasets +from .demucs import Demucs +from .hdemucs import HDemucs +from .htdemucs import HTDemucs +from .repitch import RepitchedWrapper +from .solver import Solver +from .states import capture_init +from .utils import random_subset + +logger = logging.getLogger(__name__) + + +class TorchHDemucsWrapper(nn.Module): + """Wrapper around torchaudio HDemucs implementation to provide the proper metadata + for model evaluation. + See https://pytorch.org/audio/stable/tutorials/hybrid_demucs_tutorial.html""" + + @capture_init + def __init__(self, **kwargs): + super().__init__() + try: + from torchaudio.models import HDemucs as TorchHDemucs + except ImportError: + raise ImportError("Please upgrade torchaudio for using its implementation of HDemucs") + self.samplerate = kwargs.pop('samplerate') + self.segment = kwargs.pop('segment') + self.sources = kwargs['sources'] + self.torch_hdemucs = TorchHDemucs(**kwargs) + + def forward(self, mix): + return self.torch_hdemucs.forward(mix) + + +def get_model(args): + extra = { + 'sources': list(args.dset.sources), + 'audio_channels': args.dset.channels, + 'samplerate': args.dset.samplerate, + 'segment': args.model_segment or 4 * args.dset.segment, + } + klass = { + 'demucs': Demucs, + 'hdemucs': HDemucs, + 'htdemucs': HTDemucs, + 'torch_hdemucs': TorchHDemucsWrapper, + }[args.model] + kw = OmegaConf.to_container(getattr(args, args.model), resolve=True) + model = klass(**extra, **kw) + return model + + +def get_optimizer(model, args): + seen_params = set() + other_params = [] + groups = [] + for n, module in model.named_modules(): + if hasattr(module, "make_optim_group"): + group = module.make_optim_group() + params = set(group["params"]) + assert params.isdisjoint(seen_params) + seen_params |= set(params) + groups.append(group) + for param in model.parameters(): + if param not in seen_params: + other_params.append(param) + groups.insert(0, {"params": other_params}) + parameters = groups + if args.optim.optim == "adam": + return torch.optim.Adam( + parameters, + lr=args.optim.lr, + betas=(args.optim.momentum, args.optim.beta2), + weight_decay=args.optim.weight_decay, + ) + elif args.optim.optim == "adamw": + return torch.optim.AdamW( + parameters, + lr=args.optim.lr, + betas=(args.optim.momentum, args.optim.beta2), + weight_decay=args.optim.weight_decay, + ) + else: + raise ValueError("Invalid optimizer %s", args.optim.optimizer) + + +def get_datasets(args): + if args.dset.backend: + torchaudio.set_audio_backend(args.dset.backend) + if args.dset.use_musdb: + train_set, valid_set = get_musdb_wav_datasets(args.dset) + else: + train_set, valid_set = [], [] + if args.dset.wav: + extra_train_set, extra_valid_set = get_wav_datasets(args.dset) + if len(args.dset.sources) <= 4: + train_set = ConcatDataset([train_set, extra_train_set]) + valid_set = ConcatDataset([valid_set, extra_valid_set]) + else: + train_set = extra_train_set + valid_set = extra_valid_set + + if args.dset.wav2: + extra_train_set, extra_valid_set = get_wav_datasets(args.dset, "wav2") + weight = args.dset.wav2_weight + if weight is not None: + b = len(train_set) + e = len(extra_train_set) + reps = max(1, round(e / b * (1 / weight - 1))) + else: + reps = 1 + train_set = ConcatDataset([train_set] * reps + [extra_train_set]) + if args.dset.wav2_valid: + if weight is not None: + b = len(valid_set) + n_kept = int(round(weight * b / (1 - weight))) + valid_set = ConcatDataset( + [valid_set, random_subset(extra_valid_set, n_kept)] + ) + else: + valid_set = ConcatDataset([valid_set, extra_valid_set]) + if args.dset.valid_samples is not None: + valid_set = random_subset(valid_set, args.dset.valid_samples) + assert len(train_set) + assert len(valid_set) + return train_set, valid_set + + +def get_solver(args, model_only=False): + distrib.init() + + torch.manual_seed(args.seed) + model = get_model(args) + if args.misc.show: + logger.info(model) + mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20 + logger.info('Size: %.1f MB', mb) + if hasattr(model, 'valid_length'): + field = model.valid_length(1) + logger.info('Field: %.1f ms', field / args.dset.samplerate * 1000) + sys.exit(0) + + # torch also initialize cuda seed if available + if torch.cuda.is_available(): + model.cuda() + + # optimizer + optimizer = get_optimizer(model, args) + + assert args.batch_size % distrib.world_size == 0 + args.batch_size //= distrib.world_size + + if model_only: + return Solver(None, model, optimizer, args) + + train_set, valid_set = get_datasets(args) + + if args.augment.repitch.proba: + vocals = [] + if 'vocals' in args.dset.sources: + vocals.append(args.dset.sources.index('vocals')) + else: + logger.warning('No vocal source found') + if args.augment.repitch.proba: + train_set = RepitchedWrapper(train_set, vocals=vocals, **args.augment.repitch) + + logger.info("train/valid set size: %d %d", len(train_set), len(valid_set)) + train_loader = distrib.loader( + train_set, batch_size=args.batch_size, shuffle=True, + num_workers=args.misc.num_workers, drop_last=True) + if args.dset.full_cv: + valid_loader = distrib.loader( + valid_set, batch_size=1, shuffle=False, + num_workers=args.misc.num_workers) + else: + valid_loader = distrib.loader( + valid_set, batch_size=args.batch_size, shuffle=False, + num_workers=args.misc.num_workers, drop_last=True) + loaders = {"train": train_loader, "valid": valid_loader} + + # Construct Solver + return Solver(loaders, model, optimizer, args) + + +def get_solver_from_sig(sig, model_only=False): + inst = GlobalHydra.instance() + hyd = None + if inst.is_initialized(): + hyd = inst.hydra + inst.clear() + xp = main.get_xp_from_sig(sig) + if hyd is not None: + inst.clear() + inst.initialize(hyd) + + with xp.enter(stack=True): + return get_solver(xp.cfg, model_only) + + +@hydra_main(config_path="../conf", config_name="config", version_base="1.1") +def main(args): + global __file__ + __file__ = hydra.utils.to_absolute_path(__file__) + for attr in ["musdb", "wav", "metadata"]: + val = getattr(args.dset, attr) + if val is not None: + setattr(args.dset, attr, hydra.utils.to_absolute_path(val)) + + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["MKL_NUM_THREADS"] = "1" + + if args.misc.verbose: + logger.setLevel(logging.DEBUG) + + logger.info("For logs, checkpoints and samples check %s", os.getcwd()) + logger.debug(args) + from dora import get_xp + logger.debug(get_xp().cfg) + + solver = get_solver(args) + solver.train() + + +if '_DORA_TEST_PATH' in os.environ: + main.dora.dir = Path(os.environ['_DORA_TEST_PATH']) + + +if __name__ == "__main__": + main() diff --git a/demucs/transformer.py b/demucs/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0fbb486633da64800808bcf6ac70ecb2bfce4796 --- /dev/null +++ b/demucs/transformer.py @@ -0,0 +1,839 @@ +# Copyright (c) 2019-present, Meta, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# First author is Simon Rouard. + +import random +import typing as tp + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +from einops import rearrange + + +def create_sin_embedding( + length: int, dim: int, shift: int = 0, device="cpu", max_period=10000 +): + # We aim for TBC format + assert dim % 2 == 0 + pos = shift + torch.arange(length, device=device).view(-1, 1, 1) + half_dim = dim // 2 + adim = torch.arange(dim // 2, device=device).view(1, 1, -1) + phase = pos / (max_period ** (adim / (half_dim - 1))) + return torch.cat( + [ + torch.cos(phase), + torch.sin(phase), + ], + dim=-1, + ) + + +def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000): + """ + :param d_model: dimension of the model + :param height: height of the positions + :param width: width of the positions + :return: d_model*height*width position matrix + """ + if d_model % 4 != 0: + raise ValueError( + "Cannot use sin/cos positional encoding with " + "odd dimension (got dim={:d})".format(d_model) + ) + pe = torch.zeros(d_model, height, width) + # Each dimension use half of d_model + d_model = int(d_model / 2) + div_term = torch.exp( + torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model) + ) + pos_w = torch.arange(0.0, width).unsqueeze(1) + pos_h = torch.arange(0.0, height).unsqueeze(1) + pe[0:d_model:2, :, :] = ( + torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + ) + pe[1:d_model:2, :, :] = ( + torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + ) + pe[d_model::2, :, :] = ( + torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + ) + pe[d_model + 1:: 2, :, :] = ( + torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + ) + + return pe[None, :].to(device) + + +def create_sin_embedding_cape( + length: int, + dim: int, + batch_size: int, + mean_normalize: bool, + augment: bool, # True during training + max_global_shift: float = 0.0, # delta max + max_local_shift: float = 0.0, # epsilon max + max_scale: float = 1.0, + device: str = "cpu", + max_period: float = 10000.0, +): + # We aim for TBC format + assert dim % 2 == 0 + pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1) + pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1) + if mean_normalize: + pos -= torch.nanmean(pos, dim=0, keepdim=True) + + if augment: + delta = np.random.uniform( + -max_global_shift, +max_global_shift, size=[1, batch_size, 1] + ) + delta_local = np.random.uniform( + -max_local_shift, +max_local_shift, size=[length, batch_size, 1] + ) + log_lambdas = np.random.uniform( + -np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1] + ) + pos = (pos + delta + delta_local) * np.exp(log_lambdas) + + pos = pos.to(device) + + half_dim = dim // 2 + adim = torch.arange(dim // 2, device=device).view(1, 1, -1) + phase = pos / (max_period ** (adim / (half_dim - 1))) + return torch.cat( + [ + torch.cos(phase), + torch.sin(phase), + ], + dim=-1, + ).float() + + +def get_causal_mask(length): + pos = torch.arange(length) + return pos > pos[:, None] + + +def get_elementary_mask( + T1, + T2, + mask_type, + sparse_attn_window, + global_window, + mask_random_seed, + sparsity, + device, +): + """ + When the input of the Decoder has length T1 and the output T2 + The mask matrix has shape (T2, T1) + """ + assert mask_type in ["diag", "jmask", "random", "global"] + + if mask_type == "global": + mask = torch.zeros(T2, T1, dtype=torch.bool) + mask[:, :global_window] = True + line_window = int(global_window * T2 / T1) + mask[:line_window, :] = True + + if mask_type == "diag": + + mask = torch.zeros(T2, T1, dtype=torch.bool) + rows = torch.arange(T2)[:, None] + cols = ( + (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1)) + .long() + .clamp(0, T1 - 1) + ) + mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols)) + + elif mask_type == "jmask": + mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool) + rows = torch.arange(T2 + 2)[:, None] + t = torch.arange(0, int((2 * T1) ** 0.5 + 1)) + t = (t * (t + 1) / 2).int() + t = torch.cat([-t.flip(0)[:-1], t]) + cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1) + mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols)) + mask = mask[1:-1, 1:-1] + + elif mask_type == "random": + gene = torch.Generator(device=device) + gene.manual_seed(mask_random_seed) + mask = ( + torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1) + > sparsity + ) + + mask = mask.to(device) + return mask + + +def get_mask( + T1, + T2, + mask_type, + sparse_attn_window, + global_window, + mask_random_seed, + sparsity, + device, +): + """ + Return a SparseCSRTensor mask that is a combination of elementary masks + mask_type can be a combination of multiple masks: for instance "diag_jmask_random" + """ + from xformers.sparse import SparseCSRTensor + # create a list + mask_types = mask_type.split("_") + + all_masks = [ + get_elementary_mask( + T1, + T2, + mask, + sparse_attn_window, + global_window, + mask_random_seed, + sparsity, + device, + ) + for mask in mask_types + ] + + final_mask = torch.stack(all_masks).sum(axis=0) > 0 + + return SparseCSRTensor.from_dense(final_mask[None]) + + +class ScaledEmbedding(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + scale: float = 1.0, + boost: float = 3.0, + ): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + self.embedding.weight.data *= scale / boost + self.boost = boost + + @property + def weight(self): + return self.embedding.weight * self.boost + + def forward(self, x): + return self.embedding(x) * self.boost + + +class LayerScale(nn.Module): + """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). + This rescales diagonaly residual outputs close to 0 initially, then learnt. + """ + + def __init__(self, channels: int, init: float = 0, channel_last=False): + """ + channel_last = False corresponds to (B, C, T) tensors + channel_last = True corresponds to (T, B, C) tensors + """ + super().__init__() + self.channel_last = channel_last + self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) + self.scale.data[:] = init + + def forward(self, x): + if self.channel_last: + return self.scale * x + else: + return self.scale[:, None] * x + + +class MyGroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + """ + x: (B, T, C) + if num_groups=1: Normalisation on all T and C together for each B + """ + x = x.transpose(1, 2) + return super().forward(x).transpose(1, 2) + + +class MyTransformerEncoderLayer(nn.TransformerEncoderLayer): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation=F.relu, + group_norm=0, + norm_first=False, + norm_out=False, + layer_norm_eps=1e-5, + layer_scale=False, + init_values=1e-4, + device=None, + dtype=None, + sparse=False, + mask_type="diag", + mask_random_seed=42, + sparse_attn_window=500, + global_window=50, + auto_sparsity=False, + sparsity=0.95, + batch_first=False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + batch_first=batch_first, + norm_first=norm_first, + device=device, + dtype=dtype, + ) + self.sparse = sparse + self.auto_sparsity = auto_sparsity + if sparse: + if not auto_sparsity: + self.mask_type = mask_type + self.sparse_attn_window = sparse_attn_window + self.global_window = global_window + self.sparsity = sparsity + if group_norm: + self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + + self.norm_out = None + if self.norm_first & norm_out: + self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) + self.gamma_1 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + self.gamma_2 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + + if sparse: + self.self_attn = MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, + auto_sparsity=sparsity if auto_sparsity else 0, + ) + self.__setattr__("src_mask", torch.zeros(1, 1)) + self.mask_random_seed = mask_random_seed + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + """ + if batch_first = False, src shape is (T, B, C) + the case where batch_first=True is not covered + """ + device = src.device + x = src + T, B, C = x.shape + if self.sparse and not self.auto_sparsity: + assert src_mask is None + src_mask = self.src_mask + if src_mask.shape[-1] != T: + src_mask = get_mask( + T, + T, + self.mask_type, + self.sparse_attn_window, + self.global_window, + self.mask_random_seed, + self.sparsity, + device, + ) + self.__setattr__("src_mask", src_mask) + + if self.norm_first: + x = x + self.gamma_1( + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + ) + x = x + self.gamma_2(self._ff_block(self.norm2(x))) + + if self.norm_out: + x = self.norm_out(x) + else: + x = self.norm1( + x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask)) + ) + x = self.norm2(x + self.gamma_2(self._ff_block(x))) + + return x + + +class CrossTransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation=F.relu, + layer_norm_eps: float = 1e-5, + layer_scale: bool = False, + init_values: float = 1e-4, + norm_first: bool = False, + group_norm: bool = False, + norm_out: bool = False, + sparse=False, + mask_type="diag", + mask_random_seed=42, + sparse_attn_window=500, + global_window=50, + sparsity=0.95, + auto_sparsity=None, + device=None, + dtype=None, + batch_first=False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.sparse = sparse + self.auto_sparsity = auto_sparsity + if sparse: + if not auto_sparsity: + self.mask_type = mask_type + self.sparse_attn_window = sparse_attn_window + self.global_window = global_window + self.sparsity = sparsity + + self.cross_attn: nn.Module + self.cross_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) + + self.norm_first = norm_first + self.norm1: nn.Module + self.norm2: nn.Module + self.norm3: nn.Module + if group_norm: + self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + else: + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + + self.norm_out = None + if self.norm_first & norm_out: + self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) + + self.gamma_1 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + self.gamma_2 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = self._get_activation_fn(activation) + else: + self.activation = activation + + if sparse: + self.cross_attn = MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, + auto_sparsity=sparsity if auto_sparsity else 0) + if not auto_sparsity: + self.__setattr__("mask", torch.zeros(1, 1)) + self.mask_random_seed = mask_random_seed + + def forward(self, q, k, mask=None): + """ + Args: + q: tensor of shape (T, B, C) + k: tensor of shape (S, B, C) + mask: tensor of shape (T, S) + + """ + device = q.device + T, B, C = q.shape + S, B, C = k.shape + if self.sparse and not self.auto_sparsity: + assert mask is None + mask = self.mask + if mask.shape[-1] != S or mask.shape[-2] != T: + mask = get_mask( + S, + T, + self.mask_type, + self.sparse_attn_window, + self.global_window, + self.mask_random_seed, + self.sparsity, + device, + ) + self.__setattr__("mask", mask) + + if self.norm_first: + x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask)) + x = x + self.gamma_2(self._ff_block(self.norm3(x))) + if self.norm_out: + x = self.norm_out(x) + else: + x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask))) + x = self.norm2(x + self.gamma_2(self._ff_block(x))) + + return x + + # self-attention block + def _ca_block(self, q, k, attn_mask=None): + x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x): + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + def _get_activation_fn(self, activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + +# ----------------- MULTI-BLOCKS MODELS: ----------------------- + + +class CrossTransformerEncoder(nn.Module): + def __init__( + self, + dim: int, + emb: str = "sin", + hidden_scale: float = 4.0, + num_heads: int = 8, + num_layers: int = 6, + cross_first: bool = False, + dropout: float = 0.0, + max_positions: int = 1000, + norm_in: bool = True, + norm_in_group: bool = False, + group_norm: int = False, + norm_first: bool = False, + norm_out: bool = False, + max_period: float = 10000.0, + weight_decay: float = 0.0, + lr: tp.Optional[float] = None, + layer_scale: bool = False, + gelu: bool = True, + sin_random_shift: int = 0, + weight_pos_embed: float = 1.0, + cape_mean_normalize: bool = True, + cape_augment: bool = True, + cape_glob_loc_scale: list = [5000.0, 1.0, 1.4], + sparse_self_attn: bool = False, + sparse_cross_attn: bool = False, + mask_type: str = "diag", + mask_random_seed: int = 42, + sparse_attn_window: int = 500, + global_window: int = 50, + auto_sparsity: bool = False, + sparsity: float = 0.95, + ): + super().__init__() + """ + """ + assert dim % num_heads == 0 + + hidden_dim = int(dim * hidden_scale) + + self.num_layers = num_layers + # classic parity = 1 means that if idx%2 == 1 there is a + # classical encoder else there is a cross encoder + self.classic_parity = 1 if cross_first else 0 + self.emb = emb + self.max_period = max_period + self.weight_decay = weight_decay + self.weight_pos_embed = weight_pos_embed + self.sin_random_shift = sin_random_shift + if emb == "cape": + self.cape_mean_normalize = cape_mean_normalize + self.cape_augment = cape_augment + self.cape_glob_loc_scale = cape_glob_loc_scale + if emb == "scaled": + self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2) + + self.lr = lr + + activation: tp.Any = F.gelu if gelu else F.relu + + self.norm_in: nn.Module + self.norm_in_t: nn.Module + if norm_in: + self.norm_in = nn.LayerNorm(dim) + self.norm_in_t = nn.LayerNorm(dim) + elif norm_in_group: + self.norm_in = MyGroupNorm(int(norm_in_group), dim) + self.norm_in_t = MyGroupNorm(int(norm_in_group), dim) + else: + self.norm_in = nn.Identity() + self.norm_in_t = nn.Identity() + + # spectrogram layers + self.layers = nn.ModuleList() + # temporal layers + self.layers_t = nn.ModuleList() + + kwargs_common = { + "d_model": dim, + "nhead": num_heads, + "dim_feedforward": hidden_dim, + "dropout": dropout, + "activation": activation, + "group_norm": group_norm, + "norm_first": norm_first, + "norm_out": norm_out, + "layer_scale": layer_scale, + "mask_type": mask_type, + "mask_random_seed": mask_random_seed, + "sparse_attn_window": sparse_attn_window, + "global_window": global_window, + "sparsity": sparsity, + "auto_sparsity": auto_sparsity, + "batch_first": True, + } + + kwargs_classic_encoder = dict(kwargs_common) + kwargs_classic_encoder.update({ + "sparse": sparse_self_attn, + }) + kwargs_cross_encoder = dict(kwargs_common) + kwargs_cross_encoder.update({ + "sparse": sparse_cross_attn, + }) + + for idx in range(num_layers): + if idx % 2 == self.classic_parity: + + self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder)) + self.layers_t.append( + MyTransformerEncoderLayer(**kwargs_classic_encoder) + ) + + else: + self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) + + self.layers_t.append( + CrossTransformerEncoderLayer(**kwargs_cross_encoder) + ) + + def forward(self, x, xt): + B, C, Fr, T1 = x.shape + pos_emb_2d = create_2d_sin_embedding( + C, Fr, T1, x.device, self.max_period + ) # (1, C, Fr, T1) + pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c") + x = rearrange(x, "b c fr t1 -> b (t1 fr) c") + x = self.norm_in(x) + x = x + self.weight_pos_embed * pos_emb_2d + + B, C, T2 = xt.shape + xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C + pos_emb = self._get_pos_embedding(T2, B, C, x.device) + pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c") + xt = self.norm_in_t(xt) + xt = xt + self.weight_pos_embed * pos_emb + + for idx in range(self.num_layers): + if idx % 2 == self.classic_parity: + x = self.layers[idx](x) + xt = self.layers_t[idx](xt) + else: + old_x = x + x = self.layers[idx](x, xt) + xt = self.layers_t[idx](xt, old_x) + + x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1) + xt = rearrange(xt, "b t2 c -> b c t2") + return x, xt + + def _get_pos_embedding(self, T, B, C, device): + if self.emb == "sin": + shift = random.randrange(self.sin_random_shift + 1) + pos_emb = create_sin_embedding( + T, C, shift=shift, device=device, max_period=self.max_period + ) + elif self.emb == "cape": + if self.training: + pos_emb = create_sin_embedding_cape( + T, + C, + B, + device=device, + max_period=self.max_period, + mean_normalize=self.cape_mean_normalize, + augment=self.cape_augment, + max_global_shift=self.cape_glob_loc_scale[0], + max_local_shift=self.cape_glob_loc_scale[1], + max_scale=self.cape_glob_loc_scale[2], + ) + else: + pos_emb = create_sin_embedding_cape( + T, + C, + B, + device=device, + max_period=self.max_period, + mean_normalize=self.cape_mean_normalize, + augment=False, + ) + + elif self.emb == "scaled": + pos = torch.arange(T, device=device) + pos_emb = self.position_embeddings(pos)[:, None] + + return pos_emb + + def make_optim_group(self): + group = {"params": list(self.parameters()), "weight_decay": self.weight_decay} + if self.lr is not None: + group["lr"] = self.lr + return group + + +# Attention Modules + + +class MultiheadAttention(nn.Module): + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + auto_sparsity=None, + ): + super().__init__() + assert auto_sparsity is not None, "sanity check" + self.num_heads = num_heads + self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias) + self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias) + self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias) + self.attn_drop = torch.nn.Dropout(dropout) + self.proj = torch.nn.Linear(embed_dim, embed_dim, bias) + self.proj_drop = torch.nn.Dropout(dropout) + self.batch_first = batch_first + self.auto_sparsity = auto_sparsity + + def forward( + self, + query, + key, + value, + key_padding_mask=None, + need_weights=True, + attn_mask=None, + average_attn_weights=True, + ): + + if not self.batch_first: # N, B, C + query = query.permute(1, 0, 2) # B, N_q, C + key = key.permute(1, 0, 2) # B, N_k, C + value = value.permute(1, 0, 2) # B, N_k, C + B, N_q, C = query.shape + B, N_k, C = key.shape + + q = ( + self.q(query) + .reshape(B, N_q, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + q = q.flatten(0, 1) + k = ( + self.k(key) + .reshape(B, N_k, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + k = k.flatten(0, 1) + v = ( + self.v(value) + .reshape(B, N_k, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + v = v.flatten(0, 1) + + if self.auto_sparsity: + assert attn_mask is None + x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity) + else: + x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop) + x = x.reshape(B, self.num_heads, N_q, C // self.num_heads) + + x = x.transpose(1, 2).reshape(B, N_q, C) + x = self.proj(x) + x = self.proj_drop(x) + if not self.batch_first: + x = x.permute(1, 0, 2) + return x, None + + +def scaled_query_key_softmax(q, k, att_mask): + from xformers.ops import masked_matmul + q = q / (k.size(-1)) ** 0.5 + att = masked_matmul(q, k.transpose(-2, -1), att_mask) + att = torch.nn.functional.softmax(att, -1) + return att + + +def scaled_dot_product_attention(q, k, v, att_mask, dropout): + att = scaled_query_key_softmax(q, k, att_mask=att_mask) + att = dropout(att) + y = att @ v + return y + + +def _compute_buckets(x, R): + qq = torch.einsum('btf,bfhi->bhti', x, R) + qq = torch.cat([qq, -qq], dim=-1) + buckets = qq.argmax(dim=-1) + + return buckets.permute(0, 2, 1).byte().contiguous() + + +def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None): + # assert False, "The code for the custom sparse kernel is not ready for release yet." + from xformers.ops import find_locations, sparse_memory_efficient_attention + n_hashes = 32 + proj_size = 4 + query, key, value = [x.contiguous() for x in [query, key, value]] + with torch.no_grad(): + R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device) + bucket_query = _compute_buckets(query, R) + bucket_key = _compute_buckets(key, R) + row_offsets, column_indices = find_locations( + bucket_query, bucket_key, sparsity, infer_sparsity) + return sparse_memory_efficient_attention( + query, key, value, row_offsets, column_indices, attn_bias) diff --git a/demucs/utils.py b/demucs/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4c482c102257a57b7fa40302b8a7c3d8d1962f1 --- /dev/null +++ b/demucs/utils.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from concurrent.futures import CancelledError +from contextlib import contextmanager +import math +import os +import tempfile +import typing as tp + +import torch +from torch.nn import functional as F +from torch.utils.data import Subset + + +def unfold(a, kernel_size, stride): + """Given input of size [*OT, T], output Tensor of size [*OT, F, K] + with K the kernel size, by extracting frames with the given stride. + + This will pad the input so that `F = ceil(T / K)`. + + see https://github.com/pytorch/pytorch/issues/60466 + """ + *shape, length = a.shape + n_frames = math.ceil(length / stride) + tgt_length = (n_frames - 1) * stride + kernel_size + a = F.pad(a, (0, tgt_length - length)) + strides = list(a.stride()) + assert strides[-1] == 1, 'data should be contiguous' + strides = strides[:-1] + [stride, 1] + return a.as_strided([*shape, n_frames, kernel_size], strides) + + +def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): + """ + Center trim `tensor` with respect to `reference`, along the last dimension. + `reference` can also be a number, representing the length to trim to. + If the size difference != 0 mod 2, the extra sample is removed on the right side. + """ + ref_size: int + if isinstance(reference, torch.Tensor): + ref_size = reference.size(-1) + else: + ref_size = reference + delta = tensor.size(-1) - ref_size + if delta < 0: + raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") + if delta: + tensor = tensor[..., delta // 2:-(delta - delta // 2)] + return tensor + + +def pull_metric(history: tp.List[dict], name: str): + out = [] + for metrics in history: + metric = metrics + for part in name.split("."): + metric = metric[part] + out.append(metric) + return out + + +def EMA(beta: float = 1): + """ + Exponential Moving Average callback. + Returns a single function that can be called to repeatidly update the EMA + with a dict of metrics. The callback will return + the new averaged dict of metrics. + + Note that for `beta=1`, this is just plain averaging. + """ + fix: tp.Dict[str, float] = defaultdict(float) + total: tp.Dict[str, float] = defaultdict(float) + + def _update(metrics: dict, weight: float = 1) -> dict: + nonlocal total, fix + for key, value in metrics.items(): + total[key] = total[key] * beta + weight * float(value) + fix[key] = fix[key] * beta + weight + return {key: tot / fix[key] for key, tot in total.items()} + return _update + + +def sizeof_fmt(num: float, suffix: str = 'B'): + """ + Given `num` bytes, return human readable size. + Taken from https://stackoverflow.com/a/1094933 + """ + for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: + if abs(num) < 1024.0: + return "%3.1f%s%s" % (num, unit, suffix) + num /= 1024.0 + return "%.1f%s%s" % (num, 'Yi', suffix) + + +@contextmanager +def temp_filenames(count: int, delete=True): + names = [] + try: + for _ in range(count): + names.append(tempfile.NamedTemporaryFile(delete=False).name) + yield names + finally: + if delete: + for name in names: + os.unlink(name) + + +def random_subset(dataset, max_samples: int, seed: int = 42): + if max_samples >= len(dataset): + return dataset + + generator = torch.Generator().manual_seed(seed) + perm = torch.randperm(len(dataset), generator=generator) + return Subset(dataset, perm[:max_samples].tolist()) + + +class DummyPoolExecutor: + class DummyResult: + def __init__(self, func, _dict, *args, **kwargs): + self.func = func + self._dict = _dict + self.args = args + self.kwargs = kwargs + + def result(self): + if self._dict["run"]: + return self.func(*self.args, **self.kwargs) + else: + raise CancelledError() + + def __init__(self, workers=0): + self._dict = {"run": True} + + def submit(self, func, *args, **kwargs): + return DummyPoolExecutor.DummyResult(func, self._dict, *args, **kwargs) + + def shutdown(self, *_, **__): + self._dict["run"] = False + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + return diff --git a/demucs/wav.py b/demucs/wav.py new file mode 100644 index 0000000000000000000000000000000000000000..977a6cf3fc49b2e25feb3c5bf1be1bb4f61eb0c1 --- /dev/null +++ b/demucs/wav.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Loading wav based datasets, including MusdbHQ.""" + +from collections import OrderedDict +import hashlib +import math +import json +import os +from pathlib import Path +import tqdm + +import musdb +import julius +import torch as th +from torch import distributed +import torchaudio as ta +from torch.nn import functional as F + +from .audio import convert_audio_channels +from . import distrib + +MIXTURE = "mixture" +EXT = ".wav" + + +def _track_metadata(track, sources, normalize=True, ext=EXT): + track_length = None + track_samplerate = None + mean = 0 + std = 1 + for source in sources + [MIXTURE]: + file = track / f"{source}{ext}" + if source == MIXTURE and not file.exists(): + audio = 0 + for sub_source in sources: + sub_file = track / f"{sub_source}{ext}" + sub_audio, sr = ta.load(sub_file) + audio += sub_audio + would_clip = audio.abs().max() >= 1 + if would_clip: + assert ta.get_audio_backend() == 'soundfile', 'use dset.backend=soundfile' + ta.save(file, audio, sr, encoding='PCM_F') + + try: + info = ta.info(str(file)) + except RuntimeError: + print(file) + raise + length = info.num_frames + if track_length is None: + track_length = length + track_samplerate = info.sample_rate + elif track_length != length: + raise ValueError( + f"Invalid length for file {file}: " + f"expecting {track_length} but got {length}.") + elif info.sample_rate != track_samplerate: + raise ValueError( + f"Invalid sample rate for file {file}: " + f"expecting {track_samplerate} but got {info.sample_rate}.") + if source == MIXTURE and normalize: + try: + wav, _ = ta.load(str(file)) + except RuntimeError: + print(file) + raise + wav = wav.mean(0) + mean = wav.mean().item() + std = wav.std().item() + + return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate} + + +def build_metadata(path, sources, normalize=True, ext=EXT): + """ + Build the metadata for `Wavset`. + + Args: + path (str or Path): path to dataset. + sources (list[str]): list of sources to look for. + normalize (bool): if True, loads full track and store normalization + values based on the mixture file. + ext (str): extension of audio files (default is .wav). + """ + + meta = {} + path = Path(path) + pendings = [] + from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor(8) as pool: + for root, folders, files in os.walk(path, followlinks=True): + root = Path(root) + if root.name.startswith('.') or folders or root == path: + continue + name = str(root.relative_to(path)) + pendings.append((name, pool.submit(_track_metadata, root, sources, normalize, ext))) + # meta[name] = _track_metadata(root, sources, normalize, ext) + for name, pending in tqdm.tqdm(pendings, ncols=120): + meta[name] = pending.result() + return meta + + +class Wavset: + def __init__( + self, + root, metadata, sources, + segment=None, shift=None, normalize=True, + samplerate=44100, channels=2, ext=EXT): + """ + Waveset (or mp3 set for that matter). Can be used to train + with arbitrary sources. Each track should be one folder inside of `path`. + The folder should contain files named `{source}.{ext}`. + + Args: + root (Path or str): root folder for the dataset. + metadata (dict): output from `build_metadata`. + sources (list[str]): list of source names. + segment (None or float): segment length in seconds. If `None`, returns entire tracks. + shift (None or float): stride in seconds bewteen samples. + normalize (bool): normalizes input audio, **based on the metadata content**, + i.e. the entire track is normalized, not individual extracts. + samplerate (int): target sample rate. if the file sample rate + is different, it will be resampled on the fly. + channels (int): target nb of channels. if different, will be + changed onthe fly. + ext (str): extension for audio files (default is .wav). + + samplerate and channels are converted on the fly. + """ + self.root = Path(root) + self.metadata = OrderedDict(metadata) + self.segment = segment + self.shift = shift or segment + self.normalize = normalize + self.sources = sources + self.channels = channels + self.samplerate = samplerate + self.ext = ext + self.num_examples = [] + for name, meta in self.metadata.items(): + track_duration = meta['length'] / meta['samplerate'] + if segment is None or track_duration < segment: + examples = 1 + else: + examples = int(math.ceil((track_duration - self.segment) / self.shift) + 1) + self.num_examples.append(examples) + + def __len__(self): + return sum(self.num_examples) + + def get_file(self, name, source): + return self.root / name / f"{source}{self.ext}" + + def __getitem__(self, index): + for name, examples in zip(self.metadata, self.num_examples): + if index >= examples: + index -= examples + continue + meta = self.metadata[name] + num_frames = -1 + offset = 0 + if self.segment is not None: + offset = int(meta['samplerate'] * self.shift * index) + num_frames = int(math.ceil(meta['samplerate'] * self.segment)) + wavs = [] + for source in self.sources: + file = self.get_file(name, source) + wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames) + wav = convert_audio_channels(wav, self.channels) + wavs.append(wav) + + example = th.stack(wavs) + example = julius.resample_frac(example, meta['samplerate'], self.samplerate) + if self.normalize: + example = (example - meta['mean']) / meta['std'] + if self.segment: + length = int(self.segment * self.samplerate) + example = example[..., :length] + example = F.pad(example, (0, length - example.shape[-1])) + return example + + +def get_wav_datasets(args, name='wav'): + """Extract the wav datasets from the XP arguments.""" + path = getattr(args, name) + sig = hashlib.sha1(str(path).encode()).hexdigest()[:8] + metadata_file = Path(args.metadata) / ('wav_' + sig + ".json") + train_path = Path(path) / "train" + valid_path = Path(path) / "valid" + if not metadata_file.is_file() and distrib.rank == 0: + metadata_file.parent.mkdir(exist_ok=True, parents=True) + train = build_metadata(train_path, args.sources) + valid = build_metadata(valid_path, args.sources) + json.dump([train, valid], open(metadata_file, "w")) + if distrib.world_size > 1: + distributed.barrier() + train, valid = json.load(open(metadata_file)) + if args.full_cv: + kw_cv = {} + else: + kw_cv = {'segment': args.segment, 'shift': args.shift} + train_set = Wavset(train_path, train, args.sources, + segment=args.segment, shift=args.shift, + samplerate=args.samplerate, channels=args.channels, + normalize=args.normalize) + valid_set = Wavset(valid_path, valid, [MIXTURE] + list(args.sources), + samplerate=args.samplerate, channels=args.channels, + normalize=args.normalize, **kw_cv) + return train_set, valid_set + + +def _get_musdb_valid(): + # Return musdb valid set. + import yaml + setup_path = Path(musdb.__path__[0]) / 'configs' / 'mus.yaml' + setup = yaml.safe_load(open(setup_path, 'r')) + return setup['validation_tracks'] + + +def get_musdb_wav_datasets(args): + """Extract the musdb dataset from the XP arguments.""" + sig = hashlib.sha1(str(args.musdb).encode()).hexdigest()[:8] + metadata_file = Path(args.metadata) / ('musdb_' + sig + ".json") + root = Path(args.musdb) / "train" + if not metadata_file.is_file() and distrib.rank == 0: + metadata_file.parent.mkdir(exist_ok=True, parents=True) + metadata = build_metadata(root, args.sources) + json.dump(metadata, open(metadata_file, "w")) + if distrib.world_size > 1: + distributed.barrier() + metadata = json.load(open(metadata_file)) + + valid_tracks = _get_musdb_valid() + if args.train_valid: + metadata_train = metadata + else: + metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks} + metadata_valid = {name: meta for name, meta in metadata.items() if name in valid_tracks} + if args.full_cv: + kw_cv = {} + else: + kw_cv = {'segment': args.segment, 'shift': args.shift} + train_set = Wavset(root, metadata_train, args.sources, + segment=args.segment, shift=args.shift, + samplerate=args.samplerate, channels=args.channels, + normalize=args.normalize) + valid_set = Wavset(root, metadata_valid, [MIXTURE] + list(args.sources), + samplerate=args.samplerate, channels=args.channels, + normalize=args.normalize, **kw_cv) + return train_set, valid_set diff --git a/demucs/wdemucs.py b/demucs/wdemucs.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b0552c1d485384ba634a1b4234da36262d9369 --- /dev/null +++ b/demucs/wdemucs.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# For compat +from .hdemucs import HDemucs + +WDemucs = HDemucs diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 0000000000000000000000000000000000000000..f60ce94e74c49fa6ec788571042fcf346023398b --- /dev/null +++ b/docs/api.md @@ -0,0 +1,204 @@ +# Demucs APIs + +## Quick start + +Notes: Type hints have been added to all API functions. It is recommended to check them before passing parameters to a function as some arguments only support limited types (e.g. parameter `repo` of method `load_model` only support type `pathlib.Path`). + +1. The first step is to import api module: + +```python +import demucs.api +``` + +2. Then initialize the `Separator`. Parameters which will be served as default values for methods can be passed. Model should be specified. + +```python +# Initialize with default parameters: +separator = demucs.api.Separator() + +# Use another model and segment: +separator = demucs.api.Separator(model="mdx_extra", segment=12) + +# You can also use other parameters defined +``` + +3. Separate it! + +```python +# Separating an audio file +origin, separated = separator.separate_audio_file("file.mp3") + +# Separating a loaded audio +origin, separated = separator.separate_tensor(audio) + +# If you encounter an error like CUDA out of memory, you can use this to change parameters like `segment`: +separator.update_parameter(segment=smaller_segment) +``` + +4. Save audio + +```python +# Remember to create the destination folder before calling `save_audio` +# Or you are likely to recieve `FileNotFoundError` +for file, sources in separated: + for stem, source in sources.items(): + demucs.api.save_audio(source, f"{stem}_{file}", samplerate=separator.samplerate) +``` + +## API References + +The types of each parameter and return value is not listed in this document. To know the exact type of them, please read the type hints in api.py (most modern code editors support inferring types based on type hints). + +### `class Separator` + +The base separator class + +##### Parameters + +model: Pretrained model name or signature. Default is htdemucs. + +repo: Folder containing all pre-trained models for use. + +segment: Length (in seconds) of each segment (only available if `split` is `True`). If not specified, will use the command line option. + +shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and apply the oppositve shift to the output. This is repeated `shifts` time and all predictions are averaged. This effectively makes the model time equivariant and improves SDR by up to 0.2 points. If not specified, will use the command line option. + +split: If True, the input will be broken down into small chunks (length set by `segment`) and predictions will be performed individually on each and concatenated. Useful for model with large memory footprint like Tasnet. If not specified, will use the command line option. + +overlap: The overlap between the splits. If not specified, will use the command line option. + +device (torch.device, str, or None): If provided, device on which to execute the computation, otherwise `wav.device` is assumed. When `device` is different from `wav.device`, only local computations will be on `device`, while the entire tracks will be stored on `wav.device`. If not specified, will use the command line option. + +jobs: Number of jobs. This can increase memory usage but will be much faster when multiple cores are available. If not specified, will use the command line option. + +callback: A function will be called when the separation of a chunk starts or finished. The argument passed to the function will be a dict. For more information, please see the Callback section. + +callback_arg: A dict containing private parameters to be passed to callback function. For more information, please see the Callback section. + +progress: If true, show a progress bar. + +##### Notes for callback + +The function will be called with only one positional parameter whose type is `dict`. The `callback_arg` will be combined with information of current separation progress. The progress information will override the values in `callback_arg` if same key has been used. To abort the separation, raise an exception in `callback` which should be handled by yourself if you want your codes continue to function. + +Progress information contains several keys (These keys will always exist): +- `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. +- `shift_idx`: The index of shifts. Starts from 0. +- `segment_offset`: The offset of current segment. If the number is 441000, it doesn't mean that it is at the 441000 second of the audio, but the "frame" of the tensor. +- `state`: Could be `"start"` or `"end"`. +- `audio_length`: Length of the audio (in "frame" of the tensor). +- `models`: Count of submodels in the model. + +#### `property samplerate` + +A read-only property saving sample rate of the model requires. Will raise a warning if the model is not loaded and return the default value. + +#### `property audio_channels` + +A read-only property saving audio channels of the model requires. Will raise a warning if the model is not loaded and return the default value. + +#### `property model` + +A read-only property saving the model. + +#### `method update_parameter()` + +Update the parameters of separation. + +##### Parameters + +segment: Length (in seconds) of each segment (only available if `split` is `True`). If not specified, will use the command line option. + +shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and apply the oppositve shift to the output. This is repeated `shifts` time and all predictions are averaged. This effectively makes the model time equivariant and improves SDR by up to 0.2 points. If not specified, will use the command line option. + +split: If True, the input will be broken down into small chunks (length set by `segment`) and predictions will be performed individually on each and concatenated. Useful for model with large memory footprint like Tasnet. If not specified, will use the command line option. + +overlap: The overlap between the splits. If not specified, will use the command line option. + +device (torch.device, str, or None): If provided, device on which to execute the computation, otherwise `wav.device` is assumed. When `device` is different from `wav.device`, only local computations will be on `device`, while the entire tracks will be stored on `wav.device`. If not specified, will use the command line option. + +jobs: Number of jobs. This can increase memory usage but will be much faster when multiple cores are available. If not specified, will use the command line option. + +callback: A function will be called when the separation of a chunk starts or finished. The argument passed to the function will be a dict. For more information, please see the Callback section. + +callback_arg: A dict containing private parameters to be passed to callback function. For more information, please see the Callback section. + +progress: If true, show a progress bar. + +##### Notes for callback + +The function will be called with only one positional parameter whose type is `dict`. The `callback_arg` will be combined with information of current separation progress. The progress information will override the values in `callback_arg` if same key has been used. To abort the separation, raise an exception in `callback` which should be handled by yourself if you want your codes continue to function. + +Progress information contains several keys (These keys will always exist): +- `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. +- `shift_idx`: The index of shifts. Starts from 0. +- `segment_offset`: The offset of current segment. If the number is 441000, it doesn't mean that it is at the 441000 second of the audio, but the "frame" of the tensor. +- `state`: Could be `"start"` or `"end"`. +- `audio_length`: Length of the audio (in "frame" of the tensor). +- `models`: Count of submodels in the model. + +#### `method separate_tensor()` + +Separate an audio. + +##### Parameters + +wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, while the second is the waveform of each channel. e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels. + +sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the model. + +##### Returns + +A tuple, whose first element is the original wave and second element is a dict, whose keys are the name of stems and values are separated waves. The original wave will have already been resampled. + +##### Notes + +Use this function with cautiousness. This function does not provide data verifying. + +#### `method separate_audio_file()` + +Separate an audio file. The method will automatically read the file. + +##### Parameters + +wav: Path of the file to be separated. + +##### Returns + +A tuple, whose first element is the original wave and second element is a dict, whose keys are the name of stems and values are separated waves. The original wave will have already been resampled. + +### `function save_audio()` + +Save audio file. + +##### Parameters + +wav: Audio to be saved + +path: The file path to be saved. Ending must be one of `.mp3` and `.wav`. + +samplerate: File sample rate. + +bitrate: If the suffix of `path` is `.mp3`, it will be used to specify the bitrate of mp3. + +clip: Clipping preventing strategy. + +bits_per_sample: If the suffix of `path` is `.wav`, it will be used to specify the bit depth of wav. + +as_float: If it is True and the suffix of `path` is `.wav`, then `bits_per_sample` will be set to 32 and will write the wave file with float format. + +##### Returns + +None + +### `function list_models()` + +List the available models. Please remember that not all the returned models can be successfully loaded. + +##### Parameters + +repo: The repo whose models are to be listed. + +##### Returns + +A dict with two keys ("single" for single models and "bag" for bag of models). The values are lists whose components are strs. \ No newline at end of file diff --git a/docs/linux.md b/docs/linux.md new file mode 100644 index 0000000000000000000000000000000000000000..7ddb8605134ace4575be70e98c918871881540b7 --- /dev/null +++ b/docs/linux.md @@ -0,0 +1,28 @@ +# Linux support for Demucs + +If your distribution has at least Python 3.8, and you just wish to separate +tracks with Demucs, not train it, you can just run + +```bash +pip3 install --user -U demucs +# Then anytime you want to use demucs, just do +python3 -m demucs -d cpu PATH_TO_AUDIO_FILE_1 +# If you have added the user specific pip bin/ folder to your path, you can also do +demucs -d cpu PATH_TO_AUDIO_FILE_1 +``` + +If Python is too old, or you want to be able to train, I recommend [installing Miniconda][miniconda], with Python 3.8 or more. + +```bash +conda activate +pip3 install -U demucs +# Then anytime you want to use demucs, first do conda activate, then +demucs -d cpu PATH_TO_AUDIO_FILE_1 +``` + +Of course, you can also use a specific env for Demucs. + +**Important, torchaudio 0.12 update:** Torchaudio no longer supports decoding mp3s without ffmpeg installed. You must have ffmpeg installed, either through Anaconda (`conda install ffmpeg -c conda-forge`) or as a distribution package (e.g. `sudo apt-get install ffmpeg`). + + +[miniconda]: https://docs.conda.io/en/latest/miniconda.html#linux-installers diff --git a/docs/mac.md b/docs/mac.md new file mode 100644 index 0000000000000000000000000000000000000000..298cf3cd416f3f55fb1dc91dd997ba71ff95a2f4 --- /dev/null +++ b/docs/mac.md @@ -0,0 +1,28 @@ +# macOS support for Demucs + +If you have a sufficiently recent version of macOS, you can just run + +```bash +python3 -m pip install --user -U demucs +# Then anytime you want to use demucs, just do +python3 -m demucs -d cpu PATH_TO_AUDIO_FILE_1 +# If you have added the user specific pip bin/ folder to your path, you can also do +demucs -d cpu PATH_TO_AUDIO_FILE_1 +``` + +If you do not already have Anaconda installed or much experience with the terminal on macOS, here are some detailed instructions: + +1. Download [Anaconda 3.8 (or more recent) 64-bit for macOS][anaconda]: +2. Open [Anaconda Prompt in macOS][prompt] +3. Follow these commands: +```bash +conda activate +pip3 install -U demucs +# Then anytime you want to use demucs, first do conda activate, then +demucs -d cpu PATH_TO_AUDIO_FILE_1 +``` + +**Important, torchaudio 0.12 update:** Torchaudio no longer supports decoding mp3s without ffmpeg installed. You must have ffmpeg installed, either through Anaconda (`conda install ffmpeg -c conda-forge`) or with Homebrew for instance (`brew install ffmpeg`). + +[anaconda]: https://www.anaconda.com/download +[prompt]: https://docs.anaconda.com/anaconda/user-guide/getting-started/#open-nav-mac diff --git a/docs/mdx.md b/docs/mdx.md new file mode 100644 index 0000000000000000000000000000000000000000..25318d9375ffb5dd03eb6003a254194f9244a8e5 --- /dev/null +++ b/docs/mdx.md @@ -0,0 +1,73 @@ +# Music DemiXing challenge (MDX) + +If you want to use Demucs for the [MDX challenge](https://www.aicrowd.com/challenges/music-demixing-challenge-ismir-2021), +please follow the instructions hereafter + +## Installing Demucs + +Follow the instructions from the [main README](https://github.com/facebookresearch/demucs#requirements) +in order to setup Demucs using Anaconda. You will need the full setup up for training, including soundstretch. + +## Getting MusDB-HQ + +Download [MusDB-HQ](https://zenodo.org/record/3338373) to some folder and unzip it. + +## Training Demucs + +Train Demucs (you might need to change the batch size depending on the number of GPUs available). +It seems 48 channels is enough to get the best performance on MusDB-HQ, and training will faster +and less memory demanding. In any case, the 64 channels versions is timing out on the challenge. +```bash +./run.py --channels=48 --batch_size 64 --musdb=PATH_TO_MUSDB --is_wav [EXTRA_FLAGS] +``` + +### Post training + +Once the training is completed, a new model file will be exported in `models/`. + +You can look at the SDR on the MusDB dataset using `python result_table.py`. + + +### Evaluate and export a model before training is over + +If you want to export a model before training is complete, use the following command: +```bash +python -m demucs [ALL EXACT TRAINING FLAGS] --save_model +``` +You can also pass the `--half` flag, in order to save weights in half precision. This will divide the model size by 2 and won't impact SDR. + +Once this is done, you can partially evaluate a model with +```bash +./run.py --test NAME_OF_MODEL.th --musdb=PATH_TO_MUSDB --is_wav +``` + +**Note:** `NAME_OF_MODEL.th` is given relative to the models folder (given by `--models`, defaults to `models/`), so don't include it in the name. + + +### Training smaller models + +If you want to quickly test idea, I would recommend training a 16 kHz model, and testing if things work there or not, before training the full 44kHz model. You can train one of those with +```bash +./run.py --channels=32 --samplerate 16000 --samples 160000 --data_stride 16000 --depth=5 --batch_size 64 --repitch=0 --musdb=PATH_TO_MUSDB --is_wav [EXTRA_FLAGS] +``` +(repitch must be turned off, because things will break at 16kHz). + +## Submitting your model + +1. Git clone [the Music Demixing Challenge - Starter Kit - Demucs Edition](https://github.com/adefossez/music-demixing-challenge-starter-kit). +2. Inside the starter kit, create a `models/` folder and copy over the trained model from the Demucs repo (renaming +it for instance `my_model.th`) +3. Inside the `test_demuc.py` file, change the function `prediction_setup`: comment the loading +of the pre-trained model, and uncomment the code to load your own model. +4. Edit the file `aicrowd.json` with your username. +5. Install [git-lfs](https://git-lfs.github.com/). Then run + +```bash +git lfs install +git add models/ +git add -u . +git commit -m "My Demucs submission" +``` +6. Follow the [submission instructions](https://github.com/AIcrowd/music-demixing-challenge-starter-kit/blob/master/docs/SUBMISSION.md). + +Best of luck 🤞 diff --git a/docs/release.md b/docs/release.md new file mode 100644 index 0000000000000000000000000000000000000000..f322d005904ce97bf13d8d4564c4a89b08d134ba --- /dev/null +++ b/docs/release.md @@ -0,0 +1,112 @@ +# Release notes for Demucs + +## V4.1.0a1, TBD + +Get models list + +Check segment of HTDemucs inside BagOfModels + +Added api.py to be called from another program + +Use api in separate.py + +Added `--other-method`: method to get `no_{STEM}`, add up all the other stems (add), original track substract the specific stem (minus), and discard (none) + +Added type `HTDemucs` to type alias `AnyModel`. + +## V4.0.1, 8th of September 2023 + +**From this version, Python 3.7 is no longer supported. This is not a problem since the latest PyTorch 2.0.0 no longer support it either.** + +Various improvements by @CarlGao4. Support for `segment` param inside of HTDemucs +model. + +Made diffq an optional dependency, with an error message if not installed. + +Added output format flac (Free Lossless Audio Codec) + +Will use CPU for complex numbers, when using MPS device (all other computations are performed by mps). + +Optimize codes to save memory + +Allow changing preset of MP3 + +## V4.0.0, 7th of December 2022 + +Adding hybrid transformer Demucs model. + +Added support for [Torchaudio implementation of HDemucs](https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html), thanks @skim0514. + +Added experimental 6 sources model `htdemucs_6s` (`drums`, `bass`, `other`, `vocals`, `piano`, `guitar`). + +## V3.0.6, 16th of November 2022 + +Option to customize output path of stems (@CarlGao4) + +Fixed bug in pad1d leading to failure sometimes. + +## V3.0.5, 17th of August 2022 + +Added `--segment` flag to customize the segment length and use less memory (thanks @CarlGao4). + +Fix reflect padding bug on small inputs. + +Compatible with pyTorch 1.12 + +## V3.0.4, 24th of February 2022 + +Added option to split into two stems (i.e. vocals, vs. non vocals), thanks to @CarlGao4. + +Added `--float32`, `--int24` and `--clip-mode` options to customize how output stems are saved. + +## V3.0.3, 2nd of December 2021 + +Fix bug in weights used for different sources. Thanks @keunwoochoi for the report and fix. + +Improving drastically memory usage on GPU for long files. Thanks a lot @famzah for providing this. + +Adding multithread evaluation on CPU (`-j` option). + +(v3.0.2 had a bug with the CPU pool and is skipped.) + +## V3.0.1, 12th of November 2021 + +Release of Demucs v3, featuring hybrid domain separation and much more. +This drops support for Conv-Tasnet and training on the non HQ MusDB dataset. +There is no version 3.0.0 because I messed up. + +## V2.0.2, 26th of May 2021 + +- Fix in Tasnet (PR #178) +- Use ffmpeg in priority when available instead of torchaudio to avoid small shift in MP3 data. +- other minor fixes + +## v2.0.1, 11th of May 2021 + +MusDB HQ support added. Custom wav dataset support added. +Minor changes: issue with padding of mp3 and torchaudio reading, in order to limit that, +Demucs now uses ffmpeg in priority and fallback to torchaudio. +Replaced pre-trained demucs model with one trained on more recent codebase. + +## v2.0.0, 28th of April 2021 + +This is a big release, with at lof of breaking changes. You will likely +need to install Demucs from scratch. + + + +- Demucs now supports on the fly resampling by a factor of 2. +This improves SDR almost 0.3 points. +- Random scaling of each source added (From Uhlich et al. 2017). +- Random pitch and tempo augmentation addded, from [Cohen-Hadria et al. 2019]. +- With extra augmentation, the best performing Demucs model now has only 64 channels +instead of 100, so model size goes from 2.4GB to 1GB. Also SDR is up from 5.6 SDR to 6.3 when trained only on MusDB. +- Quantized model using [DiffQ](https://github.com/facebookresearch/diffq) has been added. Model size is 150MB, no loss in quality as far as I, or the metrics, +can say. +- Pretrained models are now using the TorchHub interface. +- Overlap mode for separation, to limit inconsitencies at + frame boundaries, with linear transition over the overlap. Overlap is currently + at 25%. Not that this is only done for separation, not training, because + I added that quite late to the code. For Conv-TasNet this can improve + SDR quite a bit (+0.3 points, to 6.0). +- PyPI hosting, for separation, not training! diff --git a/docs/sdx23.md b/docs/sdx23.md new file mode 100644 index 0000000000000000000000000000000000000000..2df010650c4869014d246dc77793cc7164e952ca --- /dev/null +++ b/docs/sdx23.md @@ -0,0 +1,61 @@ +# SDX 23 challenge + +Checkout [the challenge page](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023) +for more information. This page is specifically on training models for the [MDX'23 sub-challenge](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023/problems/music-demixing-track-mdx-23). +There are two tracks: one trained on a dataset with bleeding, and the other with label mixups. + +This gives instructions on training an Hybrid Demucs model on those datasets. +I haven't tried the HT Demucs model, as it typically requires quite a bit of training data but the same could be done with it. + +You will need to work from an up to date clone of this repo. See the [generic training instructions](./training.md) for more information. + +## Getting the data + +Register on the challenge, then checkout the [Resources page](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023/problems/music-demixing-track-mdx-23/dataset_files) and download the dataset you are +interested in. + +Update the `conf/dset/sdx23_bleeding.yaml` and `conf/dset/sdx23_labelnoise.yaml` files to point to the right path. + +**Make sure soundfile** is installed (`conda install -c conda-forge libsndfile; pip install soundfile`). + +### Create proper train / valid structure + +Demucs requires a valid set to work properly. Go to the folder where you extracted the tracks then do: + +```shell +mkdir train +mv * train # there will be a warning saying cannot move train to itself but that's fine the other tracks should have. +mkdir valid +cd train +mv 5640831d-7853-4d06-8166-988e2844b652 bc964128-da16-4e4c-af95-4d1211e78c70 \ + cc7f7675-d3c8-4a49-a2d7-a8959b694004 f40ffd10-4e8b-41e6-bd8a-971929ca9138 \ + bc1f2967-f834-43bd-aadc-95afc897cfe7 cc3e4991-6cce-40fe-a917-81a4fbb92ea6 \ + ed90a89a-bf22-444d-af3d-d9ac3896ebd2 f4b735de-14b1-4091-a9ba-c8b30c0740a7 ../valid +``` + +## Training + +See `dora grid sdx23` for a starting point. You can do `dora grid sdx23 --init --dry_run` then `dora run -f SIG -d` with `SIG` one of the signature +to train on a machine with GPUs if you do not have a SLURM cluster. + +Keep in mind that the valid tracks and train tracks are corrupted in different ways for those tasks, so do not expect +the valid loss to go down as smoothly as with normal training on the clean MusDB. + +I only trained Hybrid Demucs baselines as Hybrid Transformer typically requires more data. + + +## Exporting models + +Run +``` +python -m tools.export SIG +``` + +This will export the trained model into the `release_models` folder. + +## Submitting a model + +Clone the [Demucs Starter Kit for SDX23](https://github.com/adefossez/sdx23). Follow the instructions there. + +You will to copy the models under `release_models` in the `sdx23/models/` folder before you can use them. +Make sure you have git-lfs properly installed and setup before adding those files to your fork of `sdx23`. diff --git a/docs/training.md b/docs/training.md new file mode 100644 index 0000000000000000000000000000000000000000..608218c15a625573af61fc2f4210d52b35a60150 --- /dev/null +++ b/docs/training.md @@ -0,0 +1,290 @@ +# Training (Hybrid) Demucs + +## Install all the dependencies + +You should install all the dependencies either with either Anaconda (using the env file `environment-cuda.yml` ) +or `pip`, with `requirements.txt`. + +## Datasets + +### MusDB HQ + +Note that we do not support MusDB non HQ training anymore. +Get the [Musdb HQ](https://zenodo.org/record/3338373) dataset, and update the path to it in two places: +- The `dset.musdb` key inside `conf/config.yaml`. +- The variable `MUSDB_PATH` inside `tools/automix.py`. + +### Create the fine tuning datasets + +**This is only for the MDX 2021 competition models** + +I use a fine tuning on a dataset crafted by remixing songs in a musically plausible way. +The automix script will make sure that BPM, first beat and pitches are aligned. +In the file `tools/automix.py`, edit `OUTPATH` to suit your setup, as well as the `MUSDB_PATH` +to point to your copy of MusDB HQ. Then run + +```bash +export NUMBA_NUM_THREADS=1; python3 -m tools.automix +``` + +**Important:** the script will show many errors, those are normals. They just indicate when two stems + do not batch due to BPM or music scale difference. + +Finally, edit the file `conf/dset/auto_mus.yaml` and replace `dset.wav` to the value of `OUTPATH`. + +If you have a custom dataset, you can also uncomment the lines `dset2 = ...` and +`dset3 = ...` to add your custom wav data and the test set of MusDB for Track B models. +You can then replace the paths in `conf/dset/auto_extra.yaml`, `conf/dset/auto_extra_test.yaml` +and `conf/dset/aetl.yaml` (this last one was using 10 mixes instead of 6 for each song). + +### Dataset metadata cache + +Datasets are scanned the first time they are used to determine the files and their durations. +If you change a dataset and need a rescan, just delete the `metadata` folder. + +## A short intro to Dora + +I use [Dora][dora] for all the of experiments (XPs) management. You should have a look at the Dora README +to learn about the tool. Here is a quick summary of what to know: + +- An XP is a unique set of hyper-parameters with a given signature. The signature is a hash of + those hyper-parameters. I will always refer to an XP with its signature, e.g. `9357e12e`. + We will see after that you can retrieve the hyper-params and re-rerun it in a single command. +- In fact, the hash is defined as a delta between the base config and the one obtained with + the config overrides you passed from the command line. + **This means you must never change the `conf/**.yaml` files directly.**, + except for editing things like paths. Changing the default values in the config files means + the XP signature won't reflect that change, and wrong checkpoints might be reused. + I know, this is annoying, but the reason is that otherwise, any change to the config file would + mean that all XPs ran so far would see their signature change. + +### Dora commands + +Run `tar xvf outputs.tar.gz`. This will initialize the Dora XP repository, so that Dora knows +which hyper-params match the signature like `9357e12e`. Once you have done that, you should be able +to run the following: + +```bash +dora info -f 81de367c # this will show the hyper-parameter used by a specific XP. + # Be careful some overrides might present twice, and the right most one + # will give you the right value for it. +dora run -d -f 81de367c # run an XP with the hyper-parameters from XP 81de367c. + # `-d` is for distributed, it will use all available GPUs. +dora run -d -f 81de367c hdemucs.channels=32 # start from the config of XP 81de367c but change some hyper-params. + # This will give you a new XP with a new signature (here 3fe9c332). +``` + +An XP runs from a specific folder based on its signature, by default under the `outputs/` folder. +You can safely interrupt a training and resume it, it will reuse any existing checkpoint, as it will +reuse the same folder. +If you made some change to the code and need to ignore a previous checkpoint you can use `dora run --clear [RUN ARGS]`. + +If you have a Slurm cluster, you can also use the `dora grid` command, e.g. `dora grid mdx`. +Please refer to the [Dora documentation][dora] for more information. + +## Hyper parameters + +Have a look at [conf/config.yaml](../conf/config.yaml) for a list of all the hyper-parameters you can override. +If you are not familiar with [Hydra](https://github.com/facebookresearch/hydra), go checkout their page +to be familiar with how to provide overrides for your trainings. + + +## Model architecture + +A number of architectures are supported. You can select one with `model=NAME`, and have a look +in [conf/config.yaml'(../conf/config.yaml) for each architecture specific hyperparams. +Those specific params will be always prefixed with the architecture name when passing the override +from the command line or in grid files. Here is the list of models: + +- demucs: original time-only Demucs. +- hdemucs: Hybrid Demucs (v3). +- torch_hdemucs: Same as Hybrid Demucs, but using [torchaudio official implementation](https://pytorch.org/audio/stable/tutorials/hybrid_demucs_tutorial.html). +- htdemucs: Hybrid Transformer Demucs (v4). + +### Storing config in files + +As mentioned earlier, you should never change the base config files. However, you can use Hydra config groups +in order to store variants you often use. If you want to create a new variant combining multiple hyper-params, +copy the file `conf/variant/example.yaml` to `conf/variant/my_variant.yaml`, and then you can use it with + +```bash +dora run -d variant=my_variant +``` + +Once you have created this file, you should not edit it once you have started training models with it. + + +## Fine tuning + +If a first model is trained, you can fine tune it with other settings (e.g. automix dataset) with + +```bash +dora run -d -f 81de367c continue_from=81de367c dset=auto_mus variant=finetune +```` + +Note that you need both `-f 81de367c` and `continue_from=81de367c`. The first one indicates +that the hyper-params of `81de367c` should be used as a starting point for the config. +The second indicates that the weights from `81de367c` should be used as a starting point for the solver. + + +## Model evaluation + +Your model will be evaluated automatically with the new SDR definition from MDX every 20 epochs. +Old style SDR (which is quite slow) will only happen at the end of training. + +## Model Export + + +In order to use your models with other commands (such as the `demucs` command for separation) you must +export it. For that run + +```bash +python3 -m tools.export 9357e12e [OTHER SIGS ...] # replace with the appropriate signatures. +``` + +The models will be stored under `release_models/`. You can use them with the `demucs` separation command with the following flags: +```bash +demucs --repo ./release_models -n 9357e12e my_track.mp3 +``` + +### Bag of models + +If you want to combine multiple models, potentially with different weights for each source, you can copy +`demucs/remote/mdx.yaml` to `./release_models/my_bag.yaml`. You can then edit the list of models (all models used should have been exported first) and the weights per source and model (list of list, outer list is over models, inner list is over sources). You can then use your bag of model as + +```bash +demucs --repo ./release_models -n my_bag my_track.mp3 +``` + +## Model evaluation + +You can evaluate any pre-trained model or bag of models using the following command: +```bash +python3 -m tools.test_pretrained -n NAME_OF_MODEL [EXTRA ARGS] +``` +where `NAME_OF_MODEL` is either the name of the bag (e.g. `mdx`, `repro_mdx_a`), +or a single Dora signature of one of the model of the bags. You can pass `EXTRA ARGS` to customize +the test options, like the number of random shifts (e.g. `test.shifts=2`). This will compute the old-style +SDR and can take quite bit of time. + +For custom models that were trained locally, you will need to indicate that you wish +to use the local model repositories, with the `--repo ./release_models` flag, e.g., +```bash +python3 -m tools.test_pretrained --repo ./release_models -n my_bag +``` + + +## API to retrieve the model + +You can retrieve officially released models in Python using the following API: +```python +from demucs import pretrained +from demucs.apply import apply_model +bag = pretrained.get_model('htdemucs') # for a bag of models or a named model + # (which is just a bag with 1 model). +model = pretrained.get_model('955717e8') # using the signature for single models. + +bag.models # list of individual models +stems = apply_model(model, mix) # apply the model to the given mix. +``` + +## Model Zoo + +### Hybrid Transformer Demucs + +The configuration for the Hybrid Transformer models are available in: + +```shell +dora grid mmi --dry_run --init +dora grid mmi_ft --dry_run --init # fined tuned on each sources. +``` + +We release in particular `955717e8`, Hybrid Transformer Demucs using 5 layers, 512 channels, 10 seconds training segment length. We also release its fine tuned version, with one model +for each source `f7e0c4bc`, `d12395a8`, `92cfc3b6`, `04573f0d` (drums, bass, other, vocals). +The model `955717e8` is also named `htdemucs`, while the bag of models is provided +as `htdemucs_ft`. + +We also release `75fc33f5`, a regular Hybrid Demucs trained on the same dataset, +available as `hdemucs_mmi`. + + + +### Models from the MDX Competition 2021 + + +Here is a short descriptions of the models used for the MDX submission, either Track A (MusDB HQ only) +or Track B (extra training data allowed). Training happen in two stage, with the second stage +being the fine tunining on the automix generated dataset. +All the fine tuned models are available on our AWS repository +(you can retrieve it with `demucs.pretrained.get_model(SIG)`). The bag of models are available +by doing `demucs.pretrained.get_model(NAME)` with `NAME` begin either `mdx` (for Track A) or `mdx_extra` +(for Track B). + +#### Track A + +The 4 models are: + +- `0d19c1c6`: fine-tuned on automix dataset from `9357e12e` +- `7ecf8ec1`: fine-tuned on automix dataset from `e312f349` +- `c511e2ab`: fine-tuned on automix dataset from `81de367c` +- `7d865c68`: fine-tuned on automix dataset from `80a68df8` + +The 4 initial models (before fine tuning are): + +- `9357e12e`: 64ch time domain only improved Demucs, with new residual branches, group norm, + and singular value penalty. +- `e312f349`: 64ch time domain only improved, with new residual branches, group norm, + and singular value penalty, trained with a loss that focus only on drums and bass. +- `81de367c`: 48ch hybrid model , with residual branches, group norm, + singular value penalty penalty and amplitude spectrogram. +- `80a68df8`: same as b5559babb but using CaC and different + random seed, as well different weigths per frequency bands in outermost layers. + +The hybrid models are combined with equal weights for all sources except for the bass. +`0d19c1c6` (time domain) is used for both drums and bass. `7ecf8ec1` is used only for the bass. + +You can see all the hyper parameters at once with (one common line for all common hyper params, and then only shows +the hyper parameters that differs), along with the DiffQ variants that are used for the `mdx_q` models: +``` +dora grid mdx --dry_run --init +dora grid mdx --dry_run --init +``` + +#### Track B + +- `e51eebcc` +- `a1d90b5c` +- `5d2d6c55` +- `cfa93e08` + +All the models are 48ch hybrid demucs with different random seeds. Two of them +are using CaC, and two are using amplitude spectrograms with masking. +All the models are combined with equal weights for all sources. + +Things are a bit messy for Track B, there was a lot of fine tuning +over different datasets. I won't describe the entire genealogy of models here, +but all the information can be accessed with the `dora info -f SIG` command. + +Similarly you can do (those will contain a few extra lines, for training without the MusDB test set as training, and extra DiffQ XPs): +``` +dora grid mdx_extra --dry_run --init +``` + +### Reproducibility and Ablation + +I updated the paper to report numbers with a more homogeneous setup than the one used for the competition. +On MusDB HQ, I still need to use a combination of time only and hybrid models to achieve the best performance. +The experiments are provided in the grids [repro.py](../demucs/grids/repro.py) and +[repro_ft._py](../demucs/grids/repro_ft.py) for the fine tuning on the realistic mix datasets. + +The new bag of models reaches an SDR of 7.64 (vs. 7.68 for the original track A model). It uses +2 time only models trained with residual branches, local attention and the SVD penalty, +along with 2 hybrid models, with the same features, and using CaC representation. +We average the performance of all the models with the same weight over all sources, unlike +what was done for the original track A model. We trained for 600 epochs, against 360 before. + +The new bag of model is available as part of the pretrained model as `repro_mdx_a`. +The time only bag is named `repro_mdx_a_time_only`, and the hybrid only `repro_mdx_a_hybrid_only`. +Checkout the paper for more information on the training. + +[dora]: https://github.com/facebookresearch/dora diff --git a/docs/windows.md b/docs/windows.md new file mode 100644 index 0000000000000000000000000000000000000000..1e57d7c60d5681dd7ba729c1c1be6bbcd01dcd22 --- /dev/null +++ b/docs/windows.md @@ -0,0 +1,67 @@ +# Windows support for Demucs + +## Installation and usage + +If you don't have much experience with Anaconda, python or the shell, here are more detailed instructions. Note that **Demucs is not supported on 32bits systems** (as Pytorch is not available there). + +- First install Anaconda with **Python 3.8** or more recent, which you can find [here][install]. +- Start the [Anaconda prompt][prompt]. + +Then, all commands that follow must be run from this prompt. + +
+ I have no coding experience and these are too difficult for me + +> Then a GUI is suitable for you. See [Demucs GUI](https://github.com/CarlGao4/Demucs-Gui) + +
+ +### If you want to use your GPU + +If you have graphic cards produced by NVIDIA with more than 2GiB of memory, you can separate tracks with GPU acceleration. To achieve this, you must install Pytorch with CUDA. If Pytorch was already installed (you already installed Demucs for instance), first run `python.exe -m pip uninstall torch torchaudio`. +Then visit [Pytorch Home Page](https://pytorch.org/get-started/locally/) and follow the guide on it to install with CUDA support. Please make sure that the version of torchaudio should no greater than 2.1 (which is the latest version when this document is written, but 2.2.0 is sure unsupported) + +### Installation + +Start the Anaconda prompt, and run the following + +```cmd +conda install -c conda-forge ffmpeg +python.exe -m pip install -U demucs SoundFile +``` + +### Upgrade + +To upgrade Demucs, simply run `python.exe -m pip install -U demucs`, from the Anaconda prompt. + +### Usage + +Then to use Demucs, just start the **Anaconda prompt** and run: +``` +demucs -d cpu "PATH_TO_AUDIO_FILE_1" ["PATH_TO_AUDIO_FILE_2" ...] +``` +The `"` around the filename are required if the path contains spaces. A simple way to input these paths is draging a file from a folder into the terminal. + +To find out the separated files, you can run this command and open the folders: +``` +explorer separated +``` + +### Separating an entire folder + +You can use the following command to separate an entire folder of mp3s for instance (replace the extension `.mp3` if needs be for other file types) +``` +cd FOLDER +for %i in (*.mp3) do (demucs -d cpu "%i") +``` + +## Potential errors + +If you have an error saying that `mkl_intel_thread.dll` cannot be found, you can try to first run +`conda install -c defaults intel-openmp -f`. Then try again to run the `demucs` command. If it still doesn't work, you can try to run first `set CONDA_DLL_SEARCH_MODIFICATION_ENABLE=1`, then again the `demucs` command and hopefully it will work 🙏. + +**If you get a permission error**, please try starting the Anaconda Prompt as administrator. + + +[install]: https://www.anaconda.com/download +[prompt]: https://docs.anaconda.com/anaconda/user-guide/getting-started/#open-prompt-win diff --git a/environment-cpu.yml b/environment-cpu.yml new file mode 100644 index 0000000000000000000000000000000000000000..bbf39284a620ffb2ba71530b17b3551897f41cbe --- /dev/null +++ b/environment-cpu.yml @@ -0,0 +1,28 @@ +name: demucs + +channels: + - pytorch + - conda-forge + +dependencies: + - python>=3.8,<3.10 + - ffmpeg>=4.2 + - pytorch>=1.8.1 + - torchaudio>=0.8 + - tqdm>=4.36 + - pip + - pip: + - diffq>=0.2 + - dora-search + - einops + - hydra-colorlog>=1.1 + - hydra-core>=1.1 + - julius>=0.2.3 + - lameenc>=1.2 + - openunmix + - musdb>=0.4.0 + - museval>=0.4.0 + - soundfile + - submitit + - treetable>=0.2.3 + diff --git a/environment-cuda.yml b/environment-cuda.yml new file mode 100644 index 0000000000000000000000000000000000000000..585616c967d1bf8bbb3a0f5d966b9b9325e4af67 --- /dev/null +++ b/environment-cuda.yml @@ -0,0 +1,28 @@ +name: demucs + +channels: + - pytorch + - conda-forge + +dependencies: + - python>=3.8,<3.10 + - ffmpeg>=4.2 + - pytorch>=1.8.1 + - torchaudio>=0.8 + - cudatoolkit>=10 + - tqdm>=4.36 + - pip + - pip: + - diffq>=0.2 + - dora-search + - einops + - hydra-colorlog>=1.1 + - hydra-core>=1.1 + - julius>=0.2.3 + - lameenc>=1.2 + - openunmix + - musdb>=0.4.0 + - museval>=0.4.0 + - soundfile + - submitit + - treetable>=0.2.3 diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..0308ab5df7d08735dbed58ce53ebd838f51653ab --- /dev/null +++ b/hubconf.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +dependencies = ['dora-search', 'julius', 'lameenc', 'openunmix', 'pyyaml', + 'torch', 'torchaudio', 'tqdm'] + +from demucs.pretrained import get_model + diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000000000000000000000000000000000000..04cd6c892e32f8a14a9d3c67cc08d90885d4ebd3 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,5 @@ +[mypy] + +[mypy-treetable,torchaudio.*,diffq,yaml,tqdm,lameenc,musdb,museval,openunmix.*,einops,xformers.*] +ignore_missing_imports = True + diff --git a/outputs.tar.gz b/outputs.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..dd14a728b5b61d38fab306b8cb2a403025066bac --- /dev/null +++ b/outputs.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:939b6d50c7d52e72a4c3ad82ae267fe0d4bafcfc00937837a08554936f9c9207 +size 1885 diff --git a/predict.py b/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d7c6b72a516c4e1f22a9a23bcba70bc3c39eff --- /dev/null +++ b/predict.py @@ -0,0 +1,53 @@ +# predict.py +import torchaudio +import demucs.separate +from fastapi import FastAPI, UploadFile, HTTPException, status +from fastapi.responses import FileResponse +import shutil +import os +import uuid +import tempfile +import logging + +app = FastAPI() +logging.basicConfig(level=logging.INFO) + +STEMS = ["vocals", "drums", "bass", "other"] + +@app.post("/predict") +async def predict(audio: UploadFile): + # Validate file type + if not audio.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unsupported file type.") + + # Use a unique temp directory for each request + with tempfile.TemporaryDirectory() as tmpdir: + audio_path = os.path.join(tmpdir, f"{uuid.uuid4()}_{audio.filename}") + try: + # Save uploaded file + with open(audio_path, "wb") as f: + shutil.copyfileobj(audio.file, f) + + # Run Demucs separation + output_dir = os.path.join(tmpdir, "separated") + os.makedirs(output_dir, exist_ok=True) + demucs.separate.main(["--mp3", "-n", "htdemucs", "-d", "cpu", "-o", output_dir, audio_path]) + + # Find output stems + base = os.path.splitext(os.path.basename(audio_path))[0] + stem_files = {} + for stem in STEMS: + path = os.path.join(output_dir, "htdemucs", base, f"{stem}.mp3") + if not os.path.exists(path): + raise HTTPException(status_code=500, detail=f"Stem {stem} not found.") + stem_files[stem] = path + + # Optionally, return as downloadable files (example: vocals only) + # return FileResponse(stem_files["vocals"], media_type="audio/mpeg", filename=f"{base}_vocals.mp3") + + # Or return all stems as file paths (for demo; in prod, upload to S3/CDN and return URLs) + return {"stems": stem_files} + + except Exception as e: + logging.exception("Error during separation") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..63f22f2284a012d481916031c6da450ea14d24a8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +# please make sure you have already a pytorch install that is cuda enabled! +dora-search>=0.1.12 +diffq>=0.2.1 +einops +flake8 +hydra-colorlog>=1.1 +hydra-core>=1.1 +julius>=0.2.3 +lameenc>=1.2 +museval +mypy +openunmix +pyyaml +submitit +torch>=1.8.1 +torchaudio>=0.8,<2.1 +tqdm +treetable +soundfile>=0.10.3;sys_platform=="win32" diff --git a/requirements_minimal.txt b/requirements_minimal.txt new file mode 100644 index 0000000000000000000000000000000000000000..a0cb9e3278e835a4a21b1479b32ab22370536a19 --- /dev/null +++ b/requirements_minimal.txt @@ -0,0 +1,10 @@ +# please make sure you have already a pytorch install that is cuda enabled! +dora-search +einops +julius>=0.2.3 +lameenc>=1.2 +openunmix +pyyaml +torch>=1.8.1 +torchaudio>=0.8,<2.1 +tqdm diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..3429c66349b98f6cfc21c180d3d9a8893c53cd49 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,8 @@ +[pep8] +max-line-length = 100 + +[flake8] +max-line-length = 100 + +[yapf] +column_limit = 100 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..7d5bae40f0f935a2e0628b55c722c060c0baf03d --- /dev/null +++ b/setup.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# author: adefossez +# Inspired from https://github.com/kennethreitz/setup.py + +from pathlib import Path + +from setuptools import setup + + +NAME = 'demucs' +DESCRIPTION = 'Music source separation in the waveform domain.' + +URL = 'https://github.com/facebookresearch/demucs' +EMAIL = 'defossez@fb.com' +AUTHOR = 'Alexandre Défossez' +REQUIRES_PYTHON = '>=3.8.0' + +HERE = Path(__file__).parent + +# Get version without explicitely loading the module. +for line in open('demucs/__init__.py'): + line = line.strip() + if '__version__' in line: + context = {} + exec(line, context) + VERSION = context['__version__'] + + +def load_requirements(name): + required = [i.strip() for i in open(HERE / name)] + required = [i for i in required if not i.startswith('#')] + return required + + +REQUIRED = load_requirements('requirements_minimal.txt') +ALL_REQUIRED = load_requirements('requirements.txt') + +try: + with open(HERE / "README.md", encoding='utf-8') as f: + long_description = '\n' + f.read() +except FileNotFoundError: + long_description = DESCRIPTION + +setup( + name=NAME, + version=VERSION, + description=DESCRIPTION, + long_description=long_description, + long_description_content_type='text/markdown', + author=AUTHOR, + author_email=EMAIL, + python_requires=REQUIRES_PYTHON, + url=URL, + packages=['demucs'], + extras_require={ + 'dev': ALL_REQUIRED, + }, + install_requires=REQUIRED, + include_package_data=True, + entry_points={ + 'console_scripts': ['demucs=demucs.separate:main'], + }, + license='MIT License', + classifiers=[ + # Trove classifiers + # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers + 'License :: OSI Approved :: MIT License', + 'Topic :: Multimedia :: Sound/Audio', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + ], +) diff --git a/test.mp3 b/test.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..719c2bccd50f2a08c1872ed1e0a6895c2c8006d6 --- /dev/null +++ b/test.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa8d8e51ca1b65e4e04f0be52c7e63b38c17aaaeded9193bc6ce0acf6d6a6d9a +size 802480 diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4196294309799347172dba54a17360698071ca8 --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tools/automix.py b/tools/automix.py new file mode 100644 index 0000000000000000000000000000000000000000..a5c351fe68cc9b59f12692f5bf930c50a3d09653 --- /dev/null +++ b/tools/automix.py @@ -0,0 +1,343 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script creates realistic mixes with stems from different songs. +In particular, it will align BPM, sync up the first beat and perform pitch +shift to maximize pitches overlap. +In order to limit artifacts, only parts that can be mixed with less than 15% +tempo shift, and 3 semitones of pitch shift are mixed together. +""" +from collections import namedtuple +from concurrent.futures import ProcessPoolExecutor +import hashlib +from pathlib import Path +import random +import shutil +import tqdm +import pickle + +from librosa.beat import beat_track +from librosa.feature import chroma_cqt +import numpy as np +import torch +from torch.nn import functional as F + +from dora.utils import try_load +from demucs.audio import save_audio +from demucs.repitch import repitch +from demucs.pretrained import SOURCES +from demucs.wav import build_metadata, Wavset, _get_musdb_valid + + +MUSDB_PATH = '/checkpoint/defossez/datasets/musdbhq' +EXTRA_WAV_PATH = "/checkpoint/defossez/datasets/allstems_44" +# WARNING: OUTPATH will be completely erased. +OUTPATH = Path.home() / 'tmp/demucs_mdx/automix_musdb/' +CACHE = Path.home() / 'tmp/automix_cache' # cache BPM and pitch information. +CHANNELS = 2 +SR = 44100 +MAX_PITCH = 3 # maximum allowable pitch shift in semi tones +MAX_TEMPO = 0.15 # maximum allowable tempo shift + + +Spec = namedtuple("Spec", "tempo onsets kr track index") + + +def rms(wav, window=10000): + """efficient rms computed for each time step over a given window.""" + half = window // 2 + window = 2 * half + 1 + wav = F.pad(wav, (half, half)) + tot = wav.pow(2).cumsum(dim=-1) + return ((tot[..., window - 1:] - tot[..., :-window + 1]) / window).sqrt() + + +def analyse_track(dset, index): + """analyse track, extract bpm and distribution of notes from the bass line.""" + track = dset[index] + mix = track.sum(0).mean(0) + ref = mix.std() + + starts = (abs(mix) >= 1e-2 * ref).float().argmax().item() + track = track[..., starts:] + + cache = CACHE / dset.sig + cache.mkdir(exist_ok=True, parents=True) + + cache_file = cache / f"{index}.pkl" + cached = None + if cache_file.exists(): + cached = try_load(cache_file) + if cached is not None: + tempo, events, hist_kr = cached + + if cached is None: + drums = track[0].mean(0) + if drums.std() > 1e-2 * ref: + tempo, events = beat_track(y=drums.numpy(), units='time', sr=SR) + else: + print("failed drums", drums.std(), ref) + return None, track + + bass = track[1].mean(0) + r = rms(bass) + peak = r.max() + mask = r >= 0.05 * peak + bass = bass[mask] + if bass.std() > 1e-2 * ref: + kr = torch.from_numpy(chroma_cqt(y=bass.numpy(), sr=SR)) + hist_kr = (kr.max(dim=0, keepdim=True)[0] == kr).float().mean(1) + else: + print("failed bass", bass.std(), ref) + return None, track + + pickle.dump([tempo, events, hist_kr], open(cache_file, 'wb')) + spec = Spec(tempo, events, hist_kr, track, index) + return spec, None + + +def best_pitch_shift(kr_a, kr_b): + """find the best pitch shift between two chroma distributions.""" + deltas = [] + for p in range(12): + deltas.append((kr_a - kr_b).abs().mean()) + kr_b = kr_b.roll(1, 0) + + ps = np.argmin(deltas) + if ps > 6: + ps = ps - 12 + return ps + + +def align_stems(stems): + """Align the first beats of the stems. + This is a naive implementation. A grid with a time definition 10ms is defined and + each beat onset is represented as a gaussian over this grid. + Then, we try each possible time shift to make two grids align the best. + We repeat for all sources. + """ + sources = len(stems) + width = 5e-3 # grid of 10ms + limit = 5 + std = 2 + x = torch.arange(-limit, limit + 1, 1).float() + gauss = torch.exp(-x**2 / (2 * std**2)) + + grids = [] + for wav, onsets in stems: + le = wav.shape[-1] + dur = le / SR + grid = torch.zeros(int(le / width / SR)) + for onset in onsets: + pos = int(onset / width) + if onset >= dur - 1: + continue + if onset < 1: + continue + grid[pos - limit:pos + limit + 1] += gauss + grids.append(grid) + + shifts = [0] + for s in range(1, sources): + max_shift = int(4 / width) + dots = [] + for shift in range(-max_shift, max_shift): + other = grids[s] + ref = grids[0] + if shift >= 0: + other = other[shift:] + else: + ref = ref[shift:] + le = min(len(other), len(ref)) + dots.append((ref[:le].dot(other[:le]), int(shift * width * SR))) + + _, shift = max(dots) + shifts.append(-shift) + + outs = [] + new_zero = min(shifts) + for (wav, _), shift in zip(stems, shifts): + offset = shift - new_zero + wav = F.pad(wav, (offset, 0)) + outs.append(wav) + + le = min(x.shape[-1] for x in outs) + + outs = [w[..., :le] for w in outs] + return torch.stack(outs) + + +def find_candidate(spec_ref, catalog, pitch_match=True): + """Given reference track, this finds a track in the catalog that + is a potential match (pitch and tempo delta must be within the allowable limits). + """ + candidates = list(catalog) + random.shuffle(candidates) + + for spec in candidates: + ok = False + for scale in [1/4, 1/2, 1, 2, 4]: + tempo = spec.tempo * scale + delta_tempo = spec_ref.tempo / tempo - 1 + if abs(delta_tempo) < MAX_TEMPO: + ok = True + break + if not ok: + print(delta_tempo, spec_ref.tempo, spec.tempo, "FAILED TEMPO") + # too much of a tempo difference + continue + spec = spec._replace(tempo=tempo) + + ps = 0 + if pitch_match: + ps = best_pitch_shift(spec_ref.kr, spec.kr) + if abs(ps) > MAX_PITCH: + print("Failed pitch", ps) + # too much pitch difference + continue + return spec, delta_tempo, ps + + +def get_part(spec, source, dt, dp): + """Apply given delta of tempo and delta of pitch to a stem.""" + wav = spec.track[source] + if dt or dp: + wav = repitch(wav, dp, dt * 100, samplerate=SR, voice=source == 3) + spec = spec._replace(onsets=spec.onsets / (1 + dt)) + return wav, spec + + +def build_track(ref_index, catalog): + """Given the reference track index and a catalog of track, builds + a completely new track. One of the source at random from the ref track will + be kept and other sources will be drawn from the catalog. + """ + order = list(range(len(SOURCES))) + random.shuffle(order) + + stems = [None] * len(order) + indexes = [None] * len(order) + origs = [None] * len(order) + dps = [None] * len(order) + dts = [None] * len(order) + + first = order[0] + spec_ref = catalog[ref_index] + stems[first] = (spec_ref.track[first], spec_ref.onsets) + indexes[first] = ref_index + origs[first] = spec_ref.track[first] + dps[first] = 0 + dts[first] = 0 + + pitch_match = order != 0 + + for src in order[1:]: + spec, dt, dp = find_candidate(spec_ref, catalog, pitch_match=pitch_match) + if not pitch_match: + spec_ref = spec_ref._replace(kr=spec.kr) + pitch_match = True + dps[src] = dp + dts[src] = dt + wav, spec = get_part(spec, src, dt, dp) + stems[src] = (wav, spec.onsets) + indexes[src] = spec.index + origs.append(spec.track[src]) + print("FINAL CHOICES", ref_index, indexes, dps, dts) + stems = align_stems(stems) + return stems, origs + + +def get_musdb_dataset(part='train'): + root = Path(MUSDB_PATH) / part + ext = '.wav' + metadata = build_metadata(root, SOURCES, ext=ext, normalize=False) + valid_tracks = _get_musdb_valid() + metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks} + train_set = Wavset( + root, metadata_train, SOURCES, samplerate=SR, channels=CHANNELS, + normalize=False, ext=ext) + sig = hashlib.sha1(str(root).encode()).hexdigest()[:8] + train_set.sig = sig + return train_set + + +def get_wav_dataset(): + root = Path(EXTRA_WAV_PATH) + ext = '.wav' + metadata = _build_metadata(root, SOURCES, ext=ext, normalize=False) + train_set = Wavset( + root, metadata, SOURCES, samplerate=SR, channels=CHANNELS, + normalize=False, ext=ext) + sig = hashlib.sha1(str(root).encode()).hexdigest()[:8] + train_set.sig = sig + return train_set + + +def main(): + random.seed(4321) + if OUTPATH.exists(): + shutil.rmtree(OUTPATH) + OUTPATH.mkdir(exist_ok=True, parents=True) + (OUTPATH / 'train').mkdir(exist_ok=True, parents=True) + (OUTPATH / 'valid').mkdir(exist_ok=True, parents=True) + out = OUTPATH / 'train' + + dset = get_musdb_dataset() + # dset2 = get_wav_dataset() + # dset3 = get_musdb_dataset('test') + dset2 = None + dset3 = None + pendings = [] + copies = 6 + copies_rej = 2 + + with ProcessPoolExecutor(20) as pool: + for index in range(len(dset)): + pendings.append(pool.submit(analyse_track, dset, index)) + + if dset2: + for index in range(len(dset2)): + pendings.append(pool.submit(analyse_track, dset2, index)) + if dset3: + for index in range(len(dset3)): + pendings.append(pool.submit(analyse_track, dset3, index)) + + catalog = [] + rej = 0 + for pending in tqdm.tqdm(pendings, ncols=120): + spec, track = pending.result() + if spec is not None: + catalog.append(spec) + else: + mix = track.sum(0) + for copy in range(copies_rej): + folder = out / f'rej_{rej}_{copy}' + folder.mkdir() + save_audio(mix, folder / "mixture.wav", SR) + for stem, source in zip(track, SOURCES): + save_audio(stem, folder / f"{source}.wav", SR, clip='clamp') + rej += 1 + + for copy in range(copies): + for index in range(len(catalog)): + track, origs = build_track(index, catalog) + mix = track.sum(0) + mx = mix.abs().max() + scale = max(1, 1.01 * mx) + mix = mix / scale + track = track / scale + folder = out / f'{copy}_{index}' + folder.mkdir() + save_audio(mix, folder / "mixture.wav", SR) + for stem, source, orig in zip(track, SOURCES, origs): + save_audio(stem, folder / f"{source}.wav", SR, clip='clamp') + # save_audio(stem.std() * orig / (1e-6 + orig.std()), folder / f"{source}_orig.wav", + # SR, clip='clamp') + + +if __name__ == '__main__': + main() diff --git a/tools/bench.py b/tools/bench.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb1ec5d70ab64cffc3afe4408d8eb1944094024 --- /dev/null +++ b/tools/bench.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +benchmarking script, useful to check for OOM, reasonable train time, +and for the MDX competion, estimate if we will match the time limit.""" +from contextlib import contextmanager +import logging +import sys +import time +import torch + +from demucs.train import get_solver, main +from demucs.apply import apply_model + +logging.basicConfig(level=logging.INFO, stream=sys.stderr) + + +class Result: + pass + + +@contextmanager +def bench(): + import gc + gc.collect() + torch.cuda.reset_max_memory_allocated() + torch.cuda.empty_cache() + result = Result() + # before = torch.cuda.memory_allocated() + before = 0 + begin = time.time() + try: + yield result + finally: + torch.cuda.synchronize() + mem = (torch.cuda.max_memory_allocated() - before) / 2 ** 20 + tim = time.time() - begin + result.mem = mem + result.tim = tim + + +xp = main.get_xp_from_sig(sys.argv[1]) +xp = main.get_xp(xp.argv + sys.argv[2:]) +with xp.enter(): + solver = get_solver(xp.cfg) + if getattr(solver.model, 'use_train_segment', False): + batch = solver.augment(next(iter(solver.loaders['train']))) + solver.model.segment = Fraction(batch.shape[-1], solver.model.samplerate) + train_segment = solver.model.segment + solver.model.eval() + model = solver.model + model.cuda() + x = torch.randn(2, xp.cfg.dset.channels, int(10 * model.samplerate), device='cuda') + with bench() as res: + y = model(x) + y.sum().backward() + del y + for p in model.parameters(): + p.grad = None + print(f"FB: {res.mem:.1f} MB, {res.tim * 1000:.1f} ms") + + x = torch.randn(1, xp.cfg.dset.channels, int(model.segment * model.samplerate), device='cuda') + with bench() as res: + with torch.no_grad(): + y = model(x) + del y + print(f"FV: {res.mem:.1f} MB, {res.tim * 1000:.1f} ms") + + model.cpu() + torch.set_num_threads(1) + test = torch.randn(1, xp.cfg.dset.channels, model.samplerate * 40) + b = time.time() + apply_model(model, test, split=True, shifts=1) + print("CPU 40 sec:", time.time() - b) diff --git a/tools/convert.py b/tools/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..a91b3cc26e96a534db62a6554370940a512bd538 --- /dev/null +++ b/tools/convert.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Script to convert option names and model args from the dev branch to +# the cleanup release one. There should be no reaso to use that anymore. + +import argparse +import io +import json +from pathlib import Path +import subprocess as sp + +import torch + +from demucs import train, pretrained, states + +DEV_REPO = Path.home() / 'tmp/release_demucs_mdx' + + +TO_REMOVE = [ + 'demucs.dconv_kw.gelu=True', + 'demucs.dconv_kw.nfreqs=0', + 'demucs.dconv_kw.nfreqs=0', + 'demucs.dconv_kw.version=4', + 'demucs.norm=gn', + 'wdemucs.nice=True', + 'wdemucs.good=True', + 'wdemucs.freq_emb=-0.2', + 'special=True', + 'special=False', +] + +TO_REPLACE = [ + ('power', 'svd'), + ('wdemucs', 'hdemucs'), + ('hdemucs.hybrid=True', 'hdemucs.hybrid_old=True'), + ('hdemucs.hybrid=2', 'hdemucs.hybrid=True'), +] + +TO_INJECT = [ + ('model=hdemucs', ['hdemucs.cac=False']), + ('model=hdemucs', ['hdemucs.norm_starts=999']), +] + + +def get_original_argv(sig): + return json.load(open(Path(DEV_REPO) / f'outputs/xps/{sig}/.argv.json')) + + +def transform(argv, mappings, verbose=False): + for rm in TO_REMOVE: + while rm in argv: + argv.remove(rm) + + for old, new in TO_REPLACE: + argv[:] = [a.replace(old, new) for a in argv] + + for condition, args in TO_INJECT: + if condition in argv: + argv[:] = args + argv + + for idx, arg in enumerate(argv): + if 'continue_from=' in arg: + dep_sig = arg.split('=')[1] + if dep_sig.startswith('"'): + dep_sig = eval(dep_sig) + if verbose: + print("Need to recursively convert dependency XP", dep_sig) + new_sig = convert(dep_sig, mappings, verbose).sig + argv[idx] = f'continue_from="{new_sig}"' + + +def convert(sig, mappings, verbose=False): + argv = get_original_argv(sig) + if verbose: + print("Original argv", argv) + transform(argv, mappings, verbose) + if verbose: + print("New argv", argv) + xp = train.main.get_xp(argv) + train.main.init_xp(xp) + if verbose: + print("Mapping", sig, "->", xp.sig) + mappings[sig] = xp.sig + return xp + + +def _eval_old(old_sig, x): + script = ( + 'from demucs import pretrained; import torch; import sys; import io; ' + 'buf = io.BytesIO(sys.stdin.buffer.read()); ' + 'x = torch.load(buf); m = pretrained.load_pretrained_model(' + f'"{old_sig}"); torch.save(m(x), sys.stdout.buffer)') + + buf = io.BytesIO() + torch.save(x, buf) + proc = sp.run( + ['python3', '-c', script], input=buf.getvalue(), capture_output=True, cwd=DEV_REPO) + if proc.returncode != 0: + print("Error", proc.stderr.decode()) + assert False + + buf = io.BytesIO(proc.stdout) + return torch.load(buf) + + +def compare(old_sig, model): + test = torch.randn(1, 2, 44100 * 10) + old_out = _eval_old(old_sig, test) + out = model(test) + + delta = 20 * torch.log10((out - old_out).norm() / out.norm()).item() + return delta + + +def main(): + torch.manual_seed(1234) + parser = argparse.ArgumentParser('convert') + parser.add_argument('sigs', nargs='*') + parser.add_argument('-o', '--output', type=Path, default=Path('release_models')) + parser.add_argument('-d', '--dump', action='store_true') + parser.add_argument('-c', '--compare', action='store_true') + parser.add_argument('-v', '--verbose', action='store_true') + args = parser.parse_args() + + args.output.mkdir(exist_ok=True, parents=True) + mappings = {} + for sig in args.sigs: + xp = convert(sig, mappings, args.verbose) + if args.dump or args.compare: + old_pkg = pretrained._load_package(sig, old=True) + model = train.get_model(xp.cfg) + model.load_state_dict(old_pkg['state']) + if args.dump: + pkg = states.serialize_model(model, xp.cfg) + states.save_with_checksum(pkg, args.output / f'{xp.sig}.th') + if args.compare: + delta = compare(sig, model) + print("Delta for", sig, xp.sig, delta) + + mappings[sig] = xp.sig + + print("FINAL MAPPINGS") + for old, new in mappings.items(): + print(old, " ", new) + + +if __name__ == '__main__': + main() diff --git a/tools/export.py b/tools/export.py new file mode 100644 index 0000000000000000000000000000000000000000..5113dff18c0a6848e84a97c8759e524bc3fc6d70 --- /dev/null +++ b/tools/export.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Export a trained model from the full checkpoint (with optimizer etc.) to +a final checkpoint, with only the model itself. The model is always stored as +half float to gain space, and because this has zero impact on the final loss. +When DiffQ was used for training, the model will actually be quantized and bitpacked.""" +from argparse import ArgumentParser +from fractions import Fraction +import logging +from pathlib import Path +import sys +import torch + +from demucs import train +from demucs.states import serialize_model, save_with_checksum + + +logger = logging.getLogger(__name__) + + +def main(): + logging.basicConfig(level=logging.INFO, stream=sys.stderr) + + parser = ArgumentParser("tools.export", description="Export trained models from XP sigs.") + parser.add_argument('signatures', nargs='*', help='XP signatures.') + parser.add_argument('-o', '--out', type=Path, default=Path("release_models"), + help="Path where to store release models (default release_models)") + parser.add_argument('-s', '--sign', action='store_true', + help='Add sha256 prefix checksum to the filename.') + + args = parser.parse_args() + args.out.mkdir(exist_ok=True, parents=True) + + for sig in args.signatures: + xp = train.main.get_xp_from_sig(sig) + name = train.main.get_name(xp) + logger.info('Handling %s/%s', sig, name) + + out_path = args.out / (sig + ".th") + + solver = train.get_solver_from_sig(sig) + if len(solver.history) < solver.args.epochs: + logger.warning( + 'Model %s has less epoch than expected (%d / %d)', + sig, len(solver.history), solver.args.epochs) + + solver.model.load_state_dict(solver.best_state) + pkg = serialize_model(solver.model, solver.args, solver.quantizer, half=True) + if getattr(solver.model, 'use_train_segment', False): + batch = solver.augment(next(iter(solver.loaders['train']))) + pkg['kwargs']['segment'] = Fraction(batch.shape[-1], solver.model.samplerate) + print("Override", pkg['kwargs']['segment']) + valid, test = None, None + for m in solver.history: + if 'valid' in m: + valid = m['valid'] + if 'test' in m: + test = m['test'] + pkg['metrics'] = (valid, test) + if args.sign: + save_with_checksum(pkg, out_path) + else: + torch.save(pkg, out_path) + + +if __name__ == '__main__': + main() diff --git a/tools/test_pretrained.py b/tools/test_pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..84886598300709294b3528bdbf79770020f1cf28 --- /dev/null +++ b/tools/test_pretrained.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Script to evaluate pretrained models. + +from argparse import ArgumentParser +import logging +import sys + +import torch + +from demucs import train, pretrained, evaluate + + +def main(): + torch.set_num_threads(1) + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + parser = ArgumentParser("tools.test_pretrained", + description="Evaluate pre-trained models or bags of models " + "on MusDB.") + pretrained.add_model_flags(parser) + parser.add_argument('overrides', nargs='*', + help='Extra overrides, e.g. test.shifts=2.') + args = parser.parse_args() + + xp = train.main.get_xp(args.overrides) + with xp.enter(): + solver = train.get_solver(xp.cfg) + + model = pretrained.get_model_from_args(args) + solver.model = model.to(solver.device) + solver.model.eval() + + with torch.no_grad(): + results = evaluate.evaluate(solver, xp.cfg.test.sdr) + print(results) + + +if __name__ == '__main__': + main()