Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can not compile some models to torchscripts with torch verion 2. #126

Closed
jintuzhang opened this issue Apr 11, 2024 · 2 comments
Closed

Can not compile some models to torchscripts with torch verion 2. #126

jintuzhang opened this issue Apr 11, 2024 · 2 comments
Labels
good first issue Good for newcomers

Comments

@jintuzhang
Copy link

Example input:

import mlcolvar

cv = mlcolvar.cvs.DeepTDA(
    n_cvs=1,
    n_states=2,
    target_centers=[-10.0, 10.0],
    target_sigmas=[0.2, 0.2],
    layers=[4, 3, 2, 1]
)

cv.to_torchscript('model.ptc')
Errors:
Traceback (most recent call last):
  File "/compile.py", line 11, in <module>
    cv.to_torchscript('model.ptc')
  File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1429, in to_torchscript
    torchscript_module = torch.jit.script(self.eval(), **kwargs)
  File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_script.py", line 1284, in script
    return torch.jit._recursive.create_script_module(
  File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 480, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
    init_fn(script_module)
  File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 546, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 397, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
  File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_recursive.py", line 867, in try_compile_fn
    return torch.jit.script(fn, _rcb=rcb)
  File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/_script.py", line 1341, in script
    fn = torch._C._jit_script_compile(
  File "/calc/miniconda3/envs/TORCH/lib/python3.10/site-packages/torch/jit/annotations.py", line 366, in try_ann_to_type
    assert maybe_type, msg.format(repr(ann), repr(maybe_type))
AssertionError: Unsupported annotation typing.Union[list, torch.Tensor] could not be resolved because None could not be resolved.
conda list:
# packages in environment at /calc/miniconda3/envs/TORCH:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
aiofiles                  22.1.0          py310h06a4308_0
aiohttp                   3.8.5                    pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
aiosqlite                 0.18.0          py310h06a4308_0
annotated-types           0.5.0              pyhd8ed1ab_0    conda-forge
anyio                     3.6.2              pyhd8ed1ab_0    conda-forge
argon2-cffi               21.3.0             pyhd3eb1b0_0
argon2-cffi-bindings      21.2.0          py310h5764c6d_3    conda-forge
arrow                     1.2.3           py310h06a4308_1
ase                       3.22.1                   pypi_0    pypi
asttokens                 2.2.1              pyhd8ed1ab_0    conda-forge
astunparse                1.6.3              pyhd8ed1ab_0    conda-forge
async-timeout             4.0.3                    pypi_0    pypi
attrs                     22.2.0             pyh71513ae_0    conda-forge
babel                     2.12.1             pyhd8ed1ab_1    conda-forge
backcall                  0.2.0              pyhd3eb1b0_0
backoff                   2.2.1              pyhd8ed1ab_0    conda-forge
backports                 1.1                pyhd3eb1b0_0
backports.functools_lru_cache 1.6.4              pyhd8ed1ab_0    conda-forge
beautifulsoup4            4.12.2          py310h06a4308_0
blas                      1.0                         mkl
bleach                    6.0.0              pyhd8ed1ab_0    conda-forge
blessed                   1.19.1             pyhe4f9e05_2    conda-forge
blosc                     1.21.3               h6a678d5_0
bottleneck                1.3.7           py310h0a54255_0    conda-forge
brotli                    1.0.9                h166bdaf_8    conda-forge
brotli-bin                1.0.9                h166bdaf_8    conda-forge
brotlipy                  0.7.0           py310h7f8727e_1002
bzip2                     1.0.8                h7b6447c_0
c-ares                    1.19.1               h5eee18b_0
c-blosc2                  2.8.0                h6a678d5_0
ca-certificates           2023.7.22            hbcca054_0    conda-forge
cachecontrol              0.12.11         py310h06a4308_1
certifi                   2023.7.22          pyhd8ed1ab_0    conda-forge
cffi                      1.15.1          py310h5eee18b_3
cftime                    1.6.2                    pypi_0    pypi
charset-normalizer        2.0.4              pyhd3eb1b0_0
cleo                      2.0.1           py310h06a4308_0
click                     8.1.7           unix_pyh707e725_0    conda-forge
colorama                  0.4.6           py310h06a4308_0
comm                      0.1.2           py310h06a4308_0
contourpy                 1.0.5           py310hdb19cb5_0
cpuonly                   2.0                           0    pytorch
crashtest                 0.4.1           py310h06a4308_0
croniter                  1.4.1              pyhd8ed1ab_0    conda-forge
cryptography              41.0.2          py310h22a60cf_0
cycler                    0.11.0             pyhd3eb1b0_0
cyrus-sasl                2.1.28               h52b45da_1
dateutils                 0.6.12                     py_0    conda-forge
dbus                      1.13.18              hb2f20db_0
debugpy                   1.6.7           py310h6a678d5_0
decorator                 5.1.1              pyhd3eb1b0_0
deepdiff                  6.3.1              pyhd8ed1ab_0    conda-forge
defusedxml                0.7.1              pyhd8ed1ab_0    conda-forge
distlib                   0.3.7              pyhd8ed1ab_0    conda-forge
dulwich                   0.21.5          py310h2372a71_0    conda-forge
e3nn                      0.4.4                    pypi_0    pypi
einops                    0.7.0                    pypi_0    pypi
entrypoints               0.4             py310h06a4308_0
exceptiongroup            1.1.3              pyhd8ed1ab_0    conda-forge
executing                 1.2.0              pyhd8ed1ab_0    conda-forge
expat                     2.5.0                hcb278e6_1    conda-forge
expect                    5.45.4               h555a92e_0    conda-forge
fastapi                   0.101.1            pyhd8ed1ab_0    conda-forge
filelock                  3.9.0           py310h06a4308_0
fontconfig                2.14.1               h4c34cd2_2
fonttools                 4.39.3          py310h1fa729e_0    conda-forge
freetype                  2.12.1               hca18f0e_1    conda-forge
frozenlist                1.4.0                    pypi_0    pypi
fsspec                    2023.6.0           pyh1a96a4e_0    conda-forge
giflib                    5.2.1                h0b41bf4_3    conda-forge
glib                      2.69.1               he621ea3_2
gmp                       6.2.1                h295c915_3
gmpy2                     2.1.2           py310heeb90bb_0
gst-plugins-base          1.14.1               h6a678d5_1
gstreamer                 1.14.1               h5eee18b_1
h11                       0.14.0             pyhd8ed1ab_0    conda-forge
h5py                      3.9.0           py310he06866b_0
hdf5                      1.12.1               h2b7332f_3
html5lib                  1.1                pyhd3eb1b0_0
icu                       58.2              hf484d3e_1000    conda-forge
idna                      3.4             py310h06a4308_0
importlib-metadata        6.8.0              pyha770c72_0    conda-forge
importlib_metadata        6.8.0                hd8ed1ab_0    conda-forge
iniconfig                 2.0.0              pyhd8ed1ab_0    conda-forge
inquirer                  3.1.3              pyhd8ed1ab_0    conda-forge
intel-openmp              2023.1.0         hdb19cb5_46305
ipykernel                 6.25.0          py310h2f386ee_0
ipython                   8.12.2          py310h06a4308_0
ipython_genutils          0.2.0              pyhd3eb1b0_1
itsdangerous              2.1.2              pyhd8ed1ab_0    conda-forge
jaraco.classes            3.3.0              pyhd8ed1ab_0    conda-forge
jedi                      0.18.2             pyhd8ed1ab_0    conda-forge
jeepney                   0.8.0              pyhd8ed1ab_0    conda-forge
jinja2                    3.1.2           py310h06a4308_0
joblib                    1.2.0           py310h06a4308_0
jpeg                      9e                   h0b41bf4_3    conda-forge
json5                     0.9.6              pyhd3eb1b0_0
jsonschema                4.17.3          py310h06a4308_0
jupyter_client            8.1.0           py310h06a4308_0
jupyter_core              5.3.0           py310h06a4308_0
jupyter_events            0.6.3           py310h06a4308_0
jupyter_server            1.23.6             pyhd8ed1ab_0    conda-forge
jupyter_server_fileid     0.9.0           py310h06a4308_0
jupyter_server_ydoc       0.8.0           py310h06a4308_1
jupyter_ydoc              0.2.4           py310h06a4308_0
jupyterlab                3.6.3           py310h06a4308_0
jupyterlab_pygments       0.2.2              pyhd8ed1ab_0    conda-forge
jupyterlab_server         2.22.0          py310h06a4308_0
kdepy                     1.1.5           py310h2372a71_0    conda-forge
keyring                   23.13.1         py310h06a4308_0
kiwisolver                1.4.4           py310h6a678d5_0
krb5                      1.20.1               h143b758_1
lcms2                     2.15                 hfd0df8a_0    conda-forge
ld_impl_linux-64          2.38                 h1181459_1
lerc                      3.0                  h295c915_0
libbrotlicommon           1.0.9                h166bdaf_8    conda-forge
libbrotlidec              1.0.9                h166bdaf_8    conda-forge
libbrotlienc              1.0.9                h166bdaf_8    conda-forge
libclang                  14.0.6          default_hc6dbbc7_1
libclang13                14.0.6          default_he11475f_1
libcups                   2.4.2                h2d74bed_1
libcurl                   8.2.1                h251f7ec_0
libdeflate                1.17                 h5eee18b_0
libedit                   3.1.20221030         h5eee18b_0
libev                     4.33                 h7f8727e_1
libevent                  2.1.12               hdbd6064_1
libexpat                  2.5.0                hcb278e6_1    conda-forge
libffi                    3.4.4                h6a678d5_0
libgcc-ng                 12.2.0              h65d4601_19    conda-forge
libgfortran-ng            12.2.0              h69a702a_19    conda-forge
libgfortran5              12.2.0              h337968e_19    conda-forge
libllvm14                 14.0.6               hdb19cb5_3
libnghttp2                1.52.0               h2d74bed_1
libpng                    1.6.39               h5eee18b_0
libpq                     12.15                hdbd6064_1
libprotobuf               3.20.3               he621ea3_0
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libssh2                   1.10.0               hdbd6064_2
libstdcxx-ng              13.1.0               hfd8a6a1_0    conda-forge
libtiff                   4.5.1                h6a678d5_0
libuuid                   1.41.5               h5eee18b_0
libwebp                   1.2.4                h11a3e52_1
libwebp-base              1.2.4                h5eee18b_1
libxcb                    1.15                 h7f8727e_0
libxkbcommon              1.0.1                h5eee18b_1
libxml2                   2.10.4               hcbfbd50_0
libxslt                   1.1.37               h2085143_0
libzlib                   1.2.13               h166bdaf_4    conda-forge
lightning                 2.0.7              pyhd8ed1ab_0    conda-forge
lightning-cloud           0.5.37             pyhd8ed1ab_0    conda-forge
lightning-utilities       0.9.0              pyhd8ed1ab_0    conda-forge
llvm-openmp               16.0.1               h417c0b6_0    conda-forge
llvmlite                  0.40.1                   pypi_0    pypi
lockfile                  0.12.2                     py_1    conda-forge
lz4-c                     1.9.4                h6a678d5_0
lzo                       2.10              h516909a_1000    conda-forge
mace                      0.3.2                    pypi_0    pypi
mace-layer                0.0.0                    pypi_0    pypi
mace-torch                0.3.4                    pypi_0    pypi
markdown-it-py            3.0.0              pyhd8ed1ab_0    conda-forge
markupsafe                2.1.1           py310h7f8727e_0
matplotlib                3.7.2           py310h06a4308_0
matplotlib-base           3.7.2           py310h1128e8f_0
matplotlib-inline         0.1.6           py310h06a4308_0
matscipy                  0.8.0                    pypi_0    pypi
mdtraj                    1.9.7           py310hd8d60c7_1    conda-forge
mdurl                     0.1.0           py310h06a4308_0
mistune                   2.0.5              pyhd8ed1ab_0    conda-forge
mkl                       2023.1.0         h213fc3f_46343
mkl-service               2.4.0           py310h5eee18b_1
mkl_fft                   1.3.6           py310h1128e8f_1
mkl_random                1.2.2           py310h1128e8f_1
mlcolvar                  1+unknown                pypi_0    pypi
more-itertools            10.1.0             pyhd8ed1ab_0    conda-forge
mpc                       1.1.0                h10f8cd9_1
mpfr                      4.0.2                hb69a4c5_1
mpiplus                   0+unknown                pypi_0    pypi
mpmath                    1.3.0           py310h06a4308_0
msgpack-python            1.0.3           py310hd09550d_0
multidict                 6.0.4                    pypi_0    pypi
munkres                   1.1.4                      py_0
mysql                     5.7.24               h721c034_2
nbclassic                 0.5.5              pyh8b2e9e2_0    conda-forge
nbclient                  0.7.3              pyhd8ed1ab_0    conda-forge
nbconvert-core            7.3.0              pyhd8ed1ab_2    conda-forge
nbformat                  5.8.0              pyhd8ed1ab_0    conda-forge
ncurses                   6.4                  h6a678d5_0
nest-asyncio              1.5.6           py310h06a4308_0
netcdf                    66.0.2                   pypi_0    pypi
netcdf4                   1.6.4                    pypi_0    pypi
networkx                  3.1             py310h06a4308_0
ninja                     1.10.2               h06a4308_5
ninja-base                1.10.2               hd09550d_5
notebook                  6.5.4              pyha770c72_0    conda-forge
notebook-shim             0.2.2           py310h06a4308_0
nspr                      4.35                 h6a678d5_0
nss                       3.89.1               h6a678d5_0
numba                     0.57.1                   pypi_0    pypi
numexpr                   2.8.4           py310h85018f9_1
numpy                     1.24.4                   pypi_0    pypi
openmm                    8.0.0           py310h5728c26_1    <unknown>
openmm-plumed             1.0             py310h552f1b7_9    <unknown>
openmmtools               0.23.1                   pypi_0    pypi
openssl                   3.1.2                hd590300_0    conda-forge
opt-einsum                3.0.0                      py_0    conda-forge
opt_einsum_fx             0.1.4              pyhd8ed1ab_0    conda-forge
ordered-set               4.1.0           py310h06a4308_0
orjson                    3.9.5           py310h1e2579a_0    conda-forge
packaging                 23.1            py310h06a4308_0
pandas                    2.0.3           py310h1128e8f_0
pandocfilters             1.5.0              pyhd3eb1b0_0
parso                     0.8.3              pyhd3eb1b0_0
pcre                      8.45                 h295c915_0
pexpect                   4.8.0              pyhd3eb1b0_3
pickleshare               0.7.5           pyhd3eb1b0_1003
pillow                    9.4.0           py310h6a678d5_0
pint                      0.22                     pypi_0    pypi
pip                       23.2.1          py310h06a4308_0
pkginfo                   1.9.6           py310h06a4308_0
platformdirs              2.5.2           py310h06a4308_0
pluggy                    1.2.0              pyhd8ed1ab_0    conda-forge
plumed                    2.9.0                    pypi_0    pypi
ply                       3.11                       py_1    conda-forge
poetry                    1.4.0           py310h06a4308_0
poetry-core               1.5.1           py310h06a4308_0
poetry-plugin-export      1.3.0           py310h4849bfd_0
prettytable               3.9.0                    pypi_0    pypi
prometheus_client         0.16.0             pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.38             pyha770c72_0    conda-forge
psutil                    5.9.0           py310h5eee18b_0
ptyprocess                0.7.0              pyhd3eb1b0_2
pure_eval                 0.2.2              pyhd3eb1b0_0
py-cpuinfo                8.0.0              pyhd3eb1b0_1
pycparser                 2.21               pyhd3eb1b0_0
pydantic                  2.0.3              pyhd8ed1ab_1    conda-forge
pydantic-core             2.3.0           py310hcb5633a_0    conda-forge
pyg                       2.3.0           py310_torch_2.0.0_cpu    pyg
pygments                  2.15.1          py310h06a4308_1
pyjwt                     2.8.0              pyhd8ed1ab_0    conda-forge
pymbar                    4.0.2                    pypi_0    pypi
pyopenssl                 23.2.0          py310h06a4308_0
pyparsing                 3.0.9           py310h06a4308_0
pyproject_hooks           1.0.0           py310h06a4308_0
pyqt                      5.15.7          py310h6a678d5_1
pyqt5-sip                 12.11.0                  pypi_0    pypi
pyrsistent                0.19.3          py310h1fa729e_0    conda-forge
pysocks                   1.7.1           py310h06a4308_0
pytables                  3.8.0           py310hb8ae3fc_3
pytest                    7.4.0           py310h06a4308_0
python                    3.10.12              h955ad1f_0
python-build              0.10.0             pyhd8ed1ab_1    conda-forge
python-dateutil           2.8.2              pyhd3eb1b0_0
python-editor             1.0.4              pyhd3eb1b0_0
python-fastjsonschema     2.16.3             pyhd8ed1ab_0    conda-forge
python-installer          0.6.0           py310h06a4308_0
python-json-logger        2.0.7           py310h06a4308_0
python-multipart          0.0.6           py310h06a4308_0
python-tzdata             2023.3             pyhd3eb1b0_0
python_abi                3.10                    2_cp310    conda-forge
pytorch                   2.0.1           cpu_py310hdc00b08_0
pytorch-lightning         2.0.7              pyhd8ed1ab_0    conda-forge
pytorch-mutex             1.0                         cpu    pytorch
pytorch-scatter           2.1.1           py310_torch_2.0.0_cpu    pyg
pytz                      2023.3             pyhd8ed1ab_0    conda-forge
pyyaml                    6.0.1           py310h2372a71_0    conda-forge
pyzmq                     25.1.0          py310h6a678d5_0
qt-main                   5.15.2               h7358343_9
qt-webengine              5.15.9               h9ab4d14_7
qtwebkit                  5.212                h3fafdc1_5
rapidfuzz                 2.13.7          py310h1128e8f_0
rdkit                     2022.9.5                 pypi_0    pypi
readchar                  4.0.5              pyhd8ed1ab_0    conda-forge
readline                  8.2                  h5eee18b_0
requests                  2.31.0          py310h06a4308_0
requests-toolbelt         0.10.1             pyhd8ed1ab_0    conda-forge
rfc3339-validator         0.1.4           py310h06a4308_0
rfc3986-validator         0.1.1           py310h06a4308_0
rich                      13.5.1             pyhd8ed1ab_0    conda-forge
scikit-learn              1.3.0           py310h1128e8f_0
scipy                     1.10.1                   pypi_0    pypi
secretstorage             3.3.3           py310hff52083_1    conda-forge
send2trash                1.8.0              pyhd3eb1b0_1
setuptools                68.0.0          py310h06a4308_0
shellingham               1.5.3              pyhd8ed1ab_0    conda-forge
sip                       6.6.2           py310h6a678d5_0
six                       1.16.0             pyhd3eb1b0_1
snappy                    1.1.9                h295c915_0
sniffio                   1.3.0              pyhd8ed1ab_0    conda-forge
soupsieve                 2.4             py310h06a4308_0
sqlite                    3.41.2               h5eee18b_0
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
starlette                 0.27.0          py310h06a4308_0
starsessions              1.3.0              pyhd8ed1ab_0    conda-forge
sympy                     1.11.1          py310h06a4308_0
tbb                       2021.8.0             hdb19cb5_0
terminado                 0.17.1          py310h06a4308_0
threadpoolctl             2.2.0              pyh0d69192_0
tinycss2                  1.2.1           py310h06a4308_0
tk                        8.6.12               h1ccaba5_0
toml                      0.10.2             pyhd3eb1b0_0
tomli                     2.0.1           py310h06a4308_0
tomlkit                   0.12.1             pyha770c72_0    conda-forge
torch-ema                 0.3                      pypi_0    pypi
torchmetrics              1.0.3              pyhd8ed1ab_0    conda-forge
tornado                   6.3.2           py310h5eee18b_0
tqdm                      4.66.1             pyhd8ed1ab_0    conda-forge
traitlets                 5.9.0              pyhd8ed1ab_0    conda-forge
trove-classifiers         2023.8.7           pyhd8ed1ab_0    conda-forge
typing-extensions         4.7.1           py310h06a4308_0
typing_extensions         4.7.1           py310h06a4308_0
tzdata                    2023c                h04d1e81_0
unicodedata2              15.0.0          py310h5eee18b_0
urllib3                   1.26.16         py310h06a4308_0
uvicorn                   0.23.2          py310hff52083_0    conda-forge
virtualenv                20.17.1         py310h06a4308_0
wcwidth                   0.2.6              pyhd8ed1ab_0    conda-forge
webencodings              0.5.1           py310h06a4308_1
websocket-client          1.5.1              pyhd8ed1ab_0    conda-forge
websockets                11.0.3          py310h2372a71_0    conda-forge
wheel                     0.38.4          py310h06a4308_0
xz                        5.4.2                h5eee18b_0
y-py                      0.5.9           py310h52d8a92_0
yaml                      0.2.5                h7f98852_2    conda-forge
yarl                      1.9.2                    pypi_0    pypi
ypy-websocket             0.8.2           py310h06a4308_0
zeromq                    4.3.4                h9c3ff4c_1    conda-forge
zipp                      3.15.0             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               h166bdaf_4    conda-forge
zlib-ng                   2.0.7                h5eee18b_0
zstd                      1.5.5                hc292b87_0

I tried some tinkering, and here is an example of a compilable tda_loss module, although I'm not really sure about the correctness.

code:
#!/usr/bin/env python

# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Target Discriminant Analysis Loss Function.
"""

__all__ = ["TDALoss", "tda_loss"]


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from typing import Union, List, Tuple
from warnings import warn

import torch


# =============================================================================
# LOSS FUNCTIONS
# =============================================================================


class TDALoss(torch.nn.Module):
    """Compute a loss function as the distance from a simple Gaussian target distribution."""

    def __init__(
        self,
        n_states: int,
        target_centers: Union[List[float], torch.Tensor],
        target_sigmas: Union[List[float], torch.Tensor],
        alpha: float = 1.0,
        beta: float = 100.0,
    ):
        """Constructor.

        Parameters
        ----------
        n_states : int
            Number of states. The integer labels are expected to be in between 0
            and ``n_states-1``.
        target_centers : list or torch.Tensor
            Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets.
        target_sigmas : list or torch.Tensor
            Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets.
        alpha : float, optional
            Centers_loss component prefactor, by default 1.
        beta : float, optional
            Sigmas loss compontent prefactor, by default 100.
        """
        super().__init__()
        self.n_states = n_states
        self.target_centers = target_centers
        self.target_sigmas = target_sigmas
        self.alpha = alpha
        self.beta = beta

    def forward(
        self, H: torch.Tensor, labels: torch.Tensor, return_loss_terms: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
        """Compute the value of the loss function.

        Parameters
        ----------
        H : torch.Tensor
            Shape ``(n_batches, n_features)``. Output of the NN.
        labels : torch.Tensor
            Shape ``(n_batches,)``. Labels of the dataset.
        return_loss_terms : bool, optional
            If ``True``, the loss terms associated to the center and standard
            deviations of the target Gaussians are returned as well. Default
            is ``False``.

        Returns
        -------
        loss : torch.Tensor
            Loss value.
        loss_centers : torch.Tensor, optional
            Only returned if ``return_loss_terms is True``. The value of the
            loss term associated to the centers of the target Gaussians.
        loss_sigmas : torch.Tensor, optional
            Only returned if ``return_loss_terms is True``. The value of the
            loss term associated to the standard deviations of the target Gaussians.
        """
        return tda_loss(
            H,
            labels,
            self.n_states,
            self.target_centers,
            self.target_sigmas,
            self.alpha,
            self.beta,
            return_loss_terms,
        )


def tda_loss(
    H: torch.Tensor,
    labels: torch.Tensor,
    n_states: int,
    target_centers: Union[List[float], torch.Tensor],
    target_sigmas: Union[List[float], torch.Tensor],
    alpha: float = 1,
    beta: float = 100,
    return_loss_terms: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
    """
    Compute a loss function as the distance from a simple Gaussian target distribution.

    Parameters
    ----------
    H : torch.Tensor
        Shape ``(n_batches, n_features)``. Output of the NN.
    labels : torch.Tensor
        Shape ``(n_batches,)``. Labels of the dataset.
    n_states : int
        The integer labels are expected to be in between 0 and ``n_states-1``.
    target_centers : list or torch.Tensor
        Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets.
    target_sigmas : list or torch.Tensor
        Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets.
    alpha : float, optional
        Centers_loss component prefactor, by default 1.
    beta : float, optional
        Sigmas loss compontent prefactor, by default 100.
    return_loss_terms : bool, optional
        If ``True``, the loss terms associated to the center and standard deviations
        of the target Gaussians are returned as well. Default is ``False``.

    Returns
    -------
    loss : torch.Tensor
        Loss value.
    loss_centers : torch.Tensor, optional
        Only returned if ``return_loss_terms is True``. The value of the loss
        term associated to the centers of the target Gaussians.
    loss_sigmas : torch.Tensor, optional
        Only returned if ``return_loss_terms is True``. The value of the loss
        term associated to the standard deviations of the target Gaussians.
    """
    if not isinstance(target_centers, torch.Tensor):
        target_centers = torch.tensor(target_centers)
    if not isinstance(target_sigmas, torch.Tensor):
        target_sigmas = torch.tensor(target_sigmas)

    device = H.device
    target_centers = target_centers.to(device)
    target_sigmas = target_sigmas.to(device)
    loss_centers = torch.zeros_like(target_centers, device=device)
    loss_sigmas = torch.zeros_like(target_sigmas, device=device)

    for i in range(n_states):
        # check which elements belong to class i
        if not (labels == i).any():
            raise ValueError(
                f"State {i} was not represented in this batch! Either use bigger batch_size or a more equilibrated dataset composition!"
            )
        else:
            H_red = H[labels == i]

            # compute mean and standard deviation over the class i
            mu = torch.mean(H_red, 0)
            if len(torch.nonzero(labels == i)) == 1:
                warn(
                    f"There is only one sample for state {i} in this batch! Std is set to 0, this may affect the training! Either use bigger batch_size or a more equilibrated dataset composition!"
                )
                sigma = torch.tensor(0)
            else:
                sigma = torch.std(H_red, 0)

        # compute loss function contributes for class i
        loss_centers[i] = alpha * (mu - target_centers[i]).pow(2)
        loss_sigmas[i] = beta * (sigma - target_sigmas[i]).pow(2)

    # get total model loss
    loss_centers = torch.sum(loss_centers)
    loss_sigmas = torch.sum(loss_sigmas)
    loss = loss_centers + loss_sigmas

    if return_loss_terms:
        return loss, loss_centers, loss_sigmas
    return loss
@EnricoTrizio
Copy link
Collaborator

Thanks Jintu, it seems fine and we'll fix this

@luigibonati luigibonati added the good first issue Good for newcomers label May 3, 2024
EnricoTrizio added a commit that referenced this issue May 6, 2024
@EnricoTrizio
Copy link
Collaborator

Fixed in #131

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants