From 774561a3683271e1b83cccd15fb5f92086002d55 Mon Sep 17 00:00:00 2001 From: Yoshiki Masuyama Date: Wed, 20 Nov 2024 09:43:34 -0500 Subject: [PATCH] Create release 1.0.0 --- .github/workflows/build_and_test.yaml | 38 + .github/workflows/static_checks.yaml | 77 ++ .gitignore | 179 +++++ .pre-commit-config.yaml | 51 ++ CONTRIBUTING.md | 9 + LICENSE.md | 660 ++++++++++++++++++ README.md | 99 +++ .../hrtf_selection/original_config.yaml | 15 + .../nearest_neighbor/original_config.yaml | 15 + config_template/nfcbc/original_config.yaml | 55 ++ config_template/nflora/original_config.yaml | 55 ++ config_template/ranf/original_config.yaml | 60 ++ preprocess_sonicom.sh | 27 + ranf/1_pretraining_neural_field.py | 181 +++++ ranf/2_adapting_neural_field.py | 191 +++++ ranf/3_evaluating_neural_field.py | 109 +++ ranf/__init__.py | 3 + .../compute_distance_matrices_for_spec_itd.py | 88 +++ ...mpute_spec_ild_itd_for_sonicom_datasets.py | 58 ++ ranf/evaluating_hrtf_selection.py | 110 +++ ranf/evaluating_nearest_neighbor.py | 104 +++ ranf/prepare_single_fold.py | 64 ++ ranf/summarize_evaluation_result.py | 44 ++ ranf/utils/__init__.py | 3 + ranf/utils/config.py | 380 ++++++++++ ranf/utils/loss_functions.py | 39 ++ ranf/utils/neural_field_icassp.py | 588 ++++++++++++++++ ranf/utils/reconstruction.py | 68 ++ ranf/utils/sonicom_dataset_retrieval.py | 201 ++++++ ranf/utils/util.py | 80 +++ requirements-dev.txt | 8 + requirements.txt | 13 + run_example.sh | 69 ++ run_learningfree_methods.sh | 62 ++ tests/__init__.py | 3 + tests/loss_test.py | 48 ++ tests/model_init_test.py | 175 +++++ tests/post_processing_test.py | 34 + 38 files changed, 4063 insertions(+) create mode 100644 .github/workflows/build_and_test.yaml create mode 100644 .github/workflows/static_checks.yaml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE.md create mode 100644 README.md create mode 100644 config_template/hrtf_selection/original_config.yaml create mode 100644 config_template/nearest_neighbor/original_config.yaml create mode 100644 config_template/nfcbc/original_config.yaml create mode 100644 config_template/nflora/original_config.yaml create mode 100644 config_template/ranf/original_config.yaml create mode 100644 preprocess_sonicom.sh create mode 100644 ranf/1_pretraining_neural_field.py create mode 100644 ranf/2_adapting_neural_field.py create mode 100644 ranf/3_evaluating_neural_field.py create mode 100644 ranf/__init__.py create mode 100644 ranf/compute_distance_matrices_for_spec_itd.py create mode 100644 ranf/compute_spec_ild_itd_for_sonicom_datasets.py create mode 100644 ranf/evaluating_hrtf_selection.py create mode 100644 ranf/evaluating_nearest_neighbor.py create mode 100644 ranf/prepare_single_fold.py create mode 100644 ranf/summarize_evaluation_result.py create mode 100644 ranf/utils/__init__.py create mode 100644 ranf/utils/config.py create mode 100644 ranf/utils/loss_functions.py create mode 100644 ranf/utils/neural_field_icassp.py create mode 100644 ranf/utils/reconstruction.py create mode 100644 ranf/utils/sonicom_dataset_retrieval.py create mode 100644 ranf/utils/util.py create mode 100644 requirements-dev.txt create mode 100644 requirements.txt create mode 100755 run_example.sh create mode 100644 run_learningfree_methods.sh create mode 100644 tests/__init__.py create mode 100644 tests/loss_test.py create mode 100644 tests/model_init_test.py create mode 100644 tests/post_processing_test.py diff --git a/.github/workflows/build_and_test.yaml b/.github/workflows/build_and_test.yaml new file mode 100644 index 0000000..e750910 --- /dev/null +++ b/.github/workflows/build_and_test.yaml @@ -0,0 +1,38 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +name: Build and Test + +on: + pull_request: + push: + branches: + - '**' + tags-ignore: + - '**' + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout repo + uses: actions/checkout@v3 + + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: '3.10' + cache: 'pip' + cache-dependency-path: 'requirements.txt' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pytest + + - name: Run unit tests + run: | + python -m pytest tests diff --git a/.github/workflows/static_checks.yaml b/.github/workflows/static_checks.yaml new file mode 100644 index 0000000..2f89f1c --- /dev/null +++ b/.github/workflows/static_checks.yaml @@ -0,0 +1,77 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +name: Static code checks + +on: + pull_request: + push: + branches: + - '**' + tags-ignore: + - '**' + +env: + LICENSE: AGPL-3.0-or-later + FETCH_DEPTH: 1 + FULL_HISTORY: 0 + SKIP_WORD_PRESENCE_CHECK: 0 + +jobs: + static-code-check: + if: endsWith(github.event.repository.name, 'private') + + name: Run static code checks + # See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu1804-Readme.md for list of packages + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + + steps: + - name: Setup history + if: github.ref == 'refs/heads/oss' + run: | + echo "FETCH_DEPTH=0" >> $GITHUB_ENV + echo "FULL_HISTORY=1" >> $GITHUB_ENV + + - name: Setup version + if: github.ref == 'refs/heads/melco' + run: | + echo "SKIP_WORD_PRESENCE_CHECK=1" >> $GITHUB_ENV + + - name: Check out code + uses: actions/checkout@v3 + with: + fetch-depth: ${{ env.FETCH_DEPTH }} # '0' to check full history + + - name: Set up environment + run: git config user.email github-bot@merl.com + + - name: Set up python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + cache: 'pip' + cache-dependency-path: 'requirements-dev.txt' + + - name: Install python packages + run: pip install -r requirements-dev.txt + + - name: Ensure lint and pre-commit steps have been run + uses: pre-commit/action@v3.0.0 + + - name: Check files + uses: merl-oss-private/merl-file-check-action@v1 + with: + license: ${{ env.LICENSE }} + full-history: ${{ env.FULL_HISTORY }} # If true, use fetch-depth 0 above + skip-word-presence-check: ${{ env.SKIP_WORD_PRESENCE_CHECK }} + + - name: Check license compatibility + if: github.ref != 'refs/heads/melco' + uses: merl-oss-private/merl_license_compatibility_checker@v1 + with: + input-filename: requirements.txt + license: ${{ env.LICENSE }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0b6c0d3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,179 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Data +*.wav +*.png +*.pdf +*.npy +*.npz +*.npys +*.pickle +*.ckpt +*.ipynb +*.mat +*.zip +*.sofa diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..39aabdb --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,51 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# +# Pre-commit configuration. See https://pre-commit.com + +default_language_version: + python: python3 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-yaml + - id: check-json + exclude: templates.*/python/.vscode/settings.json + - id: check-added-large-files + args: ['--maxkb=5000'] + + - repo: https://github.com/homebysix/pre-commit-macadmin + rev: v1.16.2 + hooks: + - id: check-git-config-email + args: ['--domains', 'merl.com'] + + - repo: https://github.com/psf/black + rev: 24.4.2 + hooks: + - id: black + args: + - --line-length=120 + + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files", "--line-length", "120", "--skip-gitignore"] + + - repo: https://github.com/asottile/pyupgrade + rev: v3.16.0 + hooks: + - id: pyupgrade + + - repo: https://github.com/pycqa/flake8 + rev: 7.1.0 + hooks: + - id: flake8 + # Black compatibility + args: ["--max-line-length=120", "--extend-ignore=E203, E704"] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..9c8cb63 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,9 @@ + +# Contributing + +Sorry, but we do not currently accept contributions in the form of pull requests to this repository. +However, you are welcome to post issues (bug reports, feature requests, questions, etc). diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..cba6f6a --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,660 @@ +### GNU AFFERO GENERAL PUBLIC LICENSE + +Version 3, 19 November 2007 + +Copyright (C) 2007 Free Software Foundation, Inc. + + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + +### Preamble + +The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + +The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains +free software for all its users. + +When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + +Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + +A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + +The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + +An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing +under this license. + +The precise terms and conditions for copying, distribution and +modification follow. + +### TERMS AND CONDITIONS + +#### 0. Definitions. + +"This License" refers to version 3 of the GNU Affero General Public +License. + +"Copyright" also means copyright-like laws that apply to other kinds +of works, such as semiconductor masks. + +"The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + +To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of +an exact copy. The resulting work is called a "modified version" of +the earlier work or a work "based on" the earlier work. + +A "covered work" means either the unmodified Program or a work based +on the Program. + +To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + +To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user +through a computer network, with no transfer of a copy, is not +conveying. + +An interactive user interface displays "Appropriate Legal Notices" to +the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + +#### 1. Source Code. + +The "source code" for a work means the preferred form of the work for +making modifications to it. "Object code" means any non-source form of +a work. + +A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + +The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + +The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + +The Corresponding Source need not include anything that users can +regenerate automatically from other parts of the Corresponding Source. + +The Corresponding Source for a work in source code form is that same +work. + +#### 2. Basic Permissions. + +All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + +You may make, run and propagate covered works that you do not convey, +without conditions so long as your license otherwise remains in force. +You may convey covered works to others for the sole purpose of having +them make modifications exclusively for you, or provide you with +facilities for running those works, provided that you comply with the +terms of this License in conveying all material for which you do not +control copyright. Those thus making or running the covered works for +you must do so exclusively on your behalf, under your direction and +control, on terms that prohibit them from making any copies of your +copyrighted material outside their relationship with you. + +Conveying under any other circumstances is permitted solely under the +conditions stated below. Sublicensing is not allowed; section 10 makes +it unnecessary. + +#### 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + +No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + +When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such +circumvention is effected by exercising rights under this License with +respect to the covered work, and you disclaim any intention to limit +operation or modification of the work as a means of enforcing, against +the work's users, your or third parties' legal rights to forbid +circumvention of technological measures. + +#### 4. Conveying Verbatim Copies. + +You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + +You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + +#### 5. Conveying Modified Source Versions. + +You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these +conditions: + +- a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. +- b) The work must carry prominent notices stating that it is + released under this License and any conditions added under + section 7. This requirement modifies the requirement in section 4 + to "keep intact all notices". +- c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. +- d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + +A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + +#### 6. Conveying Non-Source Forms. + +You may convey a covered work in object code form under the terms of +sections 4 and 5, provided that you also convey the machine-readable +Corresponding Source under the terms of this License, in one of these +ways: + +- a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. +- b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the Corresponding + Source from a network server at no charge. +- c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. +- d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. +- e) Convey the object code using peer-to-peer transmission, + provided you inform other peers where the object code and + Corresponding Source of the work are being offered to the general + public at no charge under subsection 6d. + +A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + +A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, +family, or household purposes, or (2) anything designed or sold for +incorporation into a dwelling. In determining whether a product is a +consumer product, doubtful cases shall be resolved in favor of +coverage. For a particular product received by a particular user, +"normally used" refers to a typical or common use of that class of +product, regardless of the status of the particular user or of the way +in which the particular user actually uses, or expects or is expected +to use, the product. A product is a consumer product regardless of +whether the product has substantial commercial, industrial or +non-consumer uses, unless such uses represent the only significant +mode of use of the product. + +"Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to +install and execute modified versions of a covered work in that User +Product from a modified version of its Corresponding Source. The +information must suffice to ensure that the continued functioning of +the modified object code is in no case prevented or interfered with +solely because modification has been made. + +If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + +The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or +updates for a work that has been modified or installed by the +recipient, or for the User Product in which it has been modified or +installed. Access to a network may be denied when the modification +itself materially and adversely affects the operation of the network +or violates the rules and protocols for communication across the +network. + +Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + +#### 7. Additional Terms. + +"Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + +When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + +Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders +of that material) supplement the terms of this License with terms: + +- a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or +- b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or +- c) Prohibiting misrepresentation of the origin of that material, + or requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or +- d) Limiting the use for publicity purposes of names of licensors + or authors of the material; or +- e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or +- f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions + of it) with contractual assumptions of liability to the recipient, + for any liability that these contractual assumptions directly + impose on those licensors and authors. + +All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + +If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + +Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; the +above requirements apply either way. + +#### 8. Termination. + +You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + +However, if you cease all violation of this License, then your license +from a particular copyright holder is reinstated (a) provisionally, +unless and until the copyright holder explicitly and finally +terminates your license, and (b) permanently, if the copyright holder +fails to notify you of the violation by some reasonable means prior to +60 days after the cessation. + +Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + +Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + +#### 9. Acceptance Not Required for Having Copies. + +You are not required to accept this License in order to receive or run +a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + +#### 10. Automatic Licensing of Downstream Recipients. + +Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + +An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + +You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + +#### 11. Patents. + +A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + +A contributor's "essential patent claims" are all patent claims owned +or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + +Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + +In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + +If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + +If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + +A patent license is "discriminatory" if it does not include within the +scope of its coverage, prohibits the exercise of, or is conditioned on +the non-exercise of one or more of the rights that are specifically +granted under this License. You may not convey a covered work if you +are a party to an arrangement with a third party that is in the +business of distributing software, under which you make payment to the +third party based on the extent of your activity of conveying the +work, and under which the third party grants, to any of the parties +who would receive the covered work from you, a discriminatory patent +license (a) in connection with copies of the covered work conveyed by +you (or copies made from those copies), or (b) primarily for and in +connection with specific products or compilations that contain the +covered work, unless you entered into that arrangement, or that patent +license was granted, prior to 28 March 2007. + +Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + +#### 12. No Surrender of Others' Freedom. + +If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under +this License and any other pertinent obligations, then as a +consequence you may not convey it at all. For example, if you agree to +terms that obligate you to collect a royalty for further conveying +from those to whom you convey the Program, the only way you could +satisfy both those terms and this License would be to refrain entirely +from conveying the Program. + +#### 13. Remote Network Interaction; Use with the GNU General Public License. + +Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your +version supports such interaction) an opportunity to receive the +Corresponding Source of your version by providing access to the +Corresponding Source from a network server at no charge, through some +standard or customary means of facilitating copying of software. This +Corresponding Source shall include the Corresponding Source for any +work covered by version 3 of the GNU General Public License that is +incorporated pursuant to the following paragraph. + +Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + +#### 14. Revised Versions of this License. + +The Free Software Foundation may publish revised and/or new versions +of the GNU Affero General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + +Each version is given a distinguishing version number. If the Program +specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever +published by the Free Software Foundation. + +If the Program specifies that a proxy can decide which future versions +of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + +Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + +#### 15. Disclaimer of Warranty. + +THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT +WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND +PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE +DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR +CORRECTION. + +#### 16. Limitation of Liability. + +IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR +CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, +INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES +ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT +NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR +LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM +TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER +PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +#### 17. Interpretation of Sections 15 and 16. + +If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + +END OF TERMS AND CONDITIONS + +### How to Apply These Terms to Your New Programs + +If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these +terms. + +To do so, attach the following notices to the program. It is safest to +attach them to the start of each source file to most effectively state +the exclusion of warranty; and each file should have at least the +"copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as + published by the Free Software Foundation, either version 3 of the + License, or (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper +mail. + +If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for +the specific requirements. + +You should also get your employer (if you work as a programmer) or +school, if any, to sign a "copyright disclaimer" for the program, if +necessary. For more information on this, and how to apply and follow +the GNU AGPL, see . diff --git a/README.md b/README.md new file mode 100644 index 0000000..7941552 --- /dev/null +++ b/README.md @@ -0,0 +1,99 @@ + +# Retrieval-Augmented Neural Field for HRTF Upsampling and Personalization + +This repository includes source code for training and evaluating the retrieval-augmented neural field (RANF) proposed in the following ICASSP 2025 submission: + + @InProceedings{Masuyama2024ICASSP_ranf, + author = {Masuyama, Yoshiki and Wichern, Gordon and Germain, Fran\c{c}ois G. and Ick, Christopher and {Le Roux}, Jonathan}, + title = {Retrieval-Augmented Neural Field for HRTF Upsampling and Personalization}, + booktitle = {Submitted to IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, + year = 2025, + month = apr + } + +## Table of contents + +1. [Environment setup](#environment-setup) +2. [Supported sparsity levels and models](#supported-sparsity-levels-and-models) +3. [Training and evaluating RANF](#training-and-evaluating-ranf) +4. [Evaluating learning-free baseline methods](#evaluating-learning-free-baseline-methods) +5. [Contributing](#contributing) +6. [Copyright and license](#copyright-and-license) + +## Environment setup + +The code has been tested using `python 3.10.0` on Linux. +Necessary dependencies can be installed using the included `requirements.txt`: + +```bash +pip install -r requirements.txt +``` + +## Supported sparsity levels and models +- Our HRTF upsampling experiments were performed on [the SONICOM dataset](https://www.sonicom.eu/tools-and-resources/hrtf-dataset/) that is [released under the MIT license](https://www.axdesign.co.uk/tools-and-devices/sonicom-hrtf-dataset). +- We performed HRTF upsampling with four sparsity levels following [Task 2 of the Listener Acoustic Personalization Challenge 2024](https://www.sonicom.eu/lap-challenge/). The number of measured directions, `sp_level` in `run_example.sh`, should be selected from `{3, 5, 19, 100}`, where smaller is more challenging to upsample. +- We currently support three NF-based methods. Please refer to our paper for their details. + - NF with conditioning by concatenation (CbC): NF takes a subject-specific latent vector as an auxiliary input in addition to the sound source direction. + - NF with low-rank adaptation (LoRA): The model weights will be updated by adding a subject-specific low-rank matrix. + - RANF: NF takes HRTF magnitude and ITDs of the retrieved subjects in addition to the sound source direction. LoRA is also used to adapt the model. + +## Training and evaluating RANF +In order to train and evaluate RANF and the existing NF-based methods on the SONICOM dataset, please execute `run_example.sh` after following Stage 0. Then, `run_example.sh` consists of five stages. You can run each stage one by one by changing `stage` and `stop_stage` in the script. + + +- **Stage 0:** + - Before starting the training and evaluation, download the SONICOM dataset into a directory specified in `original_path` in `run_example.sh` and unzip the dataset. + - The directory is assumed to contain `KEMAR`, `P0001-P0005`, ..., `P0196-P0200`. + - If you find `P0050_FreeFieldCompMinPhase_48kHz.sofa` instead of `P0051_FreeFieldCompMinPhase_48kHz.sofa` in `$original_path/P0051-P0055/P0051/HRTF/HRTF/48kHz`, please copy it as follows: + ```bash + cp $original_path/P0051-P0055/P0051/HRTF/HRTF/48kHz/P0050_FreeFieldCompMinPhase_48kHz.sofa $original_path/P0051-P0055/P0051/HRTF/HRTF/48kHz/P0051_FreeFieldCompMinPhase_48kHz.sofa + ``` + - `preprocessed_dataset_path` should be specified to save the preprocessed SONICOM dataset. + - Model checkpoints and log files will be stored in subdirectories under `exp_base_path` + - You can select a model and a sparsity level by `config_path` and `sp_level`, respectively. + +- **Stage 1:** + - This stage extracts features (spectra and ITDs) and stores the features as `$sonicom_path/npzs/features_and_locs_wo_azimuth_calibration.npz` + - This stage is required only once, and you can start from Stage 2 if you want to train a new model. + +- **Stage 2:** + - This stage splits the datasets (train, valid, and test), where the option `--skip_78` enforces that the training set excludes a subject with atypical ITD measurements. + - Distance matrices between subjects in terms of the spectra and ITDs are computed based on the measured HRTFs, where the number of measurements depends on the given sparsity level. + - The configuration file in `$config_path/original_config.yaml` will be modified based on the sparsity level and the data split, and then the updated cconfiguration file will be saved in `$exp_path/config.yaml`. + +- **Stage 3:** + - This stage trains the model specified by `$exp_path/original_config.yaml` on the multi-subject training dataset. + - The log file will be stored in `$exp_path/log/exp.log`, while the checkpoint with the best validation loss will be `$exp_path/best.ckpt` + +- **Stage 4:** + - This stage adapts the pre-trained model to the target subject by fine-tuning a few parameters in the model. + - The log file will be stored in `$exp_path/log/adaptation/adaptation.log`, while the checkpoint with the best adaptation loss will be `$exp_path/adaptation.ckpt` + - Currently, this stage simultaneously optimizes the subject-specific parameters of all target subjects since the parameters of each subject are independent of other subjects. + +- **Stage 5:** + - This stage runs inference and evaluates the results. + - The metrics used in the LAP challenge for each subject will be in `$exp_path/log/eval/eval.log`, and the summarized result will be shown in the CLI. + - We note that the performance may vary from the results reported in the paper depending on your specific environment, especially when `sp_level = 3`, and we used Pytorch 1.13.0 for the paper while the current default in `requirements.txt` is 2.2.2. + + +## Evaluating learning-free baseline methods +In order to evaluate the learning-free methods, HRTF selection and nearest neighbor neighbor, please execute `run_learningfree_methods.sh` after specifying the paths as explained for RANF above. Stages 1 and 2 are the same as for RANF, and both inference and evaluation are performed in Stage 3. + + +## Contributing +See [CONTRIBUTING.md](CONTRIBUTING.md) for our policy on contributions. + + +## Copyright and license +Released under `AGPL-3.0-or-later` license, as found in the [LICENSE.md](LICENSE.md) file. + +All files: +``` +Copyright (c) 2024 Mitsubishi Electric Research Laboratories (MERL) + +SPDX-License-Identifier: AGPL-3.0-or-later +``` diff --git a/config_template/hrtf_selection/original_config.yaml b/config_template/hrtf_selection/original_config.yaml new file mode 100644 index 0000000..171dfd2 --- /dev/null +++ b/config_template/hrtf_selection/original_config.yaml @@ -0,0 +1,15 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +dataset: + upsample: + train_subjects: + valid_subjects: + test_subjects: + features: + retrieval: + azimuth_calibration: false + retrieval_priority: itdd +device: +seed: 0 diff --git a/config_template/nearest_neighbor/original_config.yaml b/config_template/nearest_neighbor/original_config.yaml new file mode 100644 index 0000000..92459b8 --- /dev/null +++ b/config_template/nearest_neighbor/original_config.yaml @@ -0,0 +1,15 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +dataset: + upsample: + train_subjects: + valid_subjects: + test_subjects: + features: + retrieval: + azimuth_calibration: false + retrieval_priority: +device: +seed: 0 diff --git a/config_template/nfcbc/original_config.yaml b/config_template/nfcbc/original_config.yaml new file mode 100644 index 0000000..2986795 --- /dev/null +++ b/config_template/nfcbc/original_config.yaml @@ -0,0 +1,55 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +dataset: + upsample: + train_subjects: + valid_subjects: + test_subjects: + features: + retrieval: + npool: 1 # not used + nretrieval: 1 # not used + azimuth_calibration: false + retrieval_priority: itdd # not used +device: cuda:0 +seed: 0 +persistent_workers: false +learning: + num_epoch: 200 + batch_size: 64 + num_workers: 1 + clip: 3.0 + patience: 20 +adaptation: + num_epoch: 500 + batch_size: 5 + num_workers: 1 + clip: 3.0 +model: + name: CbCNeuralField + config: + hidden_features: 256 + embed_features: 32 + hidden_layers: 4 + out_features: 258 + scale: 1.0 + dropout: 0.1 + n_listeners: 200 + activation: "GELU" + itd_skip_connection: true +optimizer: + name: RAdam + config: + lr: 0.001 +scheduler: + name: ReduceLROnPlateau + config: + mode: min + factor: 0.9 + patience: 10 +loss: + weight_itd: 1.0 + eps: 1.0e-05 + threshold_itd: 0.5 diff --git a/config_template/nflora/original_config.yaml b/config_template/nflora/original_config.yaml new file mode 100644 index 0000000..0ebc132 --- /dev/null +++ b/config_template/nflora/original_config.yaml @@ -0,0 +1,55 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +dataset: + upsample: + train_subjects: + valid_subjects: + test_subjects: + features: + retrieval: + npool: 1 + nretrieval: 1 + azimuth_calibration: false + retrieval_priority: itdd +device: cuda:0 +seed: 0 +persistent_workers: false +learning: + num_epoch: 200 + batch_size: 64 + num_workers: 1 + clip: 3.0 + patience: 20 +adaptation: + num_epoch: 500 + batch_size: 5 + num_workers: 1 + clip: 3.0 +model: + name: PEFTNeuralField + config: + hidden_features: 256 + hidden_layers: 4 + out_features: 258 + scale: 1.0 + dropout: 0.1 + n_listeners: 200 + activation: "GELU" + peft: lora + itd_skip_connection: true +optimizer: + name: RAdam + config: + lr: 0.001 +scheduler: + name: ReduceLROnPlateau + config: + mode: min + factor: 0.9 + patience: 10 +loss: + weight_itd: 1.0 + eps: 1.0e-05 + threshold_itd: 0.5 diff --git a/config_template/ranf/original_config.yaml b/config_template/ranf/original_config.yaml new file mode 100644 index 0000000..242f6a4 --- /dev/null +++ b/config_template/ranf/original_config.yaml @@ -0,0 +1,60 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +dataset: + upsample: + train_subjects: + valid_subjects: + test_subjects: + features: + retrieval: + npool: 5 + nretrieval: 5 + azimuth_calibration: false + retrieval_priority: itdd +device: cuda:0 +seed: 0 +persistent_workers: false +learning: + num_epoch: 200 + batch_size: 64 + num_workers: 1 + clip: 3.0 + patience: 20 +adaptation: + num_epoch: 500 + batch_size: 5 + num_workers: 1 + clip: 3.0 +model: + name: RANF + config: + hidden_features: 128 + hidden_layers: 4 + conv_layers: 4 + spec_hidden_layers: 2 + itd_hidden_layers: 2 + scale: 1.0 + dropout: 0.1 + n_listeners: 200 + rnn: LSTM + peft: lora + norm: LayerNorm + spec_res: false + itd_res: true + itd_skip_connection: true +optimizer: + name: RAdam + config: + lr: 0.001 +scheduler: + name: ReduceLROnPlateau + config: + mode: min + factor: 0.9 + patience: 10 +loss: + weight_itd: 1.0 + eps: 1.0e-05 + threshold_itd: 0.5 diff --git a/preprocess_sonicom.sh b/preprocess_sonicom.sh new file mode 100644 index 0000000..fb3b6a4 --- /dev/null +++ b/preprocess_sonicom.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +original_path=$1 +sonicom_path=$2 + +mkdir -p "${sonicom_path}/kemar" +mkdir -p "${sonicom_path}/myhrtf" +mkdir -p "${sonicom_path}/subjects" + +fnames=`find -L $original_path -type f -name *_FreeFieldCompMinPhase_48kHz.sofa` + +for fname in $fnames; do + fbase=`basename "$fname"` + + if [ "`echo $fname | grep 'KEMAR'`" ]; then + cp -n $fname "${sonicom_path}/kemar/${fbase}" + + elif [ "`echo $fname | grep 'MyHRTF'`" ]; then + cp -n $fname "${sonicom_path}/myhrtf/${fbase}" + + else + cp -n $fname "${sonicom_path}/subjects/${fbase}" + fi +done diff --git a/ranf/1_pretraining_neural_field.py b/ranf/1_pretraining_neural_field.py new file mode 100644 index 0000000..8e199cc --- /dev/null +++ b/ranf/1_pretraining_neural_field.py @@ -0,0 +1,181 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import argparse +import logging +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from omegaconf import OmegaConf +from torch.utils.data import DataLoader +from tqdm import tqdm + +from ranf.utils import neural_field_icassp as neural_field +from ranf.utils.loss_functions import ild_diff_loss, itd_diff_loss, lsd_loss +from ranf.utils.sonicom_dataset_retrieval import SONICOMMulti +from ranf.utils.util import count_parameters, db2linear, linear2db, plot_hrtf, seed_everything + + +def forward(data, model, config): + tgt_spec, tgt_ild, tgt_itd, tgt_loc, ret_specs, ret_itds, _, tgt_sidx, ret_sidxs = data + spec_db, ret_specs_db = linear2db(tgt_spec), linear2db(ret_specs) + pred_db, pred_itd = model(ret_specs_db, ret_itds, tgt_loc, tgt_sidx, ret_sidxs) + pred = db2linear(pred_db) + + if model.training: + threshold = config.threshold_itd + else: + threshold = 0.0 + + loss_val = torch.mean(lsd_loss(spec_db, pred_db, use_index=False)) + itd_diff_loss_val = torch.mean(itd_diff_loss(tgt_itd, pred_itd[:, 0], threshold=threshold)) + loss_val = loss_val + config.weight_itd * itd_diff_loss_val + if model.training: + return loss_val + + # These metrcis are only for validation + lsd_loss_val = torch.mean(lsd_loss(spec_db, pred_db, use_index=True)) + ild_diff_loss_val = torch.mean(ild_diff_loss(tgt_spec, pred, tgt_ild)) + return loss_val, lsd_loss_val, ild_diff_loss_val, itd_diff_loss_val + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("config_path", type=str) + args = parser.parse_args() + + path = Path(args.config_path) + config = OmegaConf.load(path.joinpath("config.yaml")) + seed_everything(config.seed) + + log = path.joinpath("log") + log.mkdir(parents=True, exist_ok=True) + + log_name = log.joinpath("exp.log") + fig_name = log.joinpath("hrtf.png") + logging.basicConfig(filename=log_name, level=logging.INFO) + + # Prepare dataset and dataloader + tr_dataset = SONICOMMulti( + config.dataset, + stage="pretrain", + mode="train", + ) + dev_dataset = SONICOMMulti( + config.dataset, + stage="pretrain", + mode="valid", + ) + + tr_data_loader = DataLoader( + tr_dataset, + batch_size=config.learning.batch_size, + shuffle=True, + drop_last=True, + num_workers=config.learning.num_workers, + pin_memory=True, + persistent_workers=config.persistent_workers, + ) + dev_data_loader = DataLoader( + dev_dataset, + batch_size=config.learning.batch_size, + shuffle=False, + drop_last=False, + num_workers=config.learning.num_workers, + pin_memory=True, + persistent_workers=config.persistent_workers, + ) + assert len(tr_dataset) > 0 and len(dev_dataset) > 0, len(dev_dataset) + + # Prepare model + model = getattr(neural_field, config.model.name)(**config.model.config) + model = model.to(config.device) + logging.info(f"Number of trainable parameters: {count_parameters(model)}") + + if hasattr(config.model, "init_path"): + model.load_state_dict(torch.load(config.init_path, map_location=config.device)) + + # Prepare the optimizer and scheduler + optimizer = getattr(optim, config.optimizer.name)(model.parameters(), **config.optimizer.config) + + if hasattr(config, "scheduler"): + scheduler = getattr(optim.lr_scheduler, config.scheduler.name)(optimizer, **config.scheduler.config) + else: + scheduler = None + + tr_loss, dev_loss = [], [] + dev_loss_min = 1.0e15 + early_stop = 0 + + logging.info("Start training...") + for epoch in range(config.learning.num_epoch): + + # Training + running_loss = [] + model.train() + for data in tqdm(tr_data_loader): + data = [x.to(config.device) for x in data] + loss_val = forward(data, model, config.loss) + optimizer.zero_grad() + loss_val.backward() + nn.utils.clip_grad_norm_(model.parameters(), config.learning.clip) + optimizer.step() + running_loss.append(loss_val.item()) + + tr_loss.append(np.mean(running_loss)) + + # Validation + running_loss, running_lsd = [], [] + running_ild_diff, running_itd_diff = [], [] + model.eval() + for data in tqdm(dev_data_loader): + data = [x.to(config.device) for x in data] + loss_val, lsd_val, ild_diff_val, itd_diff_val = forward(data, model, config.loss) + running_loss.append(loss_val.item()) + running_lsd.append(lsd_val.item()) + running_itd_diff.append(itd_diff_val.item()) + running_ild_diff.append(ild_diff_val.item()) + + # Visualization + with torch.no_grad(): + data = dev_dataset[epoch % len(dev_dataset)] + data = [torch.tensor(x)[None, ...].to(config.device) for x in data] + tgt_spec, _, _, tgt_locs, ret_specs, ret_itds, _, tgt_sidx, ret_sidxs = data + spec_db, ret_specs_db = linear2db(tgt_spec), linear2db(ret_specs) + pred_db, _ = model(ret_specs_db, ret_itds, tgt_locs, tgt_sidx, ret_sidxs) + lsd = torch.mean(lsd_loss(spec_db, pred_db)).item() + plot_hrtf(fig_name, spec_db, pred_db, lsd) + + dev_loss.append(np.mean(running_loss)) + + if scheduler is not None: + scheduler.step(dev_loss[-1]) + + logging.info(f"Epoch {epoch}") + logging.info(f"tr_loss: {tr_loss[-1]}, dev_loss: {dev_loss[-1]}") + logging.info(f"dev lsd: {np.mean(running_lsd)}") + logging.info(f"dev ild diff: {np.mean(running_ild_diff)}") + logging.info(f"dev itd diff: {np.mean(running_itd_diff)}") + + if dev_loss[-1] <= dev_loss_min: + dev_loss_min = dev_loss[-1] + early_stop = 0 + torch.save(model.state_dict(), path.joinpath("best.ckpt")) + else: + early_stop += 1 + + if early_stop == config.learning.patience: + logging.info(f"Early stopping at epoch {epoch}") + break + + if np.isnan(dev_loss[-1]): + logging.info("Loss is Nan. Training is stopped") + break + + +if __name__ == "__main__": + main() diff --git a/ranf/2_adapting_neural_field.py b/ranf/2_adapting_neural_field.py new file mode 100644 index 0000000..128e8c5 --- /dev/null +++ b/ranf/2_adapting_neural_field.py @@ -0,0 +1,191 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import argparse +import logging +import pathlib + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from omegaconf import OmegaConf +from torch.utils.data import DataLoader +from tqdm import tqdm + +from ranf.utils import neural_field_icassp as neural_field +from ranf.utils.loss_functions import ild_diff_loss, itd_diff_loss, lsd_loss +from ranf.utils.sonicom_dataset_retrieval import SONICOMMulti +from ranf.utils.util import count_parameters, db2linear, linear2db, seed_everything + + +def freeze_model_for_peft(model): + for name, param in model.named_parameters(): + if "embed_layer" not in name: + # This param is independent of the target subject and should be frozen + param.requires_grad = False + continue + + if "embed_layer_bitfit" in name: + # In BitFit, the subject-dependent bias is computed by an FC layer + if "bias" in name: + param.requires_grad = False + else: + param.requires_grad = True + logging.info(f"{name} is trainable") + + elif "embed_layer_lorau" in name: + # One low-rank matrix for LoRA "u" always depends on the target subject + if "bias" in name: + param.requires_grad = False + else: + param.requires_grad = True + logging.info(f"{name} is trainable") + + elif "embed_layer_lorav" in name: + # Another low-rank matrix for LoRA "v" depends on the retrieved subject in the core-block of RANF + if "bias" in name: + param.requires_grad = False + else: + if isinstance(model, neural_field.RANF): + if int(name.split(".")[1]) < model.hidden_layers: + param.requires_grad = False + else: + param.requires_grad = True + logging.info(f"{name} is trainable") + else: + param.requires_grad = True + logging.info(f"{name} is trainable") + + else: + raise ValueError("Invalid parameter name") + + logging.info(f"Number of trainable parameters: {count_parameters(model)}") + + +def forward(data, model, config): + tgt_spec, tgt_ild, tgt_itd, tgt_loc, ret_specs, ret_itds, _, tgt_sidx, ret_sidxs = data + + spec_db, ret_specs_db = linear2db(tgt_spec), linear2db(ret_specs) + pred_db, pred_itd = model(ret_specs_db, ret_itds, tgt_loc, tgt_sidx, ret_sidxs) + pred = db2linear(pred_db) + + if model.training: + threshold = config.threshold_itd + else: + threshold = 0.0 + + loss_val = torch.mean(lsd_loss(spec_db, pred_db, use_index=False)) + itd_diff_loss_val = torch.mean(itd_diff_loss(tgt_itd, pred_itd[:, 0], threshold=threshold)) + loss_val = loss_val + config.weight_itd * itd_diff_loss_val + if model.training: + return loss_val + + # These metrcis are only for validation + lsd_loss_val = torch.mean(lsd_loss(spec_db, pred_db, use_index=True)) + ild_diff_loss_val = torch.mean(ild_diff_loss(tgt_spec, pred, tgt_ild)) + return loss_val, lsd_loss_val, ild_diff_loss_val, itd_diff_loss_val + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("config_path", type=str) + args = parser.parse_args() + + path = pathlib.Path(args.config_path) + config = OmegaConf.load(path.joinpath("config.yaml")) + seed_everything(config.seed) + + log = path.joinpath("log") + log.mkdir(parents=True, exist_ok=True) + log.joinpath("adaptation").mkdir(parents=True, exist_ok=True) + log_name = log.joinpath("adaptation").joinpath("adaptation.log") + logging.basicConfig(filename=log_name, level=logging.INFO) + + # Prepare dataset and dataloader + tr_dataset = SONICOMMulti( + config.dataset, + stage="adaptation", + mode="train", + ) + dev_dataset = SONICOMMulti( + config.dataset, + stage="adaptation", + mode="valid", + ) + tr_data_loader = DataLoader( + tr_dataset, + batch_size=config.adaptation.batch_size, + shuffle=True, + drop_last=True, + num_workers=config.adaptation.num_workers, + pin_memory=True, + persistent_workers=config.persistent_workers, + ) + dev_data_loader = DataLoader( + dev_dataset, + batch_size=config.adaptation.batch_size, + shuffle=False, + drop_last=False, + num_workers=config.adaptation.num_workers, + pin_memory=True, + persistent_workers=config.persistent_workers, + ) + + # Prepare model + model = getattr(neural_field, config.model.name)(**config.model.config) + model = model.to(config.device) + model.load_state_dict(torch.load(path.joinpath("best.ckpt"), map_location=config.device)) + + # Freezing the model except for the subject-specific parameters + freeze_model_for_peft(model) + + # Prepare the optimizer and scheduler + optimizer = getattr(optim, config.optimizer.name)(model.parameters(), **config.optimizer.config) + + tr_loss, tr_loss_min = [], 1.0e15 + logging.info("Start adaptation...") + for epoch in range(config.adaptation.num_epoch): + + # Adaptation + model.train() + for data in tqdm(tr_data_loader): + data = [x.to(config.device) for x in data] + loss_val = forward(data, model, config.loss) + optimizer.zero_grad() + loss_val.backward() + nn.utils.clip_grad_norm_(model.parameters(), config.adaptation.clip) + optimizer.step() + + # Validation but on adaptation data itself + running_loss = [] + running_lsd, running_ild_diff, running_itd_diff = [], [], [] + model.eval() + for data in tqdm(dev_data_loader): + data = [x.to(config.device) for x in data] + loss_val, lsd_val, ild_diff_val, itd_diff_val = forward(data, model, config.loss) + running_loss.append(loss_val.item()) + running_lsd.append(lsd_val.item()) + running_itd_diff.append(itd_diff_val.item()) + running_ild_diff.append(ild_diff_val.item()) + + tr_loss.append(np.mean(running_loss)) + + logging.info(f"Epoch {epoch}") + logging.info(f"tr_loss: {tr_loss[-1]}") + logging.info(f"dev lsd: {np.mean(running_lsd)}") + logging.info(f"dev ild diff: {np.mean(running_ild_diff)}") + logging.info(f"dev itd diff: {np.mean(running_itd_diff)}") + + if tr_loss[-1] <= tr_loss_min: + tr_loss_min = tr_loss[-1] + torch.save(model.state_dict(), path.joinpath("adaptation.ckpt")) + + if np.isnan(tr_loss[-1]): + logging.info("Loss is Nan. Training is stopped") + break + + +if __name__ == "__main__": + main() diff --git a/ranf/3_evaluating_neural_field.py b/ranf/3_evaluating_neural_field.py new file mode 100644 index 0000000..91346ac --- /dev/null +++ b/ranf/3_evaluating_neural_field.py @@ -0,0 +1,109 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import argparse +import logging +import os +import pathlib + +import numpy as np +import sofar as sf +import torch +from omegaconf import OmegaConf +from spatialaudiometrics import lap_challenge as lap + +from ranf.utils import neural_field_icassp as neural_field +from ranf.utils.config import TGTDIDXS003, TGTDIDXS005, TGTDIDXS019, TGTDIDXS100 +from ranf.utils.reconstruction import hrtf2hrir_minph +from ranf.utils.sonicom_dataset_retrieval import SONICOMMultiInference +from ranf.utils.util import db2linear, linear2db, seed_everything + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("config_path", type=str) + parser.add_argument("--save_results", action="store_true") + args = parser.parse_args() + + path = pathlib.Path(args.config_path) + config = OmegaConf.load(path.joinpath("config.yaml")) + seed_everything(config.seed) + + if config.dataset.upsample == 3: + seen_didxs = TGTDIDXS003 + + elif config.dataset.upsample == 5: + seen_didxs = TGTDIDXS005 + + elif config.dataset.upsample == 19: + seen_didxs = TGTDIDXS019 + + elif config.dataset.upsample == 100: + seen_didxs = TGTDIDXS100 + + else: + raise ValueError(f"Invalid upsampling target: {config.dataset.upsample}.") + + unseen_didxs = sorted(list(set(range(793)) - set(seen_didxs))) + + log = path.joinpath("log") + log.mkdir(parents=True, exist_ok=True) + log.joinpath("eval").mkdir(parents=True, exist_ok=True) + log_name = log.joinpath("eval").joinpath("eval.log") + logging.basicConfig(filename=log_name, level=logging.INFO) + + eval_dataset = SONICOMMultiInference(config.dataset) + + model = getattr(neural_field, config.model.name)(**config.model.config) + model = model.to(config.device) + model.load_state_dict(torch.load(path.joinpath("adaptation.ckpt"), map_location=config.device)) + + logging.info("Start evaluation...") + model.eval() + pred_dbs, pred_itds = [], [] + for data in eval_dataset: + sofa_file, hrir = data[:2] + tgt_loc, ret_specs, ret_itds, _, tgt_sidx, ret_sidxs = (torch.tensor(x).to(config.device) for x in data[2:]) + ret_specs_db = linear2db(ret_specs) + pred_db, pred_itd = model(ret_specs_db, ret_itds, tgt_loc, tgt_sidx, ret_sidxs) + + pred_db = pred_db.detach().to("cpu").numpy() + pred = db2linear(pred_db) + pred_itd = pred_itd.detach().to("cpu").numpy() + pred_hrir = hrtf2hrir_minph(pred, itd=pred_itd, nfft=hrir.shape[-1]) + + target_path = log.joinpath("eval").joinpath(f"target_p{data[-2][0]+1:04}.sofa") + pred_path = log.joinpath("eval").joinpath(f"pred_p{data[-2][0]+1:04}.sofa") + + sofa_file.Data_IR = sofa_file.Data_IR[unseen_didxs, :, :] + sofa_file.SourcePosition = sofa_file.SourcePosition[unseen_didxs, :] + sofa_file.MeasurementSourceAudioChannel = sofa_file.MeasurementSourceAudioChannel[unseen_didxs] + sf.write_sofa(target_path, sofa_file) + + sofa_file.Data_IR = pred_hrir.astype(np.float64)[unseen_didxs, :] + sf.write_sofa(pred_path, sofa_file) + + metrics = lap.calculate_task_two_metrics(str(target_path), str(pred_path))[0] + + logging.info(f"P{data[-2][0]+1:04} evaluation") + logging.info(f"ITD difference (µs): {metrics[0]}") + logging.info(f"ILD difference (dB): {metrics[1]}") + logging.info(f"LSD (dB): {metrics[2]}") + + os.remove(target_path) + os.remove(pred_path) + + pred_dbs.append(pred_db) + pred_itds.append(pred_itd) + + if args.save_results: + np.savez( + log.joinpath("eval").joinpath("prediction.npz"), + pred_dbs=np.array(pred_dbs), + pred_itds=np.array(pred_itds), + ) + + +if __name__ == "__main__": + main() diff --git a/ranf/__init__.py b/ranf/__init__.py new file mode 100644 index 0000000..8fccc03 --- /dev/null +++ b/ranf/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/ranf/compute_distance_matrices_for_spec_itd.py b/ranf/compute_distance_matrices_for_spec_itd.py new file mode 100644 index 0000000..04e7e0a --- /dev/null +++ b/ranf/compute_distance_matrices_for_spec_itd.py @@ -0,0 +1,88 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import argparse +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +from ranf.utils.config import LSDFREQIDX, TGTDIDXS003, TGTDIDXS005, TGTDIDXS019, TGTDIDXS100 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_path", type=str) + parser.add_argument("ouput_path", type=str) + parser.add_argument("upsample", type=int) + parser.add_argument("test_size", type=int) + parser.add_argument("--skip_78", action="store_true") + parser.add_argument("--calibrate_itdoffset", action="store_true") + args = parser.parse_args() + + Path(args.ouput_path).mkdir(parents=True, exist_ok=True) + + if args.upsample == 3: + seen_didxs = TGTDIDXS003 + + elif args.upsample == 5: + seen_didxs = TGTDIDXS005 + + elif args.upsample == 19: + seen_didxs = TGTDIDXS019 + + elif args.upsample == 100: + seen_didxs = TGTDIDXS100 + else: + raise ValueError(f"given upsample is invalid. It should be in (3, 5, 19, 100) but {args.upsample}.") + + # Load specs and ITDs + if args.calibrate_itdoffset: + npz_path = Path(args.input_path).joinpath("features_and_locs_with_azimuth_calibration.npz") + else: + npz_path = Path(args.input_path).joinpath("features_and_locs_wo_azimuth_calibration.npz") + + npz = np.load(npz_path) + specs = npz["specs"][:, seen_didxs, :, :] + itds = npz["itds"][:, seen_didxs] + nsubjects = itds.shape[0] + nsubjects_train_valid = nsubjects - args.test_size + assert ( + nsubjects_train_valid > 79 + ), "The number of training subjects is assumed to be large enough to exclude the 78th subject" + + # Compute the distance matrices + specs = 20 * np.log10(specs[..., LSDFREQIDX] + 1e-15) + lsd_mat = np.inf * np.ones((nsubjects, nsubjects), dtype=np.float32) + itdd_mat = np.inf * np.ones((nsubjects, nsubjects), dtype=np.float32) + + for n in tqdm(range(nsubjects_train_valid)): + if args.skip_78 and n == 78: + continue + + for m in range(n + 1, nsubjects): + if args.skip_78 and m == 78: + continue + + mse = np.mean(np.square(specs[n, ...] - specs[m, ...]), -1) + lsd = np.mean(np.sqrt(mse)) + lsd_mat[n, m] = lsd + lsd_mat[m, n] = lsd + + itdd = np.mean(np.abs(itds[n, :] - itds[m, :])) + itdd_mat[n, m] = itdd + itdd_mat[m, n] = itdd + + lsd_mat[:, nsubjects_train_valid:] = np.inf + itdd_mat[:, nsubjects_train_valid:] = np.inf + + np.savez( + Path(args.ouput_path).joinpath("lsd_itdd_mats.npz"), + lsd_mat=lsd_mat, + itdd_mat=itdd_mat, + ) + + +if __name__ == "__main__": + main() diff --git a/ranf/compute_spec_ild_itd_for_sonicom_datasets.py b/ranf/compute_spec_ild_itd_for_sonicom_datasets.py new file mode 100644 index 0000000..60229f8 --- /dev/null +++ b/ranf/compute_spec_ild_itd_for_sonicom_datasets.py @@ -0,0 +1,58 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import argparse +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +from ranf.utils.util import extract_features + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_path", type=str) + parser.add_argument("output_path", type=str) + parser.add_argument("--calibrate_itdoffset", action="store_true") + args = parser.parse_args() + + sofa_paths = Path(args.input_path).glob("*.sofa") + output_path = Path(args.output_path) + output_path.mkdir(parents=True, exist_ok=True) + + specs, ilds, itds, locs = [], [], [], [] + for idx, path in enumerate(tqdm(sorted(sofa_paths))): + spec, ild, itd, loc = extract_features(path) + + assert int(path.stem.split("_")[0][1:]) - 1 == idx, (idx, path) + assert len(itd) == 793, "HRTF should follow the SONICOM spatial grid" + + if args.calibrate_itdoffset: + # 4 and 414 correspond to the front and left, respectively. + alpha = (itd[414] - itd[4]) / (loc[414, 0] - loc[4, 0]) + beta = itd[4] / alpha + loc[:, 0] = (loc[:, 0] + beta + 360) % 360 + + specs.append(spec.astype(np.float32)) + ilds.append(ild.astype(np.float32)) + itds.append(np.array(itd)) + locs.append(loc.astype(np.float32)) + + if args.calibrate_itdoffset: + output_path = output_path.joinpath("features_and_locs_with_azimuth_calibration.npz") + else: + output_path = output_path.joinpath("features_and_locs_wo_azimuth_calibration.npz") + + np.savez( + output_path, + specs=np.array(specs), + ilds=np.array(ilds), + itds=np.array(itds), + locs=np.array(locs), + ) + + +if __name__ == "__main__": + main() diff --git a/ranf/evaluating_hrtf_selection.py b/ranf/evaluating_hrtf_selection.py new file mode 100644 index 0000000..3a11d7a --- /dev/null +++ b/ranf/evaluating_hrtf_selection.py @@ -0,0 +1,110 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import argparse +import logging +import os +from pathlib import Path + +import numpy as np +import sofar as sf +from omegaconf import OmegaConf +from spatialaudiometrics import lap_challenge as lap +from spatialaudiometrics import load_data as ld + +from ranf.utils.config import TGTDIDXS003, TGTDIDXS005, TGTDIDXS019, TGTDIDXS100 +from ranf.utils.util import seed_everything + + +def load_hrtf(fname): + hrtf = ld.HRTF(fname) + return hrtf.hrir, hrtf.locs + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("config_path", type=str) + parser.add_argument("--save_results", action="store_true") + args = parser.parse_args() + + path = Path(args.config_path) + config = OmegaConf.load(path.joinpath("config.yaml")) + + seed_everything(config.seed) + + if config.dataset.upsample == 3: + seen_didxs = TGTDIDXS003 + + elif config.dataset.upsample == 5: + seen_didxs = TGTDIDXS005 + + elif config.dataset.upsample == 19: + seen_didxs = TGTDIDXS019 + + elif config.dataset.upsample == 100: + seen_didxs = TGTDIDXS100 + + else: + raise ValueError(f"dataset.upsample should be in (3, 5, 19, 100) but is {config.dataset.upsample}.") + + unseen_didxs = sorted(list(set(range(793)) - set(seen_didxs))) + + eval_path = path.joinpath("log").joinpath("eval") + eval_path.mkdir(parents=True, exist_ok=True) + log_name = eval_path.joinpath("eval.log") + logging.basicConfig(filename=log_name, level=logging.INFO) + + sonicom_path = Path(config.dataset.features).parent.parent.joinpath("subjects") + hrtf_type = "FreeFieldCompMinPhase_48kHz" + + npz = np.load(config.dataset.retrieval) + lsd_mat = npz["lsd_mat"] + itdd_mat = npz["itdd_mat"] + + for sidx in config.dataset.test_subjects: + if config.dataset.retrieval_priority == "itdd": + _sidxs = np.where(itdd_mat[sidx, :] == np.min(itdd_mat[sidx, :]))[0] + pred_sidx = _sidxs[np.argmin(lsd_mat[sidx, _sidxs])] + + elif config.dataset.retrieval_priority == "lsd": + pred_sidx = np.argmin(lsd_mat[sidx, :]) + + else: + raise ValueError(f"Invalid retrieval option: {config.dataset.retrieval_priority}") + + prediction = sonicom_path.joinpath(f"P{pred_sidx+1:04}_{hrtf_type}.sofa") + pred_hrir, _ = load_hrtf(prediction) + + fname = sonicom_path.joinpath(f"P{sidx+1:04}_{hrtf_type}.sofa") + hrir, locs_deg = load_hrtf(fname) + + sofa_file = sf.read_sofa(fname) + sofa_file.SourcePosition = locs_deg[unseen_didxs, :] + sofa_file.MeasurementSourceAudioChannel = sofa_file.MeasurementSourceAudioChannel[unseen_didxs] + + target_path = eval_path.joinpath(f"target_p{sidx+1:04}.sofa") + sofa_file.Data_IR = hrir[unseen_didxs, :] + sf.write_sofa(target_path, sofa_file) + + pred_path = eval_path.joinpath(f"pred_p{sidx+1:04}.sofa") + sofa_file.Data_IR = pred_hrir[unseen_didxs, :] + sf.write_sofa(pred_path, sofa_file) + + metrics = lap.calculate_task_two_metrics( + str(target_path), + str(pred_path), + )[0] + + logging.info(f"P{sidx+1:04} evaluation") + logging.info(f"ITD difference (µs): {metrics[0]}") + logging.info(f"ILD difference (dB): {metrics[1]}") + logging.info(f"LSD (dB): {metrics[2]}") + + if not args.save_results: + os.remove(target_path) + os.remove(pred_path) + + +if __name__ == "__main__": + main() diff --git a/ranf/evaluating_nearest_neighbor.py b/ranf/evaluating_nearest_neighbor.py new file mode 100644 index 0000000..171abbc --- /dev/null +++ b/ranf/evaluating_nearest_neighbor.py @@ -0,0 +1,104 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import argparse +import logging +import os +from pathlib import Path + +import numpy as np +import sofar as sf +from omegaconf import OmegaConf +from spatialaudiometrics import lap_challenge as lap +from spatialaudiometrics import load_data as ld + +from ranf.utils.config import TGTDIDXS003, TGTDIDXS005, TGTDIDXS019, TGTDIDXS100 +from ranf.utils.util import seed_everything, to_cartesian + + +def load_hrtf(fname): + hrtf = ld.HRTF(fname) + return hrtf.hrir, hrtf.locs + + +def search_nn(loc, measured_locs): + dist = np.linalg.norm(measured_locs - loc[None, :], axis=-1) + best_idx = np.argsort(dist)[0] + return best_idx + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("config_path", type=str) + parser.add_argument("--save_results", action="store_true") + args = parser.parse_args() + + path = Path(args.config_path) + config = OmegaConf.load(path.joinpath("config.yaml")) + + seed_everything(config.seed) + + if config.dataset.upsample == 3: + seen_didxs = TGTDIDXS003 + + elif config.dataset.upsample == 5: + seen_didxs = TGTDIDXS005 + + elif config.dataset.upsample == 19: + seen_didxs = TGTDIDXS019 + + elif config.dataset.upsample == 100: + seen_didxs = TGTDIDXS100 + + else: + raise ValueError(f"dataset.upsample should be in (3, 5, 19, 100) but is {config.dataset.upsample}.") + + unseen_didxs = sorted(list(set(range(793)) - set(seen_didxs))) + + eval_path = path.joinpath("log").joinpath("eval") + eval_path.mkdir(parents=True, exist_ok=True) + log_name = eval_path.joinpath("eval.log") + logging.basicConfig(filename=log_name, level=logging.INFO) + + sonicom_path = Path(config.dataset.features).parent.parent.joinpath("subjects") + hrtf_type = "FreeFieldCompMinPhase_48kHz" + for sidx in config.dataset.test_subjects: + fname = sonicom_path.joinpath(f"P{sidx+1:04}_{hrtf_type}.sofa") + hrir, locs_deg = load_hrtf(fname) + + locs_cart = to_cartesian(np.deg2rad(locs_deg)) + pred_hrir = np.zeros_like(hrir[unseen_didxs, :]) + for idx, didx in enumerate(unseen_didxs): + best_idx = search_nn(locs_cart[didx, :], locs_cart[seen_didxs, :]) + pred_hrir[idx, :] = hrir[seen_didxs[best_idx], :] + + sofa_file = sf.read_sofa(fname) + sofa_file.SourcePosition = locs_deg[unseen_didxs, :] + sofa_file.MeasurementSourceAudioChannel = sofa_file.MeasurementSourceAudioChannel[unseen_didxs] + + target_path = eval_path.joinpath(f"target_p{sidx+1:04}.sofa") + sofa_file.Data_IR = hrir[unseen_didxs, :] + sf.write_sofa(target_path, sofa_file) + + pred_path = eval_path.joinpath(f"pred_p{sidx+1:04}.sofa") + sofa_file.Data_IR = pred_hrir.astype(np.float64) + sf.write_sofa(pred_path, sofa_file) + + metrics = lap.calculate_task_two_metrics( + str(target_path), + str(pred_path), + )[0] + + logging.info(f"P{sidx+1:04} evaluation") + logging.info(f"ITD difference (µs): {metrics[0]}") + logging.info(f"ILD difference (dB): {metrics[1]}") + logging.info(f"LSD (dB): {metrics[2]}") + + if not args.save_results: + os.remove(target_path) + os.remove(pred_path) + + +if __name__ == "__main__": + main() diff --git a/ranf/prepare_single_fold.py b/ranf/prepare_single_fold.py new file mode 100644 index 0000000..ed3bbe3 --- /dev/null +++ b/ranf/prepare_single_fold.py @@ -0,0 +1,64 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import argparse +from pathlib import Path + +import numpy as np +from omegaconf import OmegaConf + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("dump_path", type=str) + parser.add_argument("conf_path", type=str) + parser.add_argument("exp_path", type=str) + parser.add_argument("sonicom_path", type=str) + parser.add_argument("upsample", type=int) + parser.add_argument("valid_size", type=int) + parser.add_argument("test_size", type=int) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--skip_78", action="store_true") + parser.add_argument("--calibrate_itdoffset", action="store_true") + args = parser.parse_args() + + conf_path = Path(args.conf_path) + config = OmegaConf.load(conf_path.joinpath("original_config.yaml")) + config.dataset.upsample = args.upsample + + config.dataset.retrieval = str(Path(args.dump_path).joinpath("lsd_itdd_mats.npz")) + + npz_path = Path(args.sonicom_path).joinpath("npzs") + if args.calibrate_itdoffset: + config.dataset.features = str(npz_path.joinpath("features_and_locs_with_azimuth_calibration.npz")) + else: + config.dataset.features = str(npz_path.joinpath("features_and_locs_wo_azimuth_calibration.npz")) + + npz = np.load(config.dataset.features) + nsubjects = npz["itds"].shape[0] + assert nsubjects == 200, "The number of subjects in the SONICOM dataset should be 200." + + config.seed = 0 + train_size = nsubjects - (args.valid_size + args.test_size) + train_subjetcs = list(range(train_size)) + if args.skip_78: + train_subjetcs.remove(78) + config.dataset.train_subjects = train_subjetcs + + train_valid_size = train_size + args.valid_size + valid_subjects = list(range(train_size, train_valid_size)) + config.dataset.valid_subjects = valid_subjects + + test_subjetcs = list(range(train_valid_size, nsubjects)) + config.dataset.test_subjects = test_subjetcs + assert len(test_subjetcs) == args.test_size + + exp_path = Path(args.exp_path) + exp_path.mkdir(parents=True, exist_ok=True) + with open(exp_path.joinpath("config.yaml"), "w") as f: + OmegaConf.save(config=config, f=f) + + +if __name__ == "__main__": + main() diff --git a/ranf/summarize_evaluation_result.py b/ranf/summarize_evaluation_result.py new file mode 100644 index 0000000..a338666 --- /dev/null +++ b/ranf/summarize_evaluation_result.py @@ -0,0 +1,44 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import argparse +import re +from pathlib import Path + +import numpy as np + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("config_path", type=str) + args = parser.parse_args() + + eval_path = Path(args.config_path).joinpath("log").joinpath("eval") + with open(eval_path.joinpath("eval.log")) as f: + lines = f.readlines() + + results = {"ITD": [[], []], "ILD": [[], []], "LSD": [[], []]} + threhold = {"ITD": 62.5, "ILD": 4.4, "LSD": 7.4} + + for line in lines: + tmp = line.split(":")[2].split()[0] + if re.match(r"^P\d{4}$", tmp): + pidx = int(tmp[1:]) + + for key in results.keys(): + if key in line: + x = float(line.rstrip().split(":")[-1]) + results[key][0].append(x) + if x > threhold[key]: + results[key][1].append(pidx) + + for key in results.keys(): + print(key) + print(f"Mean: {np.mean(results[key][0])}") + print(f"Max: {np.max(results[key][0])}") + print(f"Subjects over threshold: {results[key][1]}") + + +if __name__ == "__main__": + main() diff --git a/ranf/utils/__init__.py b/ranf/utils/__init__.py new file mode 100644 index 0000000..8fccc03 --- /dev/null +++ b/ranf/utils/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/ranf/utils/config.py b/ranf/utils/config.py new file mode 100644 index 0000000..39dfbff --- /dev/null +++ b/ranf/utils/config.py @@ -0,0 +1,380 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import numpy as np + +FS = 48000 +FREQS = np.array( + [ + 0.0, + 187.5, + 375.0, + 562.5, + 750.0, + 937.5, + 1125.0, + 1312.5, + 1500.0, + 1687.5, + 1875.0, + 2062.5, + 2250.0, + 2437.5, + 2625.0, + 2812.5, + 3000.0, + 3187.5, + 3375.0, + 3562.5, + 3750.0, + 3937.5, + 4125.0, + 4312.5, + 4500.0, + 4687.5, + 4875.0, + 5062.5, + 5250.0, + 5437.5, + 5625.0, + 5812.5, + 6000.0, + 6187.5, + 6375.0, + 6562.5, + 6750.0, + 6937.5, + 7125.0, + 7312.5, + 7500.0, + 7687.5, + 7875.0, + 8062.5, + 8250.0, + 8437.5, + 8625.0, + 8812.5, + 9000.0, + 9187.5, + 9375.0, + 9562.5, + 9750.0, + 9937.5, + 10125.0, + 10312.5, + 10500.0, + 10687.5, + 10875.0, + 11062.5, + 11250.0, + 11437.5, + 11625.0, + 11812.5, + 12000.0, + 12187.5, + 12375.0, + 12562.5, + 12750.0, + 12937.5, + 13125.0, + 13312.5, + 13500.0, + 13687.5, + 13875.0, + 14062.5, + 14250.0, + 14437.5, + 14625.0, + 14812.5, + 15000.0, + 15187.5, + 15375.0, + 15562.5, + 15750.0, + 15937.5, + 16125.0, + 16312.5, + 16500.0, + 16687.5, + 16875.0, + 17062.5, + 17250.0, + 17437.5, + 17625.0, + 17812.5, + 18000.0, + 18187.5, + 18375.0, + 18562.5, + 18750.0, + 18937.5, + 19125.0, + 19312.5, + 19500.0, + 19687.5, + 19875.0, + 20062.5, + 20250.0, + 20437.5, + 20625.0, + 20812.5, + 21000.0, + 21187.5, + 21375.0, + 21562.5, + 21750.0, + 21937.5, + 22125.0, + 22312.5, + 22500.0, + 22687.5, + 22875.0, + 23062.5, + 23250.0, + 23437.5, + 23625.0, + 23812.5, + 24000.0, + ] +) + +LSDFREQIDX = np.array( + [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + ] +) + +TGTDIDXS003 = [4, 11, 414] + +TGTDIDXS005 = [0, 4, 8, 203, 612] + +TGTDIDXS019 = [ + 0, + 4, + 8, + 11, + 14, + 18, + 22, + 265, + 269, + 273, + 278, + 282, + 286, + 529, + 533, + 537, + 542, + 546, + 550, +] + +TGTDIDXS100 = [ + 0, + 8, + 19, + 25, + 33, + 38, + 50, + 57, + 65, + 67, + 75, + 84, + 92, + 103, + 117, + 122, + 130, + 134, + 142, + 149, + 159, + 168, + 176, + 184, + 195, + 201, + 209, + 214, + 226, + 233, + 241, + 243, + 251, + 260, + 268, + 279, + 293, + 298, + 306, + 310, + 318, + 325, + 335, + 344, + 352, + 360, + 371, + 377, + 385, + 390, + 402, + 409, + 417, + 419, + 427, + 436, + 444, + 455, + 469, + 474, + 482, + 486, + 494, + 501, + 511, + 520, + 528, + 536, + 547, + 553, + 561, + 566, + 578, + 585, + 593, + 595, + 603, + 612, + 620, + 631, + 645, + 650, + 658, + 662, + 670, + 677, + 687, + 696, + 704, + 712, + 723, + 729, + 737, + 742, + 754, + 761, + 769, + 771, + 779, + 788, +] diff --git a/ranf/utils/loss_functions.py b/ranf/utils/loss_functions.py new file mode 100644 index 0000000..27aa12c --- /dev/null +++ b/ranf/utils/loss_functions.py @@ -0,0 +1,39 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import torch + +from ranf.utils.config import LSDFREQIDX + + +def lsd_loss(target, pred, dim=-1, use_index=False): + if use_index: + freqidx = torch.tensor(LSDFREQIDX, device=target.device) + else: + freqidx = torch.arange(target.shape[-1], device=target.device) + + mse = torch.mean(torch.square(target - pred)[:, :, freqidx], dim=dim) + rmse = torch.sqrt(mse) + return torch.mean(rmse, -1) + + +def itd_diff_loss(target, pred, sr=48000, threshold=0.0): + error = torch.abs(target - pred) + retval = torch.clamp(error, min=threshold) + return (1.0e6 / sr) * retval + + +def ild_diff_loss(target, pred, target_ild=None): + if target_ild is None: + target_ild = compute_ild(target) + + return torch.abs(target_ild - compute_ild(pred)) + + +def compute_ild(hrtf): + hrtf = torch.cat([hrtf, hrtf[..., 1:-1]], dim=-1) + rms = torch.linalg.norm(hrtf, dim=-1) + logrms = 20 * torch.log10(rms) + ild = logrms[:, 0] - logrms[:, 1] + return ild diff --git a/ranf/utils/neural_field_icassp.py b/ranf/utils/neural_field_icassp.py new file mode 100644 index 0000000..6e9d0dc --- /dev/null +++ b/ranf/utils/neural_field_icassp.py @@ -0,0 +1,588 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + + +class MLP(nn.Module): + def __init__(self, inch, outch, dropout=0.0, activation="GELU", bias=True): + super().__init__() + self.fc = nn.Linear(inch, outch, bias=bias) + self.activatin = getattr(nn, activation)() + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.activatin(self.fc(x)) + x = self.dropout(x) + return x + + +class LoRAMLP(nn.Module): + def __init__(self, inch, outch, dropout, activation="GELU", bias=True): + super().__init__() + self.fc = nn.Linear(inch, outch, bias=bias) + self.activatin = getattr(nn, activation)() + self.dropout = nn.Dropout(dropout) + + def forward(self, x, u=0.0, v=0.0, b=0.0): + """Forward + Args: + x (torch.Tensor): Feature tensor of [batch, inch] + u (torch.Tensor | float, optional): Vector to construct rank-1 matrix for LoRA [batch, inch] or 0.0 + v (torch.Tensor | float, optional): Vector to construct rank-1 matrix for LoRA [batch, outch] or 0.0 + b (torch.Tensor | float, optional): Additional bias for BitFit [batch, outch] or 0.0 + + Returns: + x (torch.Tensor) + """ + z = u * torch.mean(v * x, -1, keepdim=True) + x = self.fc(x) + z + b + x = self.dropout(self.activatin(x)) + return x + + +class LoRAMLP4RANF(nn.Module): + def __init__(self, inch, outch, dropout, activation="GELU", bias=True): + super().__init__() + self.fc = nn.Linear(inch, outch, bias=bias) + self.activatin = getattr(nn, activation)() + self.dropout = nn.Dropout(dropout) + + def forward(self, x, u=0.0, v=0.0, b=0.0): + """Forward + Args: + x (torch.Tensor): Feature tensor of [batch, freqs (downsampled), inch] + u (torch.Tensor | float, optional): Vector to construct rank-R matrix for LoRA [batch, inch, R] or 0.0 + v (torch.Tensor | float, optional): Vector to construct rank-R matrix for LoRA [batch, outch, R] or 0.0 + b (torch.Tensor | float, optional): Additional bias for BitFit [batch, outch] or 0.0 + + Returns: + x (torch.Tensor) + """ + z = torch.mean(v * x[..., None], -2) + z = torch.mean(u * z[..., None, :], -1) + x = self.fc(x) + z + b + x = self.dropout(self.activatin(x)) + return x + + +class LoRATACMLP4(nn.Module): + def __init__(self, inch, outch, dropout, activation="GELU", bias=True): + super().__init__() + self.in_linear_pass = nn.Linear(inch, inch // 2) + self.in_linear_ave = nn.Linear(inch, inch // 2) + self.in_activatin = getattr(nn, activation)() + self.in_dropout = nn.Dropout(dropout) + + self.out_linear_lora = nn.Linear(inch // 2 * 2, outch, bias=bias) + self.out_activatin = getattr(nn, activation)() + self.out_dropout = nn.Dropout(dropout) + + def forward(self, x, u=0, v=0, b=0): + """Forward + Args: + x (torch.Tensor): feature tensor of [batch, K, freqs (downsampled), inch] where K is # of retrievals + u (torch.Tensor | float, optional): vector to construct rank-R matrix for LoRA [batch, K, inch, R] or 0.0 + v (torch.Tensor | float, optional): vector to construct rank-R matrix for LoRA [batch, K, outch, R] or 0.0 + b (torch.Tensor | float, optional): additional bias for BitFit [batch, outch] or 0.0 + + Returns: + x (torch.Tensor) + """ + K = x.shape[1] + + # Modified transform-average-concatenation (TAC) + y = self.in_linear_ave(x) + y = torch.tile(torch.mean(y, dim=1, keepdim=True), (1, K, 1, 1)) + x = self.in_linear_pass(x) + x = torch.cat([x, y], dim=-1) + x = self.in_dropout(self.in_activatin(x)) + + # MLP with LoRA + z = torch.mean(v[..., None, :, :] * x[..., None], -2) + z = torch.mean(u[..., None, :, :] * z[..., None, :], -1) + x = self.out_linear_lora(x) + z + b + x = self.out_dropout(self.out_activatin(x)) + return x + + +class LSTMLoRATAC(nn.Module): + def __init__( + self, + hidden_features, + dropout=0.0, + rnn="LSTM", + activation="GELU", + norm="LayerNorm", + ): + super().__init__() + self.freq_blstm = getattr(nn, rnn)( + hidden_features, hidden_features // 2, bidirectional=True, num_layers=1, batch_first=True + ) + self.lora_mlp = LoRATACMLP4(hidden_features, hidden_features, dropout=dropout, bias=True, activation=activation) + self.norm = getattr(nn, norm)(hidden_features) + + def forward(self, x, u=0.0, v=0.0, b=0.0): + """Forward + Args: + x (torch.Tensor): feature tensor of [batch, K, freqs (downsampled), inch] where K is # of retrievals + u (torch.Tensor | float, optional): vector to construct rank-R matrix for LoRA [batch, K, inch, R] or 0.0 + v (torch.Tensor | float, optional): vector to construct rank-R matrix for LoRA [batch, K, outch, R] or 0.0 + b (torch.Tensor | float, optional): additional bias for BitFit [batch, outch] or 0.0 + + Returns: + x (torch.Tensor) + """ + batch, nretrieval, _, hidden_features = x.shape + x = x.reshape(batch * nretrieval, -1, hidden_features) + x = self.freq_blstm(x)[0] + x = x.reshape(batch, nretrieval, -1, hidden_features) + x = self.lora_mlp(x, u=u, v=v, b=b) + return self.norm(x) + + +class PEFTNeuralField(nn.Module): + def __init__( + self, + hidden_features=128, + hidden_layers=1, + out_features=258, + scale=1, + dropout=0.1, + n_listeners=210, + activation="GELU", + peft="lora", + itd_skip_connection=False, + ): + super().__init__() + + assert hidden_features % 2 == 0 + assert peft in {"lora", "bitfit"} + + self.hidden_layers = hidden_layers + self.hidden_features = hidden_features + self.out_features = out_features + self.itd_skip_connection = itd_skip_connection + self.peft = peft + + # For random Fourier feature + self.rng = np.random.default_rng(0) + bmat = scale * self.rng.normal(0.0, 1.0, (hidden_features // 2, 4)) + self.bmat = torch.nn.Parameter(torch.tensor(bmat.astype(np.float32)), requires_grad=False) + + # For MLP + self.hidden_loramlps = nn.ModuleList( + [ + LoRAMLP(hidden_features, hidden_features, dropout=dropout, bias=peft != "bitfit", activation=activation) + for _ in range(hidden_layers) + ] + ) + self.out_linear = LoRAMLP( + hidden_features, out_features, dropout=0.0, bias=peft != "bitfit", activation="Identity" + ) + + self.itd_net = nn.Sequential( + MLP(hidden_features, hidden_features // 2, dropout), MLP(hidden_features // 2, 1, activation="Identity") + ) + + # For PEFT + self.n_listeners = n_listeners + + if peft == "bitfit": + bitfit = [] + for _ in range(hidden_layers): + bitfit.append(nn.Linear(n_listeners, hidden_features, bias=False)) + + bitfit.append(nn.Linear(n_listeners, out_features, bias=False)) + self.embed_layer_bitfit = nn.ModuleList(bitfit) + + if peft == "lora": + lorau, lorav = [], [] + for _ in range(hidden_layers): + lorau.append(nn.Linear(n_listeners, hidden_features, bias=False)) + lorav.append(nn.Linear(n_listeners, hidden_features, bias=False)) + nn.init.zeros_(lorau[-1].weight) + + lorau.append(nn.Linear(n_listeners, out_features, bias=False)) + lorav.append(nn.Linear(n_listeners, hidden_features, bias=False)) + nn.init.zeros_(lorau[-1].weight) + + self.embed_layer_lorau = nn.ModuleList(lorau) + self.embed_layer_lorav = nn.ModuleList(lorav) + + def forward(self, ret_specs_db, ret_itds, tgt_loc, tgt_sidx, ret_sidxs): + """Forward + + Args: + ret_specs_db: Unused in neural field without retrievals + ret_itds: Unused in neural field without retrievals + tgt_loc (torch.Tensor): Sound source location (azimuth, elevation, distance) in [batch, 3] + tgt_sidx (torch.Tensor): Indices of the target subject in integer + ret_sidxs: Unused in neural field without retrievals + + Returns: + estimate (torch.Tensor): Estimated magnitude in [batch, 2, nfreqs] + itd (torch.Tensor): Estimated ITD in [batch, 1] + """ + batch = tgt_loc.shape[0] + azimuth, elevation = tgt_loc[:, 0], tgt_loc[:, 1] + sidxs = tgt_sidx + + onehot = F.one_hot(sidxs, self.n_listeners).type(torch.float32) + + emb = [azimuth.sin(), azimuth.cos(), elevation.sin(), elevation.cos()] + emb = torch.stack(emb, -1) @ self.bmat.T + emb = torch.concatenate([emb.sin(), emb.cos()], axis=-1) + + x = emb + for n in range(self.hidden_layers): + x = self.hidden_loramlps[n]( + x, + u=self.embed_layer_lorau[n](onehot) if self.peft == "lora" else 0, + v=self.embed_layer_lorav[n](onehot) if self.peft == "lora" else 0, + b=self.embed_layer_bitfit[n](onehot) if self.peft == "bitfit" else 0, + ) + + estimate = self.out_linear( + x, + u=self.embed_layer_lorau[-1](onehot) if self.peft == "lora" else 0, + v=self.embed_layer_lorav[-1](onehot) if self.peft == "lora" else 0, + b=self.embed_layer_bitfit[-1](onehot) if self.peft == "bitfit" else 0, + ) + estimate = estimate.reshape(batch, 2, -1) + + if self.itd_skip_connection: + x = x + emb + + itd = self.itd_net(x) + return estimate, itd + + +class CbCNeuralField(nn.Module): + def __init__( + self, + hidden_features=128, + embed_features=32, + hidden_layers=1, + out_features=258, + scale=1, + dropout=0.1, + n_listeners=210, + activation="GELU", + itd_skip_connection=False, + ): + super().__init__() + + assert hidden_features % 2 == 0 + + self.hidden_layers = hidden_layers + self.hidden_features = hidden_features + self.itd_skip_connection = itd_skip_connection + + # For random Fourier feature + self.rng = np.random.default_rng(0) + bmat = scale * self.rng.normal(0.0, 1.0, (hidden_features // 2, 4)) + self.bmat = torch.nn.Parameter(torch.tensor(bmat.astype(np.float32)), requires_grad=False) + + # For MLP + self.n_listeners = n_listeners + + # Latent vector for PEFT + self.embed_layer_lorau = nn.Linear( + n_listeners, + embed_features, + bias=False, + ) + self.hidden_mlps = nn.ModuleList( + [ + MLP( + hidden_features + embed_features if n == 0 else hidden_features, + hidden_features, + dropout=dropout, + bias=True, + activation=activation, + ) + for n in range(hidden_layers) + ] + ) + self.out_linear = MLP(hidden_features, out_features, dropout=0.0, bias=True, activation="Identity") + + self.itd_net = nn.Sequential( + MLP(hidden_features, hidden_features // 2, dropout), MLP(hidden_features // 2, 1, activation="Identity") + ) + + def forward(self, ret_specs_db, ret_itds, tgt_loc, tgt_sidx, ret_sidxs): + """Forward + + Args: + ret_specs_db: Unused in neural field without retrievals + ret_itds: Unused in neural field without retrievals + tgt_loc (torch.Tensor): Sound source location (azimuth, elevation, distance) in [batch, 3] + tgt_sidx (torch.Tensor): Indices of the target subject in integer + ret_sidxs: Unused in neural field without retrievals + + Returns: + estimate (torch.Tensor): Estimated magnitude in [batch, 2, nfreqs] + itd (torch.Tensor): Estimated ITD in [batch, 1] + """ + azimuth, elevation = tgt_loc[:, 0], tgt_loc[:, 1] + sidxs = tgt_sidx + + batch = azimuth.shape[0] + onehot = F.one_hot(sidxs, self.n_listeners).type(torch.float32) + listener_emb = self.embed_layer_lorau(onehot) + + emb = [azimuth.sin(), azimuth.cos(), elevation.sin(), elevation.cos()] + emb = torch.stack(emb, -1) @ self.bmat.T + emb = torch.concatenate([emb.sin(), emb.cos()], axis=-1) + + x = torch.concatenate([emb, listener_emb], axis=-1) + for n in range(self.hidden_layers): + x = self.hidden_mlps[n](x) + + estimate = self.out_linear(x) + estimate = estimate.reshape(batch, 2, -1) + + if self.itd_skip_connection: + x = x + emb + + itd = self.itd_net(x) + return estimate, itd + + +class RANF(nn.Module): + def __init__( + self, + hidden_features=128, + hidden_layers=1, + spec_hidden_layers=1, + itd_hidden_layers=1, + conv_layers=3, + scale=1, + dropout=0.1, + n_listeners=200, + rnn="LSTM", + activation="GELU", + norm="LayerNorm", + peft="lora", + lora_rank=1, + itd_scale=np.pi / 45, + conv_in=2, + emb_in=3, + spec_res=False, + itd_res=False, + itd_activation="Identity", + itd_skip_connection=True, + **kwargs, + ): + super().__init__() + + assert hidden_features % 2 == 0 + assert conv_layers > 2 + assert peft == "lora", "Our current implementation supports only LoRA" + assert lora_rank == 1, "Our current implementation supports the rank-1 case of LoRA" + + self.hidden_layers = hidden_layers + self.spec_hidden_layers = spec_hidden_layers + self.itd_hidden_layers = itd_hidden_layers + self.hidden_features = hidden_features + self.n_listeners = n_listeners + self.spec_res = spec_res + self.itd_res = itd_res + self.itd_skip_connection = itd_skip_connection + + # For random Fourier feature + self.itd_scale = itd_scale + self.rng = np.random.default_rng(0) + + bmat = scale * self.rng.normal(0.0, 1.0, (hidden_features // 2, emb_in * 2)) + self.bmat = torch.nn.Parameter(torch.tensor(bmat.astype(np.float32)), requires_grad=False) + + loc_bmat = scale * self.rng.normal(0.0, 1.0, (hidden_features // 2, 4)) + self.loc_bmat = torch.nn.Parameter(torch.tensor(loc_bmat.astype(np.float32)), requires_grad=False) + + self.emb_mlp = MLP(hidden_features, hidden_features * 2, dropout, activation=activation) + self.itd_net = nn.Sequential( + MLP(hidden_features, hidden_features // 2, dropout, activation=activation), + MLP(hidden_features // 2, 1, dropout=0.0, activation=itd_activation), + ) + + # Spec input and output layers + layer = [ + nn.Conv1d(conv_in, hidden_features // 2, 3, stride=1, padding=1), + nn.PReLU(), + ] + for _ in range(conv_layers - 2): + layer += [ + nn.Conv1d(hidden_features // 2, hidden_features // 2, 5, stride=2, padding=2), + nn.PReLU(), + ] + layer += [ + nn.Conv1d(hidden_features // 2, hidden_features, 5, stride=2, padding=2), + nn.PReLU(), + ] + self.spec_enc = nn.Sequential(*layer) + + layer = [ + nn.ConvTranspose1d(hidden_features, hidden_features // 2, 6, stride=2, padding=2), + nn.PReLU(), + ] + for _ in range(conv_layers - 2): + layer += [ + nn.ConvTranspose1d(hidden_features // 2, hidden_features // 2, 6, stride=2, padding=2), + nn.PReLU(), + ] + layer += [ + nn.ConvTranspose1d(hidden_features // 2, 2, 3, stride=1, padding=1), + nn.PReLU(), + ] + self.spec_dec = nn.Sequential(*layer) + + self.hidden_blocks = nn.ModuleList( + [ + LSTMLoRATAC(hidden_features, dropout=dropout, rnn=rnn, activation=activation, norm=norm) + for _ in range(hidden_layers) + ] + ) + + self.spec_hidden_blocks = nn.ModuleList( + [ + LoRAMLP4RANF( + hidden_features, + hidden_features, + dropout=dropout, + bias=True, + activation=activation, + ) + for _ in range(spec_hidden_layers) + ] + ) + self.itd_hidden_blocks = nn.ModuleList( + [ + LoRAMLP4RANF( + hidden_features, + hidden_features, + dropout=dropout, + bias=True, + activation=activation, + ) + for n in range(itd_hidden_layers) + ] + ) + + # For PEFT + lorau, lorav = [], [] + + for _ in range(hidden_layers + spec_hidden_layers + itd_hidden_layers): + lorau.append(nn.Linear(n_listeners, hidden_features * lora_rank, bias=False)) + lorav.append(nn.Linear(n_listeners, hidden_features * lora_rank, bias=False)) + nn.init.zeros_(lorau[-1].weight) + + self.embed_layer_lorau = nn.ModuleList(lorau) + self.embed_layer_lorav = nn.ModuleList(lorav) + + def _compute_uv(self, tgt_onehot, ret_onehot, idx, mode): + if mode == "target": + a = self.embed_layer_lorau[idx](tgt_onehot) + b = self.embed_layer_lorav[idx](tgt_onehot) + + elif mode == "alter": + a = self.embed_layer_lorau[idx](tgt_onehot) + b = self.embed_layer_lorav[idx](ret_onehot) + + else: + raise ValueError(f"Invalid mode: {mode}") + + if a.shape[-1] < self.hidden_features: + u = torch.stack(a.split(1, -1), -1) + else: + u = torch.stack(a.split(self.hidden_features, -1), -1) + + if b.shape[-1] < self.hidden_features: + v = torch.stack(b.split(1, -1), -1) + else: + v = torch.stack(b.split(self.hidden_features, -1), -1) + return u, v + + def forward(self, ret_specs_db, ret_itds, tgt_loc, tgt_sidx, ret_sidxs): + """Forward + + Args: + ret_specs_db (torch.Tensor): Retrieved log magnitude spectra in [batch, K, 2, nfreqs] + ret_itds (torch.Tensor): Retrieved log magnitude spectra in [batch, K] + tgt_loc (torch.Tensor): Sound source location (azimuth, elevation, distance) in [batch, 3] + tgt_sidx (torch.Tensor): Indices of the target subject in integer + ret_sidxs (torch.Tensor): Indices of the retrieved subject + + Returns: + estimate (torch.Tensor): Estimated magnitude in [batch, 2, nfreqs] + itd (torch.Tensor): Estimated ITD in [batch, 1] + """ + batch, nretrieval, nch, _ = ret_specs_db.shape + + tgt_onehot = F.one_hot(tgt_sidx[:, None], self.n_listeners) + tgt_onehot = tgt_onehot.tile(1, nretrieval, 1).type(torch.float32) + ret_onehot = F.one_hot(ret_sidxs, self.n_listeners).type(torch.float32) + + tgt_azimuth = tgt_loc[:, :1].tile(1, nretrieval) + tgt_elevation = tgt_loc[:, 1:2].tile(1, nretrieval) + ret_itds_scaled = self.itd_scale * ret_itds + + emb = torch.stack([tgt_azimuth, tgt_elevation, ret_itds_scaled], -1) + ave_itd = torch.mean(ret_itds, 1)[:, None] + ave_spec = torch.mean(ret_specs_db, 1) + + emb = torch.cat([emb.sin(), emb.cos()], -1) + emb = torch.einsum("bkn,mn->bkm", emb, self.bmat) + emb = torch.concatenate([emb.sin(), emb.cos()], axis=-1) + x_embs = torch.split(self.emb_mlp(emb), self.hidden_features, dim=-1) + + x_spec = ret_specs_db[..., :-1].reshape(batch * nretrieval, nch, -1) + x_spec = self.spec_enc(x_spec).reshape(batch, nretrieval, self.hidden_features, -1) + x = torch.concatenate([x_embs[0][..., None], x_spec, x_embs[1][..., None]], -1).transpose(-1, -2) + + # Core-processing + for n in range(self.hidden_layers): + u, v = self._compute_uv(tgt_onehot, ret_onehot, n, "alter") + x = x + self.hidden_blocks[n](x, u=u, v=v) + + x = torch.mean(x, dim=1, keepdims=False) + x_itd1, x_itd2, x_spec = x[:, :1, :], x[:, -1:, :], x[:, 1:-1, :] + + # Spectra post-processing + for n in range(self.spec_hidden_layers): + u, v = self._compute_uv(tgt_onehot[:, :1, :], None, n + self.hidden_layers, "target") + x_spec = self.spec_hidden_blocks[n](x_spec, u=u, v=v) + + estimate = self.spec_dec(x_spec.transpose(-1, -2)) + estimate = F.pad(estimate, (0, 1), mode="replicate") + + if self.spec_res: + estimate = estimate + ave_spec + + # ITD post-processing + x_itd = x_itd1 + x_itd2 + for n in range(self.itd_hidden_layers): + m = n + self.hidden_layers + self.spec_hidden_layers + u, v = self._compute_uv(tgt_onehot[:, :1, :], None, m, "target") + x_itd = self.itd_hidden_blocks[n](x_itd, u=u, v=v) + + # Location-related RFF without retrieved ITDs + if self.itd_skip_connection: + locs = [tgt_loc[:, 0].sin(), tgt_loc[:, 0].cos(), tgt_loc[:, 1].sin(), tgt_loc[:, 1].cos()] + loc_emb = torch.stack(locs, -1) @ self.loc_bmat.T + loc_emb = torch.concatenate([loc_emb.sin(), loc_emb.cos()], axis=-1) + x_itd = x_itd[:, 0, :] + loc_emb + else: + x_itd = x_itd[:, 0, :] + + itd = self.itd_net(x_itd) + + if self.itd_res: + itd = itd + ave_itd + + return estimate, itd diff --git a/ranf/utils/reconstruction.py b/ranf/utils/reconstruction.py new file mode 100644 index 0000000..0496bd4 --- /dev/null +++ b/ranf/utils/reconstruction.py @@ -0,0 +1,68 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import numpy as np +import scipy.signal as sn +from scipy.fft import ifft +from spatialaudiometrics.hrtf_metrics import itd_estimator_maxiacce + + +def calculate_itd(hrir, fs=48000, upper_cut_freq=3000, filter_order=10): + _, itd_samps, _ = itd_estimator_maxiacce( + hrir[None, ...], np.array(fs), upper_cut_freq=upper_cut_freq, filter_order=filter_order + ) + return np.array(itd_samps)[0, ...] + + +def hrtf2hrir_minph(mag, itd, nfft=None, fs=48000, itd_search_width=5): + """Reconstruct time-domain HRIRs from the predicted magnitude spectra and ITDs + + We first compute HRIRs without ITD with the minimum phase. + The reconstructed HRIRs may contain time offsets, and their ITDs could differ from zero. + The current implementation compensates for the offset by a naive grid search for each sound source direction. + + Args: + mag (np.array): Predicted magnitude in [ndirection, 2, nfreqs] + itd (np.array): Predicted ITD in [ndirection, 1] + nfft (int, optional): FFT points. If None, it will be calcualted from mag. + fs (int, optional): Sampling rate + itd_search_width (int, optional): The width for the grid search + + Returns: + hrir (np.array): reconstructed HRIRs in [ndirection, 2, nfft] + """ + _, _, nfreqs = mag.shape + + if nfft is None: + nfft = 2 * (nfreqs - 1) + + mag = mag.astype(np.float64) + mag = np.concatenate([mag, np.flip(mag[..., 1:-1], -1)], -1) + ph = -np.imag(sn.hilbert(np.log(mag), axis=-1)) + hrir = ifft(mag * np.exp(1j * ph), n=nfft, axis=-1).real + + for d, itd_pred in enumerate(itd[:, 0]): + itd_candidates, itd_errors = [], [] + for shift in range(-itd_search_width, itd_search_width + 1): + + itd_candidate = int(np.round(itd_pred + shift)) + itd_candidates.append(itd_candidate) + if itd_candidate < 0: + hl = hrir[d, 0, :] + hr = np.roll(hrir[d, 1, :], -itd_candidate) + else: + hl = np.roll(hrir[d, 0, :], itd_candidate) + hr = hrir[d, 1, :] + + itd_errors.append(calculate_itd(np.stack([hl, hr], 0), fs=fs) - itd_pred) + + itd_optimal = itd_candidates[np.argmin(np.abs(np.array(itd_errors)))] + + if itd_optimal < 0: + hrir[d, 1, :] = np.roll(hrir[d, 1, :], -itd_optimal) + + else: + hrir[d, 0, :] = np.roll(hrir[d, 0, :], itd_optimal) + + return hrir diff --git a/ranf/utils/sonicom_dataset_retrieval.py b/ranf/utils/sonicom_dataset_retrieval.py new file mode 100644 index 0000000..5e92f56 --- /dev/null +++ b/ranf/utils/sonicom_dataset_retrieval.py @@ -0,0 +1,201 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +from copy import deepcopy +from pathlib import Path + +import numpy as np +import sofar as sf +import torch + +from ranf.utils.config import TGTDIDXS003, TGTDIDXS005, TGTDIDXS019, TGTDIDXS100 + + +class SONICOMMulti(torch.utils.data.Dataset): + def __init__(self, config, stage="pretrain", mode="train"): + self.nretrieval = config.nretrieval + self.mode = mode + + upsample = config.upsample + assert upsample in {3, 5, 19, 100} + + if upsample == 3: + seen_didxs = TGTDIDXS003 + + elif upsample == 5: + seen_didxs = TGTDIDXS005 + + elif upsample == 19: + seen_didxs = TGTDIDXS019 + + elif upsample == 100: + seen_didxs = TGTDIDXS100 + else: + raise ValueError(f"config.upsample should be in (3, 5, 19, 100) but is {config.upsample}.") + + self.seen_didxs = seen_didxs + self.unseen_didxs = sorted(list(set(range(793)) - set(seen_didxs))) + + npz = np.load(config.features) + self.specs = npz["specs"] + self.ilds = npz["ilds"] + self.itds = npz["itds"] + self.locs = np.deg2rad(npz["locs"]) + self.locs[:, :, 0] -= np.pi + + npz = np.load(config.retrieval) + lsd_mat = npz["lsd_mat"] + itdd_mat = npz["itdd_mat"] + + self.retrieved_subjects = [] + for sidx in range(lsd_mat.shape[0]): + sidxs = [] + if config.retrieval_priority == "itdd": + for itdd in sorted(list(set(itdd_mat[sidx, :]))): + _sidxs = np.where(itdd_mat[sidx, :] == itdd)[0] + order = np.argsort(lsd_mat[sidx, _sidxs]) + sidxs += _sidxs[order].tolist() + + elif config.retrieval_priority == "lsd": + for lsd in sorted(list(set(lsd_mat[sidx, :]))): + _sidxs = np.where(lsd_mat[sidx, :] == lsd)[0] + sidxs += _sidxs.tolist() + + else: + raise NameError(f"{config.retrieval_priority} is not supported") + + self.retrieved_subjects.append(sidxs[: config.npool]) + + self.sidxs, self.didxs = [], [] + if stage == "pretrain" and mode == "train": + for sidx in config.train_subjects: + for didx in range(793): + self.sidxs.append(sidx) + self.didxs.append(didx) + + for sidx in config.valid_subjects: + for didx in self.seen_didxs: + self.sidxs.append(sidx) + self.didxs.append(didx) + + elif stage == "pretrain" and mode == "valid": + for sidx in config.valid_subjects: + for didx in self.unseen_didxs: + self.sidxs.append(sidx) + self.didxs.append(didx) + + elif stage == "adaptation": + for sidx in config.test_subjects: + for didx in self.seen_didxs: + self.sidxs.append(sidx) + self.didxs.append(didx) + + def __len__(self): + return len(self.sidxs) + + def __getitem__(self, idx): + tgt_sidx, tgt_didx = self.sidxs[idx], self.didxs[idx] + tgt_spec = self.specs[tgt_sidx, tgt_didx, :, :] + tgt_ild = self.ilds[tgt_sidx, tgt_didx] + tgt_itd = self.itds[tgt_sidx, tgt_didx].astype(np.float32) + tgt_loc = self.locs[tgt_sidx, tgt_didx, :] + + if self.mode == "train": + rng = np.random.default_rng(idx) + ret_sidxs = rng.choice(self.retrieved_subjects[tgt_sidx], self.nretrieval) + else: + ret_sidxs = np.array(self.retrieved_subjects[tgt_sidx][: self.nretrieval]) + + ret_specs = self.specs[ret_sidxs, tgt_didx, :, :] + ret_itds = self.itds[ret_sidxs, tgt_didx].astype(np.float32) + ret_locs = self.locs[ret_sidxs, tgt_didx, :] + + return tgt_spec, tgt_ild, tgt_itd, tgt_loc, ret_specs, ret_itds, ret_locs, tgt_sidx, ret_sidxs + + +class SONICOMMultiInference(torch.utils.data.Dataset): + def __init__(self, config): + if hasattr(config, "inference_sampling"): + self.sampling = config.inference_sampling + else: + self.sampling = False + + self.fs = 48000 + self.hrtf_type = "FreeFieldCompMinPhase_48kHz" + self.upsample = config.upsample + self.nretrieval = config.nretrieval + self.sonicom_path = Path(config.features).parent.parent.joinpath("subjects") + + npz = np.load(config.features) + self.specs = npz["specs"] + self.ilds = npz["ilds"] + self.itds = npz["itds"] + self.locs = np.deg2rad(npz["locs"]) + self.locs[:, :, 0] -= np.pi + + npz = np.load(config.retrieval) + lsd_mat = npz["lsd_mat"] + itdd_mat = npz["itdd_mat"] + + self.retrieved_subjects = [] + for sidx in range(lsd_mat.shape[0]): + sidxs = [] + if config.retrieval_priority == "itdd": + for itdd in sorted(list(set(itdd_mat[sidx, :]))): + _sidxs = np.where(itdd_mat[sidx, :] == itdd)[0] + order = np.argsort(lsd_mat[sidx, _sidxs]) + sidxs += _sidxs[order].tolist() + + elif config.retrieval_priority == "lsd": + for lsd in sorted(list(set(lsd_mat[sidx, :]))): + _sidxs = np.where(lsd_mat[sidx, :] == lsd)[0] + sidxs += _sidxs.tolist() + + else: + raise NameError(f"{config.retrieval_priority} is not supported") + + self.retrieved_subjects.append(sidxs[: config.npool]) + + self.fnames, self.sidxs = [], [] + for sidx in config.test_subjects: + if sidx < 200: + pidx = sidx + 1 + fname = self.sonicom_path.joinpath(f"P{pidx:04}_{self.hrtf_type}.sofa") + self.fnames.append(fname) + self.sidxs.append(sidx) + else: + fname = ( + Path(config.features).parent.parent.joinpath("lap-task2-upsampled").joinpath(f"{self.upsample:003}") + ) + fname = fname.joinpath(f"LAPtask2_{config.upsample}_{sidx+1-200}.sofa") + self.fnames.append(fname) + self.sidxs.append(sidx) + + def __len__(self): + return len(self.fnames) + + def __getitem__(self, idx): + tgt_sidx = self.sidxs[idx] + tgt_sofa_file = sf.read_sofa(self.fnames[idx]) + tgt_hrir = tgt_sofa_file.Data_IR.astype(np.float32) + + loc = deepcopy(tgt_sofa_file.SourcePosition) + + loc = np.deg2rad(loc).astype(np.float32) + loc[:, 0] -= np.pi + retval = [tgt_sofa_file, tgt_hrir, loc] + + if self.sampling: + rng = np.random.default_rng(tgt_sidx) + ret_sidxs = rng.choice(self.retrieved_subjects[tgt_sidx], self.nretrieval) + else: + ret_sidxs = np.array(self.retrieved_subjects[tgt_sidx][: self.nretrieval]) + + ret_specs = self.specs[ret_sidxs, :, :, :].transpose(1, 0, 2, 3) + ret_itds = self.itds[ret_sidxs, :].transpose(1, 0).astype(np.float32) + ret_locs = self.locs[ret_sidxs, :, :].transpose(1, 0, 2) + + retval += [ret_specs, ret_itds, ret_locs] + retval += [np.tile(tgt_sidx, (793)), np.tile(ret_sidxs, (793, 1))] + return retval diff --git a/ranf/utils/util.py b/ranf/utils/util.py new file mode 100644 index 0000000..b178999 --- /dev/null +++ b/ranf/utils/util.py @@ -0,0 +1,80 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import os +import random + +import matplotlib.pyplot as plt +import numpy as np +import torch +from scipy.fft import fft +from spatialaudiometrics import hrtf_metrics as hf +from spatialaudiometrics import load_data as ld + + +def seed_everything(seed): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +def linear2db(spec, eps=1.0e-5): + return 20 * torch.log10(spec + eps) + + +def db2linear(spec): + return 10.0 ** (spec / 20.0) + + +def plot_hrtf(fname, target, pred, lsd): + target = target.detach().to("cpu") + pred = pred.detach().to("cpu") + + plt.subplot(1, 2, 1) + plt.plot(target[0, 0, :]) + plt.plot(pred[0, 0, :]) + + plt.subplot(1, 2, 2) + plt.plot(target[0, 1, :]) + plt.plot(pred[0, 1, :]) + + plt.title(f"LSD: {lsd} dB") + plt.savefig(fname, format="png", dpi=300) + plt.clf() + plt.close() + + +def extract_features(path): + hrtf = ld.HRTF(path) + spec = np.abs(fft(hrtf.hrir, axis=-1))[..., : hrtf.hrir.shape[-1] // 2 + 1] + ild = hf.ild_estimator_rms(hrtf.hrir) + _, itd, _ = hf.itd_estimator_maxiacce(hrtf.hrir, hrtf.fs) + loc = hrtf.locs + return spec, ild, itd, loc + + +def count_parameters(model): + params = [] + for param in model.parameters(): + if param.requires_grad: + params.append(param.numel()) + return sum(params) + + +def to_cartesian(x): + if x.ndim == 1: + x = x[None, :] + ndim = 1 + else: + ndim = 2 + + y = np.stack([np.cos(x[:, 0]) * np.cos(x[:, 1]), np.sin(x[:, 0]) * np.cos(x[:, 1]), np.sin(x[:, 1])], -1) + y *= x[:, 2, None] + + if ndim == 1: + return y[0, :] + else: + return y diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..c8b4a2c --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,8 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +pre-commit +black>=22 +flake8 +pytest diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bb17e40 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +matplotlib==3.8.3 +numpy==1.26.4 +omegaconf==2.3.0 +pytest==8.3.3 +scipy==1.14.1 +sofar==1.1.3 +spatialaudiometrics==0.0.8 +torch==2.2.2 +tqdm==4.65.0 diff --git a/run_example.sh b/run_example.sh new file mode 100755 index 0000000..b3adaad --- /dev/null +++ b/run_example.sh @@ -0,0 +1,69 @@ +#!/usr/bin/env bash +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +set -e +set -u +set -o pipefail + +stage=1 +stop_stage=5 + +original_path="YOUR_SONICOM_DATASET_PATH/" +preprocessed_dataset_path="PATH_TO_STORE_PREPROCESSED_SONICOM_DATA" +exp_base_path="PATH_TO_STORE_CHECKPOINTS_AND_LOG_FILES" +config_path="config_template/ranf" +sp_level=3 +valid_size=19 +test_size=20 + +dump_dir="${preprocessed_dataset_path}/sp_level_$(printf "%03d" "$sp_level")_no_azimuth_calibration" +sonicom_path="${preprocessed_dataset_path}/sonicom" +exp_path="${exp_base_path}/$(basename "$config_path")_splevel$(printf "%03d" "$sp_level")" + + +if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # Stage 1: Copying the required HRTF files to local and Extracting features + # This stage is required only once + + echo "Stage 1 ..." + bash preprocess_sonicom.sh $original_path $sonicom_path + python -m ranf.compute_spec_ild_itd_for_sonicom_datasets \ + "${sonicom_path}/subjects" "${sonicom_path}/npzs" +fi + +if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # Stage 2: Splitting the dataset and writing the split into the configuration + # `skip_78` removes an outlier from the training dataset + + echo "Stage 2 ..." + mkdir -p $dump_dir + python -m ranf.compute_distance_matrices_for_spec_itd \ + "${sonicom_path}/npzs" $dump_dir $sp_level $test_size --skip_78 + + python -m ranf.prepare_single_fold \ + $dump_dir $config_path $exp_path $sonicom_path $sp_level $valid_size $test_size --skip_78 + + echo "The config file in $config_path has been modified for the specified sparsity level and saved in $exp_path" +fi + +if [ $stage -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # Stage 3: Pretraining a neural field + echo "Stage 3 ..." + python -m ranf.1_pretraining_neural_field $exp_path +fi + +if [ $stage -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # Stage 4: Adapting the pre-trained neural field to the target subjects + echo "Stage 4 ..." + python -m ranf.2_adapting_neural_field $exp_path +fi + + +if [ $stage -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # Stage 5: Performing the inference and evaluation on the test set + echo "Stage 5 ..." + python -m ranf.3_evaluating_neural_field $exp_path + python -m ranf.summarize_evaluation_result $exp_path +fi diff --git a/run_learningfree_methods.sh b/run_learningfree_methods.sh new file mode 100644 index 0000000..2b565eb --- /dev/null +++ b/run_learningfree_methods.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +set -e +set -u +set -o pipefail + +stage=1 +stop_stage=3 + +original_path="YOUR_SONICOM_DATASET_PATH/" +preprocessed_dataset_path="PATH_TO_STORE_PREPROCESSED_SONICOM_DATA" +exp_base_path="PATH_TO_STORE_LOG_FILES" +config_path="config_template/hrtf_selection" +sp_level=19 +valid_size=19 +test_size=20 + +dump_dir="${preprocessed_dataset_path}/sp_level_$(printf "%03d" "$sp_level")_no_azimuth_calibration" +sonicom_path="${preprocessed_dataset_path}/sonicom" +exp_path="${exp_base_path}/$(basename "$config_path")_splevel$(printf "%03d" "$sp_level")" + + +if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # Stage 1: Copying the required HRTF files to local and extracting features + # This stage is required only once + + echo "Stage 1 ..." + bash preprocess_sonicom.sh $original_path $sonicom_path + python -m ranf.compute_spec_ild_itd_for_sonicom_datasets \ + "${sonicom_path}/subjects" "${sonicom_path}/npzs" +fi + +if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # Stage 2: Splitting the dataset and writing the split into the configuration + # `skip_78` excludes an outlier from the training dataset + + echo "Stage 2 ..." + mkdir -p $dump_dir + python -m ranf.compute_distance_matrices_for_spec_itd \ + "${sonicom_path}/npzs" $dump_dir $sp_level $test_size --skip_78 + + python -m ranf.prepare_single_fold \ + $dump_dir $config_path $exp_path $sonicom_path $sp_level $valid_size $test_size --skip_78 + + echo "The config file in $config_path has been modified for the specified sparsity level and saved in $exp_path" +fi + +if [ $stage -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # Stage 3: Performing inference and evaluation on the test set + echo "Stage 3 ..." + if [ "$(echo $exp_path | grep 'hrtf_selection')" ]; then + python -m ranf.evaluating_hrtf_selection $exp_path + elif [ "$(echo $exp_path | grep 'nearest_neighbor')" ]; then + python -m ranf.evaluating_nearest_neighbor $exp_path + else + echo "Invalid exp_path" + fi + python -m ranf.summarize_evaluation_result $exp_path +fi diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..8fccc03 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/tests/loss_test.py b/tests/loss_test.py new file mode 100644 index 0000000..a429a33 --- /dev/null +++ b/tests/loss_test.py @@ -0,0 +1,48 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import pytest +import torch + +from ranf.utils.loss_functions import ild_diff_loss, itd_diff_loss, lsd_loss + + +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("f_nyquist", [129]) +@pytest.mark.parametrize("use_index", [False, True]) +def test_lsd_loss(batch_size, f_nyquist, use_index): + ground_truth = torch.randn(batch_size, 2, f_nyquist) + loss = torch.mean(lsd_loss(ground_truth, ground_truth, use_index=use_index)) + torch.testing.assert_close(loss, torch.tensor(0.0), rtol=1.0e-5, atol=1.0e-6) + + loss = torch.mean(lsd_loss(ground_truth, ground_truth + 1.0, use_index=use_index)) + torch.testing.assert_close(loss, torch.tensor(1.0), rtol=1.0e-5, atol=1.0e-6) + + +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_itd_diff_loss(batch_size): + ground_truth = torch.randn(batch_size) + loss = torch.mean(itd_diff_loss(ground_truth, ground_truth)) + torch.testing.assert_close(loss, torch.tensor(0.0), rtol=1.0e-5, atol=1.0e-6) + + loss = torch.mean(itd_diff_loss(ground_truth, ground_truth + 1.0)) + torch.testing.assert_close(loss, torch.tensor(1.0e6 / 48000), rtol=1.0e-5, atol=1.0e-6) + + loss = torch.mean(itd_diff_loss(ground_truth, ground_truth, threshold=1.0)) + torch.testing.assert_close(loss, torch.tensor(1.0e6 / 48000), rtol=1.0e-5, atol=1.0e-6) + + +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("f_nyquist", [129]) +@pytest.mark.parametrize("target_ild", [None, 0.0]) +def test_ild_diff_loss(batch_size, f_nyquist, target_ild): + ground_truth = torch.randn(batch_size, 2, f_nyquist) + + if target_ild is None: + loss = torch.mean(ild_diff_loss(ground_truth, ground_truth, target_ild)) + torch.testing.assert_close(loss, torch.tensor(0.0), rtol=1.0e-5, atol=1.0e-6) + + else: + loss = torch.mean(ild_diff_loss(ground_truth, ground_truth, target_ild)) + loss > 0.0 diff --git a/tests/model_init_test.py b/tests/model_init_test.py new file mode 100644 index 0000000..912062e --- /dev/null +++ b/tests/model_init_test.py @@ -0,0 +1,175 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import pytest +import torch + +from ranf.utils.neural_field_icassp import RANF, CbCNeuralField, PEFTNeuralField + + +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize("f_nyquist", [129, 257]) +@pytest.mark.parametrize("hidden_features", [16, 64]) +@pytest.mark.parametrize("hidden_layers", [1, 4]) +@pytest.mark.parametrize("scale", [1, 3]) +@pytest.mark.parametrize("dropout", [0, 0.1]) +@pytest.mark.parametrize("n_listeners", [100, 200]) +@pytest.mark.parametrize("activation", ["PReLU", "GELU"]) +@pytest.mark.parametrize("peft", ["bitfit", "lora"]) +@pytest.mark.parametrize("itd_skip_connection", [False, True]) +def test_peft_neural_field( + batch_size, + f_nyquist, + hidden_features, + hidden_layers, + scale, + dropout, + n_listeners, + activation, + peft, + itd_skip_connection, +): + + out_features = 2 * f_nyquist + model = PEFTNeuralField( + hidden_features=hidden_features, + hidden_layers=hidden_layers, + out_features=out_features, + scale=scale, + dropout=dropout, + n_listeners=n_listeners, + activation=activation, + peft=peft, + itd_skip_connection=itd_skip_connection, + ) + + model.train() + + ret_specs_db = None + ret_itds = None + ret_sidxs = None + tgt_loc = torch.rand(batch_size, 2) + tgt_sidx = torch.randint(low=0, high=n_listeners, size=(batch_size,)) + + mag, itd = model(ret_specs_db, ret_itds, tgt_loc, tgt_sidx, ret_sidxs) + assert list(mag.shape) == [batch_size, 2, f_nyquist] + assert list(itd.shape) == [batch_size, 1] + + +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize("f_nyquist", [129, 257]) +@pytest.mark.parametrize("hidden_features", [16, 64]) +@pytest.mark.parametrize("hidden_layers", [1, 4]) +@pytest.mark.parametrize("scale", [1, 3]) +@pytest.mark.parametrize("dropout", [0, 0.1]) +@pytest.mark.parametrize("n_listeners", [100, 200]) +@pytest.mark.parametrize("activation", ["PReLU", "GELU"]) +@pytest.mark.parametrize("itd_skip_connection", [False, True]) +def test_cbc_neural_field( + batch_size, + f_nyquist, + hidden_features, + hidden_layers, + scale, + dropout, + n_listeners, + activation, + itd_skip_connection, +): + + out_features = 2 * f_nyquist + model = CbCNeuralField( + hidden_features=hidden_features, + hidden_layers=hidden_layers, + out_features=out_features, + scale=scale, + dropout=dropout, + n_listeners=n_listeners, + activation=activation, + itd_skip_connection=itd_skip_connection, + ) + + model.train() + + ret_specs_db = None + ret_itds = None + ret_sidxs = None + tgt_loc = torch.rand(batch_size, 2) + tgt_sidx = torch.randint(low=0, high=n_listeners, size=(batch_size,)) + + mag, itd = model(ret_specs_db, ret_itds, tgt_loc, tgt_sidx, ret_sidxs) + assert list(mag.shape) == [batch_size, 2, f_nyquist] + assert list(itd.shape) == [batch_size, 1] + + +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize("f_nyquist", [257]) +@pytest.mark.parametrize("n_retrievals", [1, 3, 5]) +@pytest.mark.parametrize("hidden_features", [64]) +@pytest.mark.parametrize("hidden_layers", [1, 4]) +@pytest.mark.parametrize("spec_hidden_layers", [1, 2]) +@pytest.mark.parametrize("itd_hidden_layers", [1, 2]) +@pytest.mark.parametrize("conv_layers", [4]) +@pytest.mark.parametrize("scale", [1]) +@pytest.mark.parametrize("dropout", [0]) +@pytest.mark.parametrize("n_listeners", [200]) +@pytest.mark.parametrize("rnn", ["GRU", "LSTM"]) +@pytest.mark.parametrize("activation", ["GELU"]) +@pytest.mark.parametrize("norm", ["Identity", "LayerNorm"]) +@pytest.mark.parametrize("spec_res", [False, True]) +@pytest.mark.parametrize("itd_res", [False, True]) +@pytest.mark.parametrize("lora_retrieval", ["alter", "diff", "target"]) +@pytest.mark.parametrize("itd_skip_connection", [False, True]) +def test_ranf( + batch_size, + f_nyquist, + n_retrievals, + hidden_features, + hidden_layers, + spec_hidden_layers, + itd_hidden_layers, + conv_layers, + scale, + dropout, + n_listeners, + rnn, + activation, + norm, + spec_res, + itd_res, + lora_retrieval, + itd_skip_connection, +): + + out_features = 2 * f_nyquist + model = RANF( + hidden_features=hidden_features, + hidden_layers=hidden_layers, + spec_hidden_layers=spec_hidden_layers, + itd_hidden_layers=itd_hidden_layers, + conv_layers=conv_layers, + out_features=out_features, + scale=scale, + dropout=dropout, + n_listeners=n_listeners, + rnn=rnn, + activation=activation, + norm=norm, + spec_res=spec_res, + itd_res=itd_res, + lora_retrieval=lora_retrieval, + itd_skip_connection=itd_skip_connection, + ) + + model.train() + + ret_specs_db = torch.rand(batch_size, n_retrievals, 2, f_nyquist) + ret_itds = torch.rand(batch_size, n_retrievals) + ret_sidxs = torch.randint(low=0, high=n_listeners, size=(batch_size, n_retrievals)) + tgt_loc = torch.rand(batch_size, 2) + tgt_sidx = torch.randint(low=0, high=n_listeners, size=(batch_size,)) + + mag, itd = model(ret_specs_db, ret_itds, tgt_loc, tgt_sidx, ret_sidxs) + assert list(mag.shape) == [batch_size, 2, f_nyquist] + assert list(itd.shape) == [batch_size, 1] diff --git a/tests/post_processing_test.py b/tests/post_processing_test.py new file mode 100644 index 0000000..9d242ac --- /dev/null +++ b/tests/post_processing_test.py @@ -0,0 +1,34 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import numpy as np +import pytest +from scipy.fft import fft +from spatialaudiometrics import hrtf_metrics as hf + +from ranf.utils.reconstruction import hrtf2hrir_minph + + +@pytest.mark.parametrize("n_directions", [10, 20]) +@pytest.mark.parametrize("f_nyquist", [129]) +@pytest.mark.parametrize("itd_search_width", [0, 1, 5, 10]) +def test_hrtf2hrir_minph( + n_directions, + f_nyquist, + itd_search_width, +): + mag = np.random.rand(n_directions, 2, f_nyquist) + itd = np.round(np.random.rand(n_directions, 1) * 90 - 45) + hrir = hrtf2hrir_minph(mag, itd, itd_search_width=itd_search_width) + + _mag = np.abs(fft(hrir, axis=-1))[..., : hrir.shape[-1] // 2 + 1] + np.testing.assert_allclose(_mag, mag) + + if itd_search_width > 1: + mag = np.ones((n_directions, 2, f_nyquist)) + itd = np.round(np.random.rand(n_directions, 1) * 90 - 45) + hrir = hrtf2hrir_minph(mag, itd, itd_search_width=itd_search_width) + + _, _itd, _ = hf.itd_estimator_maxiacce(hrir, fs=np.array(48000)) + np.testing.assert_allclose(_itd, itd[:, 0], atol=1.0 + 1.0e-5)