Skip to content

Commit

Permalink
Revert "Fix"
Browse files Browse the repository at this point in the history
This reverts commit 211e974.
  • Loading branch information
Routhleck committed May 12, 2024
1 parent 211e974 commit c7d04ab
Show file tree
Hide file tree
Showing 16 changed files with 1,880 additions and 407 deletions.
2 changes: 1 addition & 1 deletion brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def find_fps_with_gd_method(
"""
# optimization settings
if optimizer is None:
optimizer = optim.Adam(lr=optim.ExponentialDecayLR(0.2, 1, 0.9999),
optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999),
beta1=0.9, beta2=0.999, eps=1e-8)
else:
if not isinstance(optimizer, optim.Optimizer):
Expand Down
10 changes: 3 additions & 7 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,9 @@ def import_taichi(error_if_not_found=True):

if taichi is None:
return None
taichi_version = taichi.__version__[0] * 10000 + taichi.__version__[1] * 100 + taichi.__version__[2]
minimal_taichi_version = _minimal_taichi_version[0] * 10000 + _minimal_taichi_version[1] * 100 + \
_minimal_taichi_version[2]
if taichi_version >= minimal_taichi_version:
return taichi
else:
raise ModuleNotFoundError(taichi_install_info)
if taichi.__version__ != _minimal_taichi_version:
raise RuntimeError(taichi_install_info)
return taichi


def raise_taichi_not_found(*args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/train/back_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(

# optimizer
if optimizer is None:
lr = optim.ExponentialDecayLR(lr=0.025, decay_steps=1, decay_rate=0.99975)
lr = optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)
optimizer = optim.Adam(lr=lr)
self.optimizer: optim.Optimizer = optimizer
if len(self.optimizer.vars_to_train) == 0:
Expand Down
296 changes: 243 additions & 53 deletions docs/quickstart/analysis.ipynb

Large diffs are not rendered by default.

144 changes: 120 additions & 24 deletions docs/tutorial_advanced/advanced_lowdim_analysis.ipynb

Large diffs are not rendered by default.

48 changes: 24 additions & 24 deletions docs/tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
Expand All @@ -34,33 +35,33 @@
"\n",
"import brainpy as bp\n",
"import brainpy.math as bm"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"bm.set(mode=bm.training_mode, dt=1.)\n",
"\n",
"bp.__version__"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"num_time = 10"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# the recurrent cell with trainable parameters\n",
"cell1 = bp.dnn.ToFlaxRNNCell(bp.dyn.Conv2dLSTMCell((28, 28),\n",
Expand All @@ -71,13 +72,13 @@
" in_channels=32,\n",
" out_channels=64,\n",
" kernel_size=(3, 3)))"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class CNN(nn.Module):\n",
" \"\"\"A simple CNN model.\"\"\"\n",
Expand All @@ -93,13 +94,13 @@
" x = nn.relu(x)\n",
" x = nn.Dense(features=10)(x)\n",
" return x"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def apply_model(state, images, labels):\n",
Expand All @@ -118,24 +119,24 @@
" (loss, logits), grads = grad_fn(state.params)\n",
" accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)\n",
" return grads, loss, accuracy"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def update_model(state, grads):\n",
" return state.apply_gradients(grads=grads)"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def train_epoch(state, train_ds, batch_size, rng):\n",
" \"\"\"Train for a single epoch.\"\"\"\n",
Expand All @@ -159,13 +160,13 @@
" train_loss = np.mean(epoch_loss)\n",
" train_accuracy = np.mean(epoch_accuracy)\n",
" return state, train_loss, train_accuracy"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def get_datasets():\n",
" \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n",
Expand All @@ -176,27 +177,27 @@
" train_ds['image'] = jnp.asarray(train_ds['image']) / 255.\n",
" test_ds['image'] = jnp.asarray(test_ds['image']) / 255.\n",
" return train_ds, test_ds"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def create_train_state(rng, config):\n",
" \"\"\"Creates initial `TrainState`.\"\"\"\n",
" cnn = CNN()\n",
" params = cnn.init(rng, jnp.ones([1, num_time, 28, 28, 1]))['params']\n",
" tx = optax.sgd(config.learning_rate, config.momentum)\n",
" return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def train_and_evaluate(config: ml_collections.ConfigDict,\n",
" workdir: str) -> train_state.TrainState:\n",
Expand Down Expand Up @@ -246,13 +247,13 @@
"\n",
" summary_writer.flush()\n",
" return state"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"config = ml_collections.ConfigDict()\n",
"\n",
Expand All @@ -262,8 +263,7 @@
"config.num_epochs = 10\n",
"\n",
"train_and_evaluate(config, './ckpt')"
],
"outputs": []
]
}
],
"metadata": {
Expand Down
Loading

0 comments on commit c7d04ab

Please sign in to comment.