diff --git a/docs/tutorial_toolbox/saving_and_loading.ipynb b/docs/tutorial_toolbox/saving_and_loading.ipynb index 9f5bd81b4..ce3f427ea 100644 --- a/docs/tutorial_toolbox/saving_and_loading.ipynb +++ b/docs/tutorial_toolbox/saving_and_loading.ipynb @@ -245,9 +245,48 @@ } }, "source": [ - "You can make your own saving and loading functions easily. Beacause all variables in the model can be easily collected through ``.vars()``. Therefore, saving variables is just transforming these variables to numpy.ndarray and then storing them into the disk. Similarly, to load variables, you just need read the numpy arrays from the disk and then transform these arrays as instances of [Variables](../tutorial_math/variables.ipynb). \n", + "You can make your own saving and loading functions easily.\n", "\n", - "The only gotcha to pay attention to is to avoid saving duplicated variables. " + "For customizing the saving and loading, users can overwrite ``__save_state__`` and ``__load_state__`` functions.\n", + "\n", + "Here is an example to customize:\n", + "```python\n", + "class YourClass(bp.DynamicSystem):\n", + " def __init__(self):\n", + " self.a = 1\n", + " self.b = bm.random.rand(10)\n", + " self.c = bm.Variable(bm.random.rand(3))\n", + " self.d = bm.var_list([bm.Variable(bm.random.rand(3)),\n", + " bm.Variable(bm.random.rand(3))])\n", + "\n", + " def __save_state__(self) -> dict:\n", + " state_dict = {'a': self.a,\n", + " 'b': self.b,\n", + " 'c': self.c}\n", + " for i, elem in enumerate(self.d):\n", + " state_dict[f'd{i}'] = elem.value\n", + "\n", + " return state_dict\n", + "\n", + " def __load_state__(self, state_dict):\n", + " self.a = state_dict['a']\n", + " self.b = bm.asarray(state_dict['b'])\n", + " self.c = bm.asarray(state_dict['c'])\n", + "\n", + " for i in range(len(self.d)):\n", + " self.d[i].value = bm.asarray(state_dict[f'd{i}'])\n", + "```\n", + "\n", + "\n", + "- ``__save_state__(self)`` function saves the state of the object's variables and returns a dictionary where the keys are the names of the variables and the values are the variables' contents.\n", + "\n", + "- ``__load_state__(self, state_dict: Dict)`` function loads the state of the object's variables from a provided dictionary (``state_dict``). \n", + "At firstly it gets the current variables of the object.\n", + "Then, it determines the intersection of keys from the provided state_dict and the object's variables.\n", + "For each intersecting key, it updates the value of the object's variable with the value from state_dict.\n", + "Finally, returns A tuple containing two lists:\n", + " - ``unexpected_keys``: Keys in state_dict that were not found in the object's variables.\n", + " - ``missing_keys``: Keys that are in the object's variables but were not found in state_dict." ] } ],