Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO returns nan with multiple GPU #332

Open
Daffan opened this issue Apr 10, 2023 · 5 comments
Open

PPO returns nan with multiple GPU #332

Daffan opened this issue Apr 10, 2023 · 5 comments
Assignees

Comments

@Daffan
Copy link

Daffan commented Apr 10, 2023

PPO training returns nan when using multiple GPU. Forcing t use one GPU works fine. I just ran the exactly same code in training code in Brax Training. Can somebody help to try it? Thanks!

@btaba btaba self-assigned this Apr 13, 2023
@btaba
Copy link
Collaborator

btaba commented Apr 13, 2023

@Daffan Thanks for reporting! Which environment are you using and which backend?

@Yunkai-Yu
Copy link

Hello, @btaba. I'm following up to inquire about any progress regarding the issue we discussed earlier. In my recent experiments, I've encountered an unexpected problem with PPO returning NaN values after several iterations on my GPUs.

I've included the code snippet below for reference. The code aims to drive a humanoid model to move gradually to different positions one by one. While I understand that optimization convergence might not be achieved immediately, encountering NaN values seems peculiar and warrants investigation.

#!/usr/bin/env python
import time
import os
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/lib/cuda/'
from datetime import datetime
import functools
import jax
from jax import numpy as jp
import numpy as np
from brax import base
from brax import envs
from brax import actuator
from brax.envs.base import PipelineEnv, State
from brax.training.agents.ppo import train as ppo
from brax.io import mjcf
from etils import epath

import mujoco
from jax import config
config.update("jax_debug_nans", True)

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']=".95"
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"

print("xla setted")
np.set_printoptions(precision=3, suppress=True, linewidth=100)


class Humanoid(PipelineEnv):

  # pyformat: enable
  def __init__(
      self,
      terminate_when_unhealthy=True,
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      backend='generalized',
      **kwargs,
  ):
    path = epath.resource_path('brax') / 'envs/assets/humanoid.xml'
    sys = mjcf.load(path)
    n_frames = 5

    if backend == 'mjx':
      sys = sys.tree_replace({
          'opt.solver': mujoco.mjtSolver.mjSOL_NEWTON,
          'opt.disableflags': mujoco.mjtDisableBit.mjDSBL_EULERDAMP,
          'opt.iterations': 1,
          'opt.ls_iterations': 4,
      })

    kwargs['n_frames'] = kwargs.get('n_frames', n_frames)

    super().__init__(sys=sys, backend=backend, **kwargs)
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )
    self.ref_traj = jp.asarray([1.0,2.0,3.0,4.0])
    self.max_err = jp.array(0.5)

  def reset(self, rng: jax.Array) -> State:
    """Resets the environment to an initial state."""
    rng, rng1,rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.init_q + jax.random.uniform(
        rng1, (self.sys.q_size(),), minval=low, maxval=hi
    )
    
    qvel = jax.random.uniform(
        rng2, (self.sys.qd_size(),), minval=low, maxval=hi
    )

    pipeline_state = self.pipeline_init(qpos, qvel)

    obs = self._get_obs(pipeline_state, jp.zeros(self.sys.act_size()))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'mse_loss': zero,
        'idx': 0,
        'is_converge': zero,
        'done': zero
    }
    return State(pipeline_state, obs, reward, done, metrics)

  def step(self, state: State, action: jax.Array) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    idx = state.metrics["idx"].astype(int)
    target = self.ref_traj[idx].copy()
    
    data = self.pipeline_step(data0, action)
    mse_loss = jp.sum(jp.square(data.q[0]-target))
    
    is_converge = jp.where(mse_loss < self.max_err, 1.0, 0.0)
    idx_new = jp.where(is_converge==1.0, idx+1, idx)

    obs = self._get_obs(data,action)
    done = is_converge

    reward = -mse_loss+is_converge*100
    state.metrics.update(
        mse_loss = mse_loss,
        idx = idx_new,
        is_converge = is_converge,
        done = done
    )
    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, pipeline_state: base.State, action: jax.Array
  ) -> jax.Array:
    """Observes humanoid body position, velocities, and angles."""
    position = pipeline_state.q
    velocity = pipeline_state.qd

    if self._exclude_current_positions_from_observation:
      position = position[2:]

    com, inertia, mass_sum, x_i = self._com(pipeline_state)
    cinr = x_i.replace(pos=x_i.pos - com).vmap().do(inertia)
    com_inertia = jp.hstack(
        [cinr.i.reshape((cinr.i.shape[0], -1)), inertia.mass[:, None]]
    )

    xd_i = (
        base.Transform.create(pos=x_i.pos - pipeline_state.x.pos)
        .vmap()
        .do(pipeline_state.xd)
    )
    com_vel = inertia.mass[:, None] * xd_i.vel / mass_sum
    com_ang = xd_i.ang
    com_velocity = jp.hstack([com_vel, com_ang])

    qfrc_actuator = actuator.to_tau(
        self.sys, action, pipeline_state.q, pipeline_state.qd)

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        velocity,
        com_inertia.ravel(),
        com_velocity.ravel(),
        qfrc_actuator,
    ])

  def _com(self, pipeline_state: base.State) -> jax.Array:
    inertia = self.sys.link.inertia

    mass_sum = jp.sum(inertia.mass)
    x_i = pipeline_state.x.vmap().do(inertia.transform)
    com = (
        jp.sum(jax.vmap(jp.multiply)(inertia.mass, x_i.pos), axis=0) / mass_sum
    )
    return com, inertia, mass_sum, x_i  # pytype: disable=bad-return-type  # jax-ndarray

envs.register_environment('humanoid', Humanoid)
print("env registered")
# instantiate the environment
env_name = 'humanoid'
env = envs.get_environment(env_name)

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

max_y, min_y = 13000, 0
def progress(num_steps, metrics):
  times.append(datetime.now())
  print(times[-2:],"len: ",len(times))
  print(num_steps)
  print(metrics['eval/episode_reward'])

train_fn = functools.partial(
    ppo.train, num_timesteps=10_00, num_evals=1200, reward_scaling=0.1,
    episode_length=10, normalize_observations=True, action_repeat=1,
    unroll_length=5, num_minibatches=32, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-6, entropy_cost=1e-3, num_envs=512,
    batch_size=256, seed=0)

# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# grab a trajectory
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

nan

Thank you for your time.

@i1Cps
Copy link

i1Cps commented Aug 2, 2024

@Yunkai-Yu Any progress made with this?

@RuiningLi
Copy link

I encountered the same issue in a customized environment. Any progress on this?

@jadkins99
Copy link

I have also encountered this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants