-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
88 lines (68 loc) · 3.15 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# %% main.py
# parabellum main
# by: Noah Syrkis
# Imports ###################################################################
from typing import Tuple
import equinox as eqx
import jax.numpy as jnp
from chex import dataclass
from jax import lax, random
from jaxtyping import Array
import parabellum as pb
# %% Types ####################################################################
Obs = Array
# %% Dataclasses ################################################################
@dataclass
class State:
pos: Array
types: Array
teams: Array
health: Array
@dataclass
class Conf:
place: str = "Copenhagen, Denmark"
knn: int = 10
size: int = 100
@dataclass
class Env:
cfg: Conf
geo: pb.tps.Terrain
num_allies = 10
num_rivals = 10
type_health = jnp.array([100, 100, 100])
type_damage = jnp.array([10, 10, 10])
type_ranges = jnp.array([10, 10, 10])
type_sights = jnp.array([10, 10, 10])
type_speeds = jnp.array([1, 1, 1])
type_reload = jnp.array([10, 10, 10])
def reset(self, rng: Array) -> Tuple[Obs, State]:
return init_fn(rng, self.cfg, self)
def step(self, rng, state, action) -> Tuple[Obs, State]:
return obs_fn(self.cfg, self, state), step_fn(rng, self, state, action)
# %% Functions
@eqx.filter_jit
def init_fn(rng: Array, cfg: Conf, env: Env) -> Tuple[Obs, State]: # initialize -----
keys, num_agents = random.split(rng), env.num_allies + env.num_rivals # meta ----
types = random.choice(keys[0], jnp.arange(env.type_damage), (num_agents,)) # type
pos = random.uniform(keys[1], (num_agents, 2), minval=0, maxval=cfg.size) # pos -
teams = jnp.where(jnp.arange(num_agents) < env.num_allies, 0, 1) # agent team ---
health = jnp.take(env.type_health, types) # health of agents by type for starting
state = State(pos=pos, health=health, types=types, teams=teams) # state of agents
return obs_fn(cfg, env, state), state # return observation and state of agents --
@eqx.filter_jit
def obs_fn(cfg: Conf, env: Env, state: State) -> Obs: # return infoabout neighbors ---
distances = jnp.linalg.norm(state.pos[:, None] - state.pos, axis=-1) # all dist --
dist, idxs = lax.approx_min_k(distances, cfg.knn) # dists and idxs of close by ---
directions = jnp.take(state.pos, idxs, axis=0) - state.pos[:, None] # direction --
obs = jnp.stack([dist, state.health[idxs], state.types[idxs]], axis=-1) # concat -
mask = dist < env.type_ranges[state.types][..., None] # mask for removing hidden -
return jnp.concat([obs, directions], axis=-1) * mask[..., None] # an observation -
@eqx.filter_jit
def step_fn(rng: Array, env: Env, state: State, action) -> State: # update agents ---
pos = state.pos + action.direction * env.type_speeds[state.types] # move agent --
hp = state.health - action.attack * env.type_damage[state.types] # attack stuff -
return State(pos=pos, health=hp, types=state.types, teams=state.teams) # return -
# %% Main #####################################################################
cfg, geo = Conf(), pb.geo.geography_fn("Copenhagen, Denmark")
env = Env(cfg=Conf(), geo=geo)
obs, state = env.reset(rng := random.PRNGKey(0))