Skip to content

Commit

Permalink
Temporary fix with new signature of BaseCV
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoTrizio committed Dec 10, 2024
1 parent 73745d2 commit 72e1804
Showing 1 changed file with 22 additions and 27 deletions.
49 changes: 22 additions & 27 deletions docs/notebooks/tutorials/adv_newcv_scratch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -72,7 +72,7 @@
"from mlcolvar.cvs import BaseCV\n",
"\n",
"class AutoEncoderCV(BaseCV, lightning.LightningModule):\n",
" BLOCKS = ['norm_in','encoder','decoder'] "
" DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] "
]
},
{
Expand All @@ -87,7 +87,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -104,7 +104,7 @@
" with the input 'data'.\n",
" \"\"\"\n",
" \n",
" BLOCKS = ['norm_in','encoder','decoder'] "
" DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] "
]
},
{
Expand Down Expand Up @@ -136,12 +136,12 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class AutoEncoderCV(BaseCV, lightning.LightningModule):\n",
" BLOCKS = ['norm_in','encoder','decoder'] \n",
" DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] \n",
"\n",
" def __init__(self,\n",
"# ================================================ LOOK HERE 0.0 ================================================ \n",
Expand All @@ -165,7 +165,7 @@
" Available blocks: ['norm_in', 'encoder','decoder'].\n",
" Set 'block_name' = None or False to turn off that block\n",
" \"\"\"\n",
" super().__init__(in_features=encoder_layers[0], out_features=encoder_layers[-1], **kwargs)\n",
" super().__init__(model=encoder_layers, **kwargs)\n",
" \n",
"# ================================================ LOOK HERE 0.0 ================================================ \n"
]
Expand All @@ -185,19 +185,19 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class AutoEncoderCV(BaseCV, lightning.LightningModule):\n",
" BLOCKS = ['norm_in','encoder','decoder'] \n",
" DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] \n",
" \n",
" def __init__(self,\n",
" encoder_layers : list, \n",
" decoder_layers : list = None, \n",
" options : dict = None, \n",
" **kwargs):\n",
" super().__init__(in_features=encoder_layers[0], out_features=encoder_layers[-1], **kwargs)\n",
" super().__init__(model=encoder_layers, **kwargs)\n",
"\n",
"# ================================================ LOOK HERE 0.0 ================================================ \n",
" \n",
Expand All @@ -224,21 +224,21 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mlcolvar.core.loss import MSELoss\n",
"\n",
"class AutoEncoderCV(BaseCV, lightning.LightningModule):\n",
" BLOCKS = ['norm_in','encoder','decoder'] \n",
" DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] \n",
" \n",
" def __init__(self,\n",
" encoder_layers : list, \n",
" decoder_layers : list = None, \n",
" options : dict = None, \n",
" **kwargs):\n",
" super().__init__(in_features=encoder_layers[0], out_features=encoder_layers[-1], **kwargs)\n",
" super().__init__(model=encoder_layers, **kwargs)\n",
"\n",
" # ======= OPTIONS ======= \n",
" # parse and sanitize\n",
Expand Down Expand Up @@ -283,22 +283,22 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mlcolvar.core.nn import FeedForward\n",
"from mlcolvar.core.transform import Normalization\n",
"\n",
"class AutoEncoderCV(BaseCV, lightning.LightningModule):\n",
" BLOCKS = ['norm_in','encoder','decoder'] \n",
" DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] \n",
" \n",
" def __init__(self,\n",
" encoder_layers : list, \n",
" decoder_layers : list = None, \n",
" options : dict = None, \n",
" **kwargs):\n",
" super().__init__(in_features=encoder_layers[0], out_features=encoder_layers[-1], **kwargs)\n",
" super().__init__(model=encoder_layers, **kwargs)\n",
"\n",
" # ======= OPTIONS ======= \n",
" # parse and sanitize\n",
Expand Down Expand Up @@ -425,7 +425,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -442,7 +442,7 @@
" with the input 'data'.\n",
" \"\"\"\n",
" \n",
" BLOCKS = ['norm_in','encoder','decoder'] \n",
" DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] \n",
" \n",
" def __init__(self,\n",
" encoder_layers : list, \n",
Expand All @@ -465,7 +465,7 @@
" Available blocks: ['norm_in', 'encoder','decoder'].\n",
" Set 'block_name' = None or False to turn off that block\n",
" \"\"\"\n",
" super().__init__(in_features=encoder_layers[0], out_features=encoder_layers[-1], **kwargs)\n",
" super().__init__(model=encoder_layers, **kwargs)\n",
"\n",
" # ======= OPTIONS ======= \n",
" # parse and sanitize\n",
Expand Down Expand Up @@ -625,7 +625,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "pytorch",
"display_name": "graph_mlcolvar_test",
"language": "python",
"name": "python3"
},
Expand All @@ -639,14 +639,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.9.18"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "1cbeac1d7079eaeba64f3210ccac5ee24400128e300a45ae35eee837885b08b3"
}
}
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
Expand Down

0 comments on commit 72e1804

Please sign in to comment.