Skip to content

Commit

Permalink
Add the viz example
Browse files Browse the repository at this point in the history
  • Loading branch information
knyazer committed Feb 26, 2024
1 parent fdaa461 commit f4519fa
Showing 1 changed file with 77 additions and 0 deletions.
77 changes: 77 additions & 0 deletions examples/test_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import equinox as eqx
import jax
import pytest
from jax import numpy as jnp, random as jr
from tqdm import tqdm as tqdm

from cotix._colliders import RandomizedCollider
from cotix._constraint_solvers import SimpleConstraintSolver
from cotix._lunar_lander import LunarLander
from cotix._physics_solvers import ExplicitEulerPhysics
from cotix._robocup import RoboCupEnv
from cotix._viz import Painter


@pytest.mark.skip
def test_lunar_lander():
jax.config.update("jax_log_compiles", True)
env = LunarLander()

physics = ExplicitEulerPhysics()
collider = RandomizedCollider()
painter = Painter()

@eqx.filter_jit
def f(env, key):
new_bodies, aux = physics.step(env.bodies, dt=1e-2)
new_bodies = eqx.tree_at(
lambda x: x[0].velocity,
new_bodies,
new_bodies[0].velocity + jnp.array([0.0, -0.002]),
)

def draw_log(log):
pos = log.contact_point
painter.draw_circle(pos, 0.1, (200, 0, 0))

new_bodies = collider.resolve(new_bodies, key, draw_log)
# new_bodies = constraintSolver.solve(new_bodies, env.constraints)
key, next_key = jr.split(key)
env = eqx.tree_at(lambda x: x.bodies, env, new_bodies)
env = env.step()

env.draw(painter)
return env, key

key = jr.PRNGKey(0)
for i in range(10000):
env, key = f(env, key)


@pytest.mark.skip
def test_robocup_env():
jax.config.update("jax_log_compiles", True)
env = RoboCupEnv()

physics = ExplicitEulerPhysics()
collider = RandomizedCollider()
constraintSolver = SimpleConstraintSolver(loops=1)
painter = Painter()

@eqx.filter_jit
def f(env, key):
new_bodies, aux = physics.step(env.bodies, dt=1e-2)
new_bodies = collider.resolve(new_bodies, key)
new_bodies = constraintSolver.solve(new_bodies, env.constraints)
key, next_key = jr.split(key)
env = eqx.tree_at(lambda x: x.bodies, env, new_bodies)
painter.draw_env(env)
return env, key

key = jr.PRNGKey(0)
for i in tqdm(range(3000)):
env, key = f(env, key)


if __name__ == "__main__":
test_lunar_lander()

0 comments on commit f4519fa

Please sign in to comment.