diff --git a/test_graphs/test_committor.ipynb b/test_graphs/test_committor.ipynb index 5f47389..762d34e 100644 --- a/test_graphs/test_committor.ipynb +++ b/test_graphs/test_committor.ipynb @@ -2,15 +2,15 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "DictModule(dataset -> DictDataset( \"data_list\": 1700, \"z_table\": [6, 9], \"cutoff\": 8.0, \"data_type\": graphs ),\n", - "\t\t train_loader -> DictLoader(length=1, batch_size=1700, shuffle=False))\n" + "DictModule(dataset -> DictDataset( \"data_list\": 28603, \"z_table\": [6, 7, 8], \"cutoff\": 8.0, \"data_type\": graphs ),\n", + "\t\t train_loader -> DictLoader(length=1, batch_size=28603, shuffle=False))\n" ] } ], @@ -22,16 +22,23 @@ "\n", "dataset_graph = create_dataset_from_trajectories(\n", " trajectories=[\n", - " 'data/r.dcd',\n", - " 'data/p.dcd',\n", + " '/home/etrizio@iit.local/notebooks/projects/kolmogorov/alanine_transform/unbiased_sims/state_A/traj_comp.xtc',\n", + " '/home/etrizio@iit.local/notebooks/projects/kolmogorov/alanine_transform/unbiased_sims/state_B/traj_comp.xtc',\n", + " '/home/etrizio@iit.local/notebooks/projects/kolmogorov_opes/alanine_long_train/biased_sims/iter_4/A/traj_comp.xtc',\n", + " # '/home/etrizio@iit.local/notebooks/projects/kolmogorov_opes/alanine_long_train/biased_sims/iter_4/B/traj_comp.xtc'\n", + " # 'data/r.dcd',\n", + " # 'data/p.dcd',\n", " # 'data/biased.trajectory.h5',\n", - " 'data/biased.dcd',\n", + " # 'data/biased.dcd',\n", " #'data/biased.trajectory.h5',\n", " #'data/r.dcd'\n", " ],\n", - " top=['data/r.pdb', \n", - " 'data/p.pdb',\n", - " 'data/r.pdb',\n", + " top=['/home/etrizio@iit.local/notebooks/projects/kolmogorov/alanine_transform/unbiased_sims/state_A/confout.gro',\n", + " '/home/etrizio@iit.local/notebooks/projects/kolmogorov/alanine_transform/unbiased_sims/state_A/confout.gro',\n", + " '/home/etrizio@iit.local/notebooks/projects/kolmogorov/alanine_transform/unbiased_sims/state_A/confout.gro'\n", + " # 'data/r.pdb', \n", + " # 'data/p.pdb',\n", + " # 'data/r.pdb',\n", " #'data/r.pdb'\n", " ],\n", " cutoff=8.0, # Ang\n", @@ -47,15 +54,20 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from mlcolvar.utils.io import load_dataframe\n", "\n", - "df = load_dataframe('data/colvar', stride=100)\n", - "weights = torch.exp(1*torch.tensor(df['opes.bias'].values))\n", + "T = 300 \n", + "# Boltzmann factor in the RIGHT ENRGY UNITS!\n", + "kb = 0.0083144621\n", + "beta = 1/(kb*T)\n", + "\n", + "df = load_dataframe('/home/etrizio@iit.local/notebooks/projects/kolmogorov_opes/alanine_long_train/biased_sims/iter_4/A/COLVAR', stride=1)\n", + "weights = torch.exp(1/beta*torch.tensor((df['opes.bias'] + df['bias']).values))\n", "weights = weights / weights.sum()\n", "weights\n", "\n", @@ -76,7 +88,7 @@ { "data": { "text/plain": [ - "1700" + "28603" ] }, "execution_count": 2, @@ -90,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -109,7 +121,7 @@ " (_radial_embedding): RadialEmbeddingBlock(\n", " (bessel_fn): GAUSSIANBASIS [ \u001b[32m16\u001b[0m\u001b[36m 󰯰 \u001b[0m| \u001b[32m8.000000\u001b[0m\u001b[36m 󰳁 \u001b[0m]\n", " )\n", - " (W_v): Linear(in_features=2, out_features=32, bias=False)\n", + " (W_v): Linear(in_features=3, out_features=32, bias=False)\n", " (layers): ModuleList(\n", " (0-1): 2 x InteractionBlock(\n", " (mlp): Sequential(\n", @@ -132,7 +144,7 @@ ")" ] }, - "execution_count": 29, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -151,16 +163,22 @@ " n_hidden_channels=32,\n", " )\n", "\n", + "# model = Committor(model=gnn_model,\n", + "# mass=torch.Tensor([12, 19]),\n", + "# alpha=1)\n", + "\n", + "\n", "model = Committor(model=gnn_model,\n", - " mass=torch.Tensor([12, 19]),\n", + " mass=torch.Tensor([12, 14, 16]),\n", " alpha=1)\n", "\n", + "\n", "model" ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -178,7 +196,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "020e80012d2d46809f41270cdb5692e2", + "model_id": "0b16af39d1674ea68b070e8ae805baf9", "version_major": 2, "version_minor": 0 }, @@ -190,10 +208,46 @@ "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "`Trainer.fit` stopped: `max_epochs=5000` reached.\n" + "ename": "OutOfMemoryError", + "evalue": "CUDA out of memory. Tried to allocate 316.00 MiB. GPU 0 has a total capacty of 7.79 GiB of which 244.31 MiB is free. Including non-PyTorch memory, this process has 6.88 GiB memory in use. Of the allocated memory 6.61 GiB is allocated by PyTorch, and 158.40 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 13\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mlightning\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Trainer\n\u001b[1;32m 3\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Trainer(\n\u001b[1;32m 4\u001b[0m logger\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 5\u001b[0m enable_checkpointing\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 10\u001b[0m num_sanity_val_steps\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 11\u001b[0m )\n\u001b[0;32m---> 13\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule_graph\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:544\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 542\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 543\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 544\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 545\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 546\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:44\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 44\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 47\u001b[0m _call_teardown_hook(trainer)\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:580\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 573\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 574\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 575\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 576\u001b[0m ckpt_path,\n\u001b[1;32m 577\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 578\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 579\u001b[0m )\n\u001b[0;32m--> 580\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 582\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 583\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:989\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 984\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 986\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 987\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 988\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 989\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 991\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 992\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 993\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 994\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1035\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1033\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1034\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1035\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1036\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1037\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:202\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 202\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:359\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 358\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 359\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/loops/training_epoch_loop.py:136\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 135\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 136\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/loops/training_epoch_loop.py:240\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 238\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 239\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 240\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 241\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 242\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:187\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 180\u001b[0m closure()\n\u001b[1;32m 182\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 183\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 184\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 187\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 189\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 190\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:265\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 264\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 265\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 266\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 267\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 275\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:157\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 154\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 157\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 160\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/core/module.py:1291\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1252\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1253\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1254\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1257\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1258\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1259\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1260\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1261\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1289\u001b[0m \n\u001b[1;32m 1290\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1291\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/core/optimizer.py:151\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 151\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py:230\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 230\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/plugins/precision/precision.py:117\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 116\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 117\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/torch/optim/optimizer.py:373\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 368\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 369\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 370\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 371\u001b[0m )\n\u001b[0;32m--> 373\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 374\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 376\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/torch/optim/optimizer.py:76\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 74\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[1;32m 75\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 76\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 78\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/torch/optim/adam.py:143\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 143\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 146\u001b[0m params_with_grad \u001b[38;5;241m=\u001b[39m []\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/plugins/precision/precision.py:104\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 93\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 94\u001b[0m optimizer: Optimizer,\n\u001b[1;32m 95\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 96\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 97\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 98\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 99\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 102\u001b[0m \n\u001b[1;32m 103\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 104\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 106\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:140\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:126\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 125\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 126\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 129\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py:315\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 312\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[1;32m 314\u001b[0m \u001b[38;5;66;03m# manually capture logged metrics\u001b[39;00m\n\u001b[0;32m--> 315\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_result_cls\u001b[38;5;241m.\u001b[39mfrom_training_step_output(training_step_output, trainer\u001b[38;5;241m.\u001b[39maccumulate_grad_batches)\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:309\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 308\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 309\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 311\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 312\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py:382\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 381\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 382\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Bin/dev/mlcolvar/mlcolvar/cvs/committor/committor.py:136\u001b[0m, in \u001b[0;36mCommittor.training_step\u001b[0;34m(self, train_batch, batch_idx)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;66;03m# ===================loss=====================\u001b[39;00m\n\u001b[1;32m 135\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining:\n\u001b[0;32m--> 136\u001b[0m loss, loss_var, loss_bound_A, loss_bound_B \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloss_fn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 137\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweights\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 138\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 140\u001b[0m loss, loss_var, loss_bound_A, loss_bound_B \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloss_fn(\n\u001b[1;32m 141\u001b[0m x, q, labels, weights \n\u001b[1;32m 142\u001b[0m )\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Bin/dev/mlcolvar/mlcolvar/core/loss/committor_loss.py:77\u001b[0m, in \u001b[0;36mCommittorLoss.forward\u001b[0;34m(self, x, q, labels, w, create_graph)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \n\u001b[1;32m 71\u001b[0m x: Union[torch\u001b[38;5;241m.\u001b[39mTensor, torch_geometric\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mBatch], \n\u001b[1;32m 72\u001b[0m q: torch\u001b[38;5;241m.\u001b[39mTensor, \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 75\u001b[0m create_graph: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 76\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[0;32m---> 77\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcommittor_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 78\u001b[0m \u001b[43m \u001b[49m\u001b[43mq\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[43m \u001b[49m\u001b[43mw\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mw\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 81\u001b[0m \u001b[43m \u001b[49m\u001b[43matomic_masses\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43matomic_masses\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 82\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 83\u001b[0m \u001b[43m \u001b[49m\u001b[43mgamma\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgamma\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 84\u001b[0m \u001b[43m \u001b[49m\u001b[43mdelta_f\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelta_f\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 85\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 86\u001b[0m \u001b[43m \u001b[49m\u001b[43mseparate_boundary_dataset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mseparate_boundary_dataset\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 87\u001b[0m \u001b[43m \u001b[49m\u001b[43mdescriptors_derivatives\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdescriptors_derivatives\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 88\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_var\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_var\u001b[49m\n\u001b[1;32m 89\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Bin/dev/mlcolvar/mlcolvar/core/loss/committor_loss.py:207\u001b[0m, in \u001b[0;36mcommittor_loss\u001b[0;34m(x, q, labels, w, atomic_masses, alpha, gamma, delta_f, create_graph, separate_boundary_dataset, descriptors_derivatives, log_var)\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[38;5;66;03m# ============================== LOSS ==============================\u001b[39;00m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;66;03m# Each loss contribution is scaled by the number of samples\u001b[39;00m\n\u001b[1;32m 203\u001b[0m \n\u001b[1;32m 204\u001b[0m \u001b[38;5;66;03m# 1. VARIATIONAL LOSS\u001b[39;00m\n\u001b[1;32m 205\u001b[0m \u001b[38;5;66;03m# Compute gradients of q(x) wrt x\u001b[39;00m\n\u001b[1;32m 206\u001b[0m grad_outputs \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mones_like(q[mask_var])\n\u001b[0;32m--> 207\u001b[0m grad \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mq\u001b[49m\u001b[43m[\u001b[49m\u001b[43mmask_var\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_outputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgrad_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcreate_graph\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 208\u001b[0m grad \u001b[38;5;241m=\u001b[39m grad[mask_var_batches]\n\u001b[1;32m 209\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m descriptors_derivatives \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 210\u001b[0m \u001b[38;5;66;03m# we use the precomputed derivatives from descriptors to pos\u001b[39;00m\n", + "File \u001b[0;32m~/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/torch/autograd/__init__.py:394\u001b[0m, in \u001b[0;36mgrad\u001b[0;34m(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched, materialize_grads)\u001b[0m\n\u001b[1;32m 390\u001b[0m result \u001b[38;5;241m=\u001b[39m _vmap_internals\u001b[38;5;241m.\u001b[39m_vmap(vjp, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m, allow_none_pass_through\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)(\n\u001b[1;32m 391\u001b[0m grad_outputs_\n\u001b[1;32m 392\u001b[0m )\n\u001b[1;32m 393\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 394\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 395\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 396\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_outputs_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 397\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 398\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 399\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 400\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unused\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 401\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 402\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m 403\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m materialize_grads:\n\u001b[1;32m 404\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(\n\u001b[1;32m 405\u001b[0m output\n\u001b[1;32m 406\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 407\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mzeros_like(\u001b[38;5;28minput\u001b[39m, requires_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 408\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m (output, \u001b[38;5;28minput\u001b[39m) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(result, t_inputs)\n\u001b[1;32m 409\u001b[0m )\n", + "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 316.00 MiB. GPU 0 has a total capacty of 7.79 GiB of which 244.31 MiB is free. Including non-PyTorch memory, this process has 6.88 GiB memory in use. Of the allocated memory 6.61 GiB is allocated by PyTorch, and 158.40 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF" ] } ], @@ -204,7 +258,7 @@ " logger=False,\n", " enable_checkpointing=False,\n", " accelerator='cuda',\n", - " max_epochs=5000,\n", + " max_epochs=500,\n", " enable_model_summary=False,\n", " limit_val_batches=0, \n", " num_sanity_val_steps=0\n", diff --git a/test_graphs/test_graph.ipynb b/test_graphs/test_graph.ipynb index 908189b..58d23a2 100644 --- a/test_graphs/test_graph.ipynb +++ b/test_graphs/test_graph.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 9, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -21,6 +21,7 @@ } ], "source": [ + "import torch\n", "from mlcolvar.data import DictModule\n", "from mlcolvar.utils.io import create_dataset_from_trajectories\n", "from mlcolvar.utils.io import create_dataset_from_files\n", @@ -52,7 +53,433 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/home/etrizio@iit.local/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b4a53ab5e07749d0aea380482e04a522", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00" ] @@ -686,6 +1105,403 @@ "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'epoch': 4,\n", + " 'global_step': 5,\n", + " 'pytorch-lightning_version': '2.1.3',\n", + " 'state_dict': OrderedDict([('loss_fn.target_centers', tensor([-7., 7.])),\n", + " ('loss_fn.target_sigmas', tensor([0.2000, 0.2000])),\n", + " ('nn.nn.0.weight',\n", + " tensor([[ 0.1091, -0.1932, -0.1206, -0.0681, 0.1556, 0.1799, 0.1149, -0.2043,\n", + " -0.0908, -0.0390, -0.0684, -0.0457, 0.0788, 0.1463, -0.1411, -0.1737,\n", + " 0.1678, 0.1805, 0.2164, -0.1644, -0.0258],\n", + " [ 0.1690, 0.1007, -0.0752, 0.1875, 0.1510, 0.0228, 0.1875, 0.2199,\n", + " 0.0580, 0.1580, 0.1729, 0.1864, -0.0042, 0.1831, 0.1117, 0.0253,\n", + " 0.0455, 0.2125, -0.1677, 0.0447, -0.1825],\n", + " [ 0.0161, 0.1482, -0.0858, -0.0592, 0.0975, -0.1783, 0.0376, 0.1408,\n", + " 0.1791, 0.1822, -0.1982, 0.2098, -0.0159, -0.0619, -0.0960, -0.0942,\n", + " -0.1535, 0.0209, -0.1388, -0.1435, 0.1325],\n", + " [-0.1178, 0.0437, 0.0124, -0.1600, 0.0760, 0.1467, -0.0989, 0.0063,\n", + " -0.1764, -0.1864, -0.0645, -0.0613, 0.0868, -0.2056, 0.0756, 0.0821,\n", + " -0.0335, -0.1973, -0.1012, -0.1909, 0.0805],\n", + " [-0.2068, -0.0470, 0.0194, -0.0922, -0.1448, -0.1430, 0.0428, -0.0748,\n", + " 0.1746, -0.1001, 0.0089, -0.0171, -0.0882, 0.1353, -0.1607, -0.1919,\n", + " -0.1167, -0.1318, 0.0544, 0.0299, -0.0761],\n", + " [ 0.0451, 0.1221, 0.0529, 0.1372, 0.1856, -0.1394, -0.0864, 0.1279,\n", + " -0.0794, 0.0893, -0.0387, 0.0769, -0.0881, -0.0817, 0.0155, -0.1930,\n", + " 0.0596, -0.1755, 0.0113, 0.0951, -0.0341],\n", + " [ 0.1832, 0.0194, 0.1026, -0.1198, -0.0984, 0.1239, -0.0472, -0.2238,\n", + " 0.1610, 0.1594, 0.0977, -0.1394, 0.1871, -0.1873, 0.0732, 0.0750,\n", + " 0.0214, 0.0256, -0.0510, 0.0816, 0.2049],\n", + " [-0.1987, -0.0250, -0.0757, 0.0912, -0.0272, -0.1131, -0.1551, 0.2095,\n", + " 0.0775, 0.1406, 0.1878, 0.0675, 0.1871, -0.1933, -0.0769, 0.1693,\n", + " 0.2012, -0.1832, 0.1627, -0.1376, -0.1235],\n", + " [-0.0159, -0.0251, -0.2143, -0.1840, -0.0242, 0.1255, -0.1588, -0.1978,\n", + " -0.0974, 0.2079, -0.0857, -0.1865, -0.0769, 0.1807, -0.0904, -0.2169,\n", + " 0.1189, -0.1440, 0.1711, 0.0855, 0.1139],\n", + " [ 0.1700, 0.0159, 0.0643, -0.1504, -0.1127, 0.0806, -0.0737, -0.1117,\n", + " -0.0987, -0.1546, 0.0140, -0.0475, 0.1823, -0.0481, 0.1926, 0.0935,\n", + " 0.1058, 0.0454, 0.0127, 0.0891, -0.0544],\n", + " [-0.0002, -0.0582, -0.0854, 0.2017, 0.0692, 0.1184, 0.1207, 0.1331,\n", + " 0.0864, -0.0704, 0.1787, 0.0275, -0.1465, 0.0042, 0.1829, -0.0312,\n", + " -0.2200, 0.0996, -0.2029, 0.1268, -0.1038],\n", + " [ 0.0136, -0.0978, 0.1791, 0.0506, -0.0822, 0.0442, -0.1463, 0.1748,\n", + " -0.1854, 0.1354, -0.0795, 0.0344, 0.2072, -0.0770, -0.0250, 0.0210,\n", + " -0.1289, -0.0499, -0.0376, 0.1949, 0.0601],\n", + " [ 0.1683, -0.1752, -0.1138, 0.1663, -0.0030, -0.0636, -0.1552, 0.1392,\n", + " 0.1685, -0.1308, -0.1843, -0.1079, -0.1469, -0.0323, -0.0278, -0.2174,\n", + " -0.1844, -0.0565, 0.0696, 0.0376, -0.0199],\n", + " [ 0.1238, -0.1107, -0.1061, -0.0396, -0.1200, -0.1036, 0.1609, -0.1744,\n", + " 0.2047, 0.0651, 0.0231, 0.0436, -0.2161, 0.1388, 0.2217, 0.0804,\n", + " -0.1238, -0.0067, -0.2067, 0.2264, -0.1849],\n", + " [-0.1626, 0.1413, -0.0193, 0.1994, -0.1003, 0.1011, 0.0298, -0.1755,\n", + " 0.0680, 0.0070, 0.2111, 0.0611, -0.1107, 0.2134, 0.0470, -0.0835,\n", + " -0.0261, 0.1392, -0.1695, 0.2148, 0.1370]])),\n", + " ('nn.nn.0.bias',\n", + " tensor([ 0.0900, 0.0308, -0.1071, -0.1334, 0.0102, 0.0913, 0.0692, 0.1084,\n", + " -0.0914, 0.1140, 0.0671, 0.1281, 0.0583, 0.0690, 0.1046])),\n", + " ('nn.nn.2.weight',\n", + " tensor([[ 0.0915, -0.0500, 0.0047, 0.2131, -0.0152, 0.2466, -0.0150, 0.2662,\n", + " -0.0901, -0.0439, 0.0211, 0.1409, -0.0744, 0.0863, 0.0169],\n", + " [ 0.2359, -0.1307, 0.1628, -0.0547, -0.2153, 0.0409, -0.1520, -0.2040,\n", + " 0.0957, -0.2015, -0.1916, 0.0876, -0.2271, 0.0377, 0.1216],\n", + " [-0.1357, -0.2273, -0.1212, 0.0699, 0.1139, 0.0471, 0.1090, 0.1206,\n", + " 0.0937, -0.2334, 0.2370, -0.2110, -0.2154, 0.1292, 0.2160],\n", + " [ 0.1417, -0.0689, 0.1739, 0.1697, 0.0167, 0.1179, -0.2494, 0.1554,\n", + " 0.1904, -0.1728, -0.2117, -0.1294, -0.0025, 0.0746, 0.2042],\n", + " [-0.2106, -0.0038, -0.2492, 0.1139, -0.1676, 0.1174, -0.1071, 0.0586,\n", + " -0.2236, -0.0618, 0.2575, 0.2394, -0.1630, -0.0699, -0.0600],\n", + " [-0.0289, -0.1661, 0.1261, 0.1911, -0.1126, 0.0430, -0.0270, 0.2116,\n", + " 0.0193, -0.0740, -0.0703, -0.0849, -0.0675, 0.1081, 0.1828],\n", + " [-0.0442, 0.2096, 0.0064, 0.0248, -0.1210, -0.2544, 0.1305, -0.0803,\n", + " -0.0339, -0.1462, -0.2000, -0.0745, 0.0805, -0.1077, -0.1239],\n", + " [ 0.2284, -0.2315, -0.0981, 0.1649, -0.2322, -0.0236, -0.2408, -0.1701,\n", + " 0.1063, 0.2036, 0.1024, 0.2182, -0.0476, -0.2530, 0.1508],\n", + " [-0.0808, -0.2526, 0.0910, -0.0245, -0.2045, -0.1999, -0.0636, -0.0950,\n", + " 0.0953, -0.0798, -0.2545, 0.1149, 0.0773, 0.0743, 0.1952],\n", + " [ 0.0481, 0.0103, -0.0342, -0.1739, -0.0711, 0.1968, -0.1936, -0.1705,\n", + " 0.0499, 0.0205, 0.2502, 0.0007, -0.1033, -0.2411, -0.1898]])),\n", + " ('nn.nn.2.bias',\n", + " tensor([ 0.0205, 0.1057, 0.2199, -0.1365, -0.2577, -0.1165, 0.0104, -0.1540,\n", + " 0.2311, -0.1104])),\n", + " ('nn.nn.4.weight',\n", + " tensor([[-0.0132, 0.1431, 0.2450, 0.0389, 0.1609, -0.1307, -0.0078, 0.1824,\n", + " 0.0553, 0.2282]])),\n", + " ('nn.nn.4.bias', tensor([0.1878]))]),\n", + " 'loops': {'fit_loop': {'state_dict': {},\n", + " 'epoch_loop.state_dict': {'_batches_that_stepped': 5},\n", + " 'epoch_loop.batch_progress': {'total': {'ready': 5,\n", + " 'completed': 5,\n", + " 'started': 5,\n", + " 'processed': 5},\n", + " 'current': {'ready': 1, 'completed': 1, 'started': 1, 'processed': 1},\n", + " 'is_last_batch': True},\n", + " 'epoch_loop.scheduler_progress': {'total': {'ready': 0, 'completed': 0},\n", + " 'current': {'ready': 0, 'completed': 0}},\n", + " 'epoch_loop.automatic_optimization.state_dict': {},\n", + " 'epoch_loop.automatic_optimization.optim_progress': {'optimizer': {'step': {'total': {'ready': 5,\n", + " 'completed': 5},\n", + " 'current': {'ready': 1, 'completed': 1}},\n", + " 'zero_grad': {'total': {'ready': 5, 'completed': 5, 'started': 5},\n", + " 'current': {'ready': 1, 'completed': 1, 'started': 1}}}},\n", + " 'epoch_loop.manual_optimization.state_dict': {},\n", + " 'epoch_loop.manual_optimization.optim_step_progress': {'total': {'ready': 0,\n", + " 'completed': 0},\n", + " 'current': {'ready': 0, 'completed': 0}},\n", + " 'epoch_loop.val_loop.state_dict': {},\n", + " 'epoch_loop.val_loop.batch_progress': {'total': {'ready': 1,\n", + " 'completed': 1,\n", + " 'started': 1,\n", + " 'processed': 1},\n", + " 'current': {'ready': 1, 'completed': 1, 'started': 1, 'processed': 1},\n", + " 'is_last_batch': True},\n", + " 'epoch_progress': {'total': {'ready': 5,\n", + " 'completed': 4,\n", + " 'started': 5,\n", + " 'processed': 5},\n", + " 'current': {'ready': 5, 'completed': 4, 'started': 5, 'processed': 5}}},\n", + " 'validate_loop': {'state_dict': {},\n", + " 'batch_progress': {'total': {'ready': 0,\n", + " 'completed': 0,\n", + " 'started': 0,\n", + " 'processed': 0},\n", + " 'current': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0},\n", + " 'is_last_batch': False}},\n", + " 'test_loop': {'state_dict': {},\n", + " 'batch_progress': {'total': {'ready': 0,\n", + " 'completed': 0,\n", + " 'started': 0,\n", + " 'processed': 0},\n", + " 'current': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0},\n", + " 'is_last_batch': False}},\n", + " 'predict_loop': {'state_dict': {},\n", + " 'batch_progress': {'total': {'ready': 0,\n", + " 'completed': 0,\n", + " 'started': 0,\n", + " 'processed': 0},\n", + " 'current': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}}}},\n", + " 'callbacks': {\"ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}\": {'monitor': None,\n", + " 'best_model_score': None,\n", + " 'best_model_path': '/home/etrizio@iit.local/Bin/dev/mlcolvar/test_graphs/checkpoints/epoch=4-step=5.ckpt',\n", + " 'current_score': None,\n", + " 'dirpath': '/home/etrizio@iit.local/Bin/dev/mlcolvar/test_graphs/checkpoints',\n", + " 'best_k_models': {},\n", + " 'kth_best_model_path': '',\n", + " 'kth_value': tensor(inf),\n", + " 'last_model_path': ''}},\n", + " 'optimizer_states': [{'state': {0: {'step': tensor(5.),\n", + " 'exp_avg': tensor([[-3.5984e-04, -4.0911e-05, -8.8112e-06, -2.7854e-03, 1.7126e-04,\n", + " -3.3620e-04, -4.5149e-04, -1.1946e-03, 1.8959e-03, -4.6310e-04,\n", + " -3.4691e-04, -2.7802e-04, -3.7837e-03, 1.0556e-03, -8.9433e-05,\n", + " -1.7654e-03, -1.1040e-03, -1.9409e-03, 1.0916e-03, 2.0371e-03,\n", + " -3.0389e-04],\n", + " [-8.1956e-03, -9.4709e-04, -2.1276e-04, -6.3275e-02, 3.8211e-03,\n", + " -7.6471e-03, -1.0253e-02, -2.7176e-02, 4.3028e-02, -1.0535e-02,\n", + " -7.8866e-03, -6.3373e-03, -8.5962e-02, 2.3908e-02, -1.9752e-03,\n", + " -4.0119e-02, -2.5120e-02, -4.4181e-02, 2.4700e-02, 4.6257e-02,\n", + " -6.9061e-03],\n", + " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00],\n", + " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00],\n", + " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00],\n", + " [ 2.7277e-04, 3.1460e-05, 7.0329e-06, 2.1066e-03, -1.2749e-04,\n", + " 2.5455e-04, 3.4137e-04, 9.0460e-04, -1.4327e-03, 3.5068e-04,\n", + " 2.6254e-04, 2.1090e-04, 2.8619e-03, -7.9626e-04, 6.5984e-05,\n", + " 1.3356e-03, 8.3614e-04, 1.4705e-03, -8.2272e-04, -1.5401e-03,\n", + " 2.2991e-04],\n", + " [ 3.0178e-04, 3.5688e-05, 8.4767e-06, 2.3212e-03, -1.3647e-04,\n", + " 2.8104e-04, 3.7597e-04, 9.9900e-04, -1.5762e-03, 3.8727e-04,\n", + " 2.8963e-04, 2.3364e-04, 3.1539e-03, -8.7324e-04, 6.9448e-05,\n", + " 1.4726e-03, 9.2364e-04, 1.6255e-03, -9.0095e-04, -1.6959e-03,\n", + " 2.5348e-04],\n", + " [-5.7066e-04, -6.4770e-05, -1.3887e-05, -4.4184e-03, 2.7217e-04,\n", + " -5.3324e-04, -7.1621e-04, -1.8946e-03, 3.0078e-03, -7.3451e-04,\n", + " -5.5025e-04, -4.4086e-04, -6.0020e-03, 1.6750e-03, -1.4227e-04,\n", + " -2.8003e-03, -1.7510e-03, -3.0783e-03, 1.7322e-03, 3.2315e-03,\n", + " -4.8204e-04],\n", + " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00],\n", + " [-8.4292e-03, -9.7425e-04, -2.1896e-04, -6.5077e-02, 3.9291e-03,\n", + " -7.8649e-03, -1.0545e-02, -2.7950e-02, 4.4253e-02, -1.0835e-02,\n", + " -8.1112e-03, -6.5180e-03, -8.8410e-02, 2.4588e-02, -2.0308e-03,\n", + " -4.1262e-02, -2.5835e-02, -4.5440e-02, 2.5402e-02, 4.7574e-02,\n", + " -7.1027e-03],\n", + " [ 1.7536e-03, 2.0391e-04, 4.6520e-05, 1.3525e-02, -8.1100e-04,\n", + " 1.6354e-03, 2.1914e-03, 5.8120e-03, -9.1938e-03, 2.2531e-03,\n", + " 1.6863e-03, 1.3564e-03, 1.8375e-02, -5.1044e-03, 4.1752e-04,\n", + " 8.5767e-03, 5.3726e-03, 9.4510e-03, -5.2716e-03, -9.8859e-03,\n", + " 1.4764e-03],\n", + " [-3.1876e-03, -3.6807e-04, -8.2524e-05, -2.4613e-02, 1.4876e-03,\n", + " -2.9744e-03, -3.9884e-03, -1.0570e-02, 1.6738e-02, -4.0977e-03,\n", + " -3.0677e-03, -2.4647e-03, -3.3438e-02, 9.3013e-03, -7.6937e-04,\n", + " -1.5605e-02, -9.7704e-03, -1.7184e-02, 9.6098e-03, 1.7994e-02,\n", + " -2.6863e-03],\n", + " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00],\n", + " [ 1.5152e-04, -2.9981e-03, -3.3844e-03, 2.2486e-02, -7.5181e-03,\n", + " -1.9135e-03, -1.2430e-03, 6.0865e-03, -2.3053e-02, 1.0649e-03,\n", + " 3.3099e-04, -3.1508e-03, 2.9949e-02, -1.8428e-02, -5.0736e-03,\n", + " 1.0704e-02, 3.3760e-03, 1.1371e-02, -1.7308e-02, -2.6344e-02,\n", + " -2.4798e-03],\n", + " [ 7.5502e-03, 8.7246e-04, 1.9596e-04, 5.8293e-02, -3.5205e-03,\n", + " 7.0449e-03, 9.4460e-03, 2.5036e-02, -3.9640e-02, 9.7055e-03,\n", + " 7.2656e-03, 5.8383e-03, 7.9194e-02, -2.2026e-02, 1.8198e-03,\n", + " 3.6960e-02, 2.3142e-02, 4.0702e-02, -2.2756e-02, -4.2615e-02,\n", + " 6.3623e-03]]),\n", + " 'exp_avg_sq': tensor([[3.7763e-09, 4.8915e-11, 2.4611e-12, 2.2653e-07, 8.7240e-10, 3.2982e-09,\n", + " 5.9526e-09, 4.1636e-08, 1.0501e-07, 6.2576e-09, 3.5125e-09, 2.2536e-09,\n", + " 4.1799e-07, 3.2594e-08, 2.4217e-10, 9.0979e-08, 3.5560e-08, 1.0989e-07,\n", + " 3.4875e-08, 1.2118e-07, 2.6959e-09],\n", + " [1.9973e-06, 2.7624e-08, 1.6194e-09, 1.1857e-04, 4.2549e-07, 1.7351e-06,\n", + " 3.1119e-06, 2.1919e-05, 5.4749e-05, 3.2940e-06, 1.8441e-06, 1.1959e-06,\n", + " 2.1886e-04, 1.6856e-05, 1.1384e-07, 4.7692e-05, 1.8732e-05, 5.7986e-05,\n", + " 1.7971e-05, 6.3328e-05, 1.4132e-06],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [2.2000e-09, 3.0197e-11, 1.7421e-12, 1.3077e-07, 4.7342e-10, 1.9124e-09,\n", + " 3.4326e-09, 2.4157e-08, 6.0411e-08, 3.6304e-09, 2.0330e-09, 1.3168e-09,\n", + " 2.4137e-07, 1.8618e-08, 1.2727e-10, 5.2589e-08, 2.0644e-08, 6.3890e-08,\n", + " 1.9857e-08, 6.9857e-08, 1.5583e-09],\n", + " [3.0865e-09, 4.6301e-11, 3.1668e-12, 1.8069e-07, 5.8446e-10, 2.6620e-09,\n", + " 4.7344e-09, 3.3663e-08, 8.2994e-08, 5.0582e-09, 2.8217e-09, 1.8561e-09,\n", + " 3.3370e-07, 2.5266e-08, 1.4693e-10, 7.2829e-08, 2.8793e-08, 8.9334e-08,\n", + " 2.6800e-08, 9.6300e-08, 2.1577e-09],\n", + " [9.5133e-09, 1.2241e-10, 6.0525e-12, 5.7125e-07, 2.2145e-09, 8.3133e-09,\n", + " 1.5013e-08, 1.0494e-07, 2.6490e-07, 1.5772e-08, 8.8550e-09, 5.6754e-09,\n", + " 1.0540e-06, 8.2289e-08, 6.1673e-10, 2.2939e-07, 8.9618e-08, 2.7691e-07,\n", + " 8.8079e-08, 3.0564e-07, 6.7975e-09],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [2.1139e-06, 2.9258e-08, 1.7177e-09, 1.2548e-04, 4.4993e-07, 1.8363e-06,\n", + " 3.2932e-06, 2.3198e-05, 5.7938e-05, 3.4862e-06, 1.9516e-06, 1.2658e-06,\n", + " 2.3162e-04, 1.7837e-05, 1.2032e-07, 5.0473e-05, 1.9826e-05, 6.1372e-05,\n", + " 1.9015e-05, 6.7018e-05, 1.4956e-06],\n", + " [9.3667e-08, 1.3272e-09, 8.1737e-11, 5.5382e-06, 1.9311e-08, 8.1203e-08,\n", + " 1.4529e-07, 1.0261e-06, 2.5534e-06, 1.5420e-07, 8.6235e-08, 5.6155e-08,\n", + " 1.0224e-05, 7.8365e-07, 5.0832e-09, 2.2290e-06, 8.7714e-07, 2.7170e-06,\n", + " 8.3424e-07, 2.9562e-06, 6.6045e-08],\n", + " [3.0141e-07, 4.1560e-09, 2.4205e-10, 1.7902e-05, 6.4473e-08, 2.6192e-07,\n", + " 4.6989e-07, 3.3086e-06, 8.2681e-06, 4.9722e-07, 2.7839e-07, 1.8045e-07,\n", + " 3.3045e-05, 2.5466e-06, 1.7283e-08, 7.2004e-06, 2.8275e-06, 8.7519e-06,\n", + " 2.7154e-06, 9.5626e-06, 2.1336e-07],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [2.1475e-07, 4.0900e-07, 4.7883e-07, 1.6241e-05, 2.0259e-06, 5.3709e-07,\n", + " 5.6838e-07, 1.9295e-06, 1.5508e-05, 2.8227e-07, 1.7979e-07, 7.9329e-07,\n", + " 2.9158e-05, 1.0315e-05, 1.1692e-06, 4.6807e-06, 1.5095e-06, 5.6349e-06,\n", + " 8.9746e-06, 2.0288e-05, 6.1525e-07],\n", + " [1.6948e-06, 2.3435e-08, 1.3731e-09, 1.0061e-04, 3.6116e-07, 1.4723e-06,\n", + " 2.6407e-06, 1.8599e-05, 4.6460e-05, 2.7951e-06, 1.5648e-06, 1.0148e-06,\n", + " 1.8572e-04, 1.4305e-05, 9.6640e-08, 4.0470e-05, 1.5896e-05, 4.9205e-05,\n", + " 1.5250e-05, 5.3739e-05, 1.1992e-06]])},\n", + " 1: {'step': tensor(5.),\n", + " 'exp_avg': tensor([-0.0014, -0.0330, 0.0000, 0.0000, 0.0000, 0.0011, 0.0012, -0.0023,\n", + " 0.0000, -0.0339, 0.0071, -0.0128, 0.0000, -0.0087, 0.0304]),\n", + " 'exp_avg_sq': tensor([6.1061e-08, 3.2349e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.5625e-08,\n", + " 5.0100e-08, 1.5380e-07, 0.0000e+00, 3.4239e-05, 1.5180e-06, 4.8814e-06,\n", + " 0.0000e+00, 1.0426e-05, 2.7449e-05])},\n", + " 2: {'step': tensor(5.),\n", + " 'exp_avg': tensor([[-0.0009, -0.0037, 0.0000, 0.0000, 0.0000, 0.0007, -0.0014, -0.0044,\n", + " 0.0000, -0.0018, -0.0003, -0.0032, 0.0000, 0.0046, 0.0006],\n", + " [ 0.0120, 0.0495, 0.0000, 0.0000, 0.0000, -0.0091, 0.0187, 0.0585,\n", + " 0.0000, 0.0243, 0.0044, 0.0434, 0.0000, -0.0618, -0.0082],\n", + " [ 0.0207, 0.0854, 0.0000, 0.0000, 0.0000, -0.0157, 0.0323, 0.1011,\n", + " 0.0000, 0.0419, 0.0076, 0.0749, 0.0000, -0.1066, -0.0142],\n", + " [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", + " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", + " [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", + " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", + " [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", + " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", + " [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", + " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", + " [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", + " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", + " [ 0.0045, 0.0185, 0.0000, 0.0000, 0.0000, -0.0034, 0.0070, 0.0219,\n", + " 0.0000, 0.0091, 0.0016, 0.0163, 0.0000, -0.0231, -0.0031],\n", + " [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", + " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]),\n", + " 'exp_avg_sq': tensor([[2.3564e-08, 3.9634e-07, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.4035e-08,\n", + " 5.5842e-08, 5.5526e-07, 0.0000e+00, 9.6596e-08, 3.1520e-09, 3.0598e-07,\n", + " 0.0000e+00, 6.1716e-07, 1.1620e-08],\n", + " [4.2216e-06, 7.2331e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.4081e-06,\n", + " 1.0581e-05, 1.0120e-04, 0.0000e+00, 1.7315e-05, 6.5603e-07, 5.5476e-05,\n", + " 0.0000e+00, 1.1279e-04, 1.9772e-06],\n", + " [1.2592e-05, 2.1589e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 7.1721e-06,\n", + " 3.1621e-05, 3.0205e-04, 0.0000e+00, 5.1649e-05, 1.9663e-06, 1.6554e-04,\n", + " 0.0000e+00, 3.3667e-04, 5.8869e-06],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [5.9102e-07, 1.0101e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.3919e-07,\n", + " 1.4702e-06, 1.4135e-05, 0.0000e+00, 2.4240e-06, 9.0072e-08, 7.7540e-06,\n", + " 0.0000e+00, 1.5748e-05, 2.7883e-07],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00]])},\n", + " 3: {'step': tensor(5.),\n", + " 'exp_avg': tensor([-0.0040, 0.0545, 0.0941, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", + " 0.0204, 0.0000]),\n", + " 'exp_avg_sq': tensor([4.7714e-07, 8.8522e-05, 2.6436e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 1.2334e-05, 0.0000e+00])},\n", + " 4: {'step': tensor(5.),\n", + " 'exp_avg': tensor([[ 0.0821, -0.1237, -0.1054, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", + " -0.0650, 0.0000]]),\n", + " 'exp_avg_sq': tensor([[0.0002, 0.0005, 0.0003, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0001,\n", + " 0.0000]])},\n", + " 5: {'step': tensor(5.),\n", + " 'exp_avg': tensor([0.3885]),\n", + " 'exp_avg_sq': tensor([0.0045])}},\n", + " 'param_groups': [{'lr': 0.001,\n", + " 'betas': (0.9, 0.999),\n", + " 'eps': 1e-08,\n", + " 'weight_decay': 0,\n", + " 'amsgrad': False,\n", + " 'maximize': False,\n", + " 'foreach': None,\n", + " 'capturable': False,\n", + " 'differentiable': False,\n", + " 'fused': None,\n", + " 'params': [0, 1, 2, 3, 4, 5]}]}],\n", + " 'lr_schedulers': [],\n", + " 'hparams_name': 'kwargs',\n", + " 'hyper_parameters': {'model': FeedForward(\n", + " (nn): Sequential(\n", + " (0): Linear(in_features=21, out_features=15, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Linear(in_features=15, out_features=10, bias=True)\n", + " (3): ReLU(inplace=True)\n", + " (4): Linear(in_features=10, out_features=1, bias=True)\n", + " )\n", + " ),\n", + " 'preprocessing': None,\n", + " 'postprocessing': None,\n", + " 'n_states': 2,\n", + " 'n_cvs': 1,\n", + " 'target_centers': [-7, 7],\n", + " 'target_sigmas': [0.2, 0.2],\n", + " 'options': None}}" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.load('checkpoints/epoch=4-step=5.ckpt')" + ] + }, { "cell_type": "code", "execution_count": null,