-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
multiple GPU devices simulation and training of one dynamic system in brainpy #641
Comments
I think I might find the way to sharding bm.array based on JAX's tutorial https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html :
Maybe just sharding the input output bm.array tensor along the batch axis, and then let it automatically calculate on multi-GPUs ? |
print is something like that before- and after-sharding array:
|
Thanks for the question. Sorry for the slow response. I will check it later. |
Hi, chaoming @chaoming0625 Maybe this issue is a bit hard with too many engineering works to achieve. 🫡 I just have an idea about a quick and cheap solution of this issue. As to #663 , if any built-in or customized brainpy dynamical system class could be automatically transformed into Flax's RNN cell using bp.dnn.ToFlaxRNNCell(). Then, we could just do multi-GPU parallel training using Flax (https://flax.readthedocs.io/en/latest/guides/parallel_training/index.html). 🤖 best, |
yes, the idea is simple. I will give you the solution soon. |
Here is my example of using multiple GPUs. I marked the key code by using the comment import os
import jax
import brainpy as bp
import brainpy.math as bm
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" # specify which GPU(s) to be used
bm.disable_gpu_memory_preallocation()
bm.set_platform('gpu')
bm.set_mode(bm.training_mode)
print('bp version:', bp.__version__)
print(jax.local_devices())
# bp version: 2.4.6.post5
# [cuda(id=0), cuda(id=1)]
# %%
class RNN(bp.DynamicalSystemNS):
def __init__(self, num_in, num_hid, num_out, batch_size=1):
super(RNN, self).__init__()
bp.check.is_subclass(self.mode, (bm.TrainingMode, bm.BatchingMode))
# define parameters
self.num_in = num_in
self.num_hid = num_hid
self.num_out = num_out
# define variables [KEY]
self.state = bp.init.variable(bm.zeros, num_hid, batch_size, axis_names=['hidden'])
# self.state = bm.Variable(bm.zeros((batch_size, num_hid)), batch_axis=0)
# define weights [KEY]
self.win = bm.TrainVar(bp.init.variable_(bp.init.Normal(), (num_in, num_hid), axis_names=[None, 'hidden']))
self.wrec = bm.TrainVar(bp.init.variable_(bp.init.Normal(), (num_hid, num_hid), axis_names=[None, 'hidden']))
self.wout = bm.TrainVar(bp.init.variable_(bp.init.Normal(), (num_hid, num_out), axis_names=['hidden', None]))
# self.win = bm.TrainVar(bm.random.normal(0., 1., size=(num_in, num_hid)))
# self.wrec = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_hid)))
# self.wout = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_out)))
def reset_state(self, batch_size): # this function defines how to reset the mode states
self.state.value = bp.init.variable_(bm.zeros, (self.num_hid,), batch_size, axis_names=['hidden'])
def update(self, x): # this function defined how the model update its state and produce its output
self.state.value = bm.tanh(bm.matmul(x, self.win) + bm.matmul(self.state, self.wrec))
return bm.matmul(self.state, self.wout)
with bm.sharding.device_mesh(jax.devices(), ['hidden']): # [KEY]
# initialize model
bm.random.seed(123)
dim_in = 1
dim_hid = 10
dim_out = 1
batch_size = 1
model = RNN(dim_in, dim_hid, dim_out, batch_size)
# %%
# generate some data
Nsample = 500
X_train = bm.random.normal(0., 1., size=(batch_size, Nsample, dim_in)) # (Batch,Time,dim)
Y_train = bm.random.normal(10., 1., size=(batch_size, Nsample, dim_out))
# training
def loss_fun(inputs, targets):
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss
grad_fun = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), return_value=True)
opt = bp.optim.Adam(lr=1e-1, train_vars=model.train_vars().unique())
@bm.jit
def train(xs, ys):
grads, loss = grad_fun(xs, ys)
opt.update(grads)
return loss
losses = []
for _ in range(1000):
losses.append(train(X_train, Y_train))
|
The concept is very simple.
with bm.sharding.device_mesh(devices, ['hidden']):
... means that the Note that the with bm.sharding.device_mesh(np,asarray(jax.devices(), (2, 2)), ['input', 'hidden']):
...
|
Please tell me whether the above code works. Please also see an example of TPU multi-device partition examples of COBA-HH network model. |
By the way, I apologize for the very late response! |
Thanks for the feedback! |
One more question about the details. it seems that you partition the model (the hidden states of this RNN) into two GPUs. Why not partition along the batch axis? it seems more nature for users. |
This is a good idea. While, if the batch size is the challenge hindering the training of the model on one GPU, we can decrease the batch size, rather than partition it on multiple devices. One more difficult situation is that the model is too big to install on one device. For such cases, we can partition the model on multiple devices. For example, simulating a very large-scale SNN model (usually there are no batch sizes). |
Partitioning on hidden states, and their interaction matrix is a simple model parallelization method. |
Okay, I see. By the way, I found that in the code of model definition, only change one line about the model state variable is enough for parallelization. No need to change the weights TrainVar with # define variables
self.state = bp.init.variable(bm.zeros, batch_size, num_hid, axis_names=['hidden'], batch_axis_name=['batch']) #<<<关键点
# define weights
self.win = bm.TrainVar(bm.random.normal(0., 1., size=(num_in, num_hid))) # 不用改
self.wrec = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_hid)))
self.wout = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_out))) Thanks again for the help👍👍👍 |
Hi, Chaoming:
I am trying to do simulation and training of a dynamic system (a self customized RNN based on brainpy, https://github.com/Dr-Chen-Xiaoyu/DecoModel) with very huge dimension and time steps. The memory usage is out of one single GPU device.
I believe this could be solved by running brainpy on multiple GPU devices with its own
sharding
method, just like jax'ssharding
or pytorch'storch.nn.DataParallel
. A simplified case of RNN training is provided below, and change the dimension of RNN to very huge (maybe >1000) as well as the input output tensor (maybe >1000^3). Maybe you could modify this code with brainpy'ssharding
and make it as part of brainpy's tutorial if this is a general demand of users.best,
Xiaoyu
The example code:
The text was updated successfully, but these errors were encountered: