Skip to content

Commit

Permalink
Merge branch 'master' into staging
Browse files Browse the repository at this point in the history
  • Loading branch information
Delaunay authored Nov 21, 2024
2 parents 491505f + 3d27180 commit a66519e
Show file tree
Hide file tree
Showing 15 changed files with 14 additions and 177 deletions.
5 changes: 0 additions & 5 deletions benchmarks/diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ def models(accelerator, args: Arguments):
unet = UNet2DConditionModel.from_pretrained(
args.model, subfolder="unet", revision=args.revision, variant=args.variant
)

from benchmate.models import model_size
print(model_size(unet))
print(model_size(encoder))
print(model_size(vae))

vae.requires_grad_(False)
encoder.requires_grad_(False)
Expand Down
1 change: 1 addition & 0 deletions benchmarks/flops/benchfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class FlopsBenchmarch(Package):

def build_run_plan(self) -> "Command":
import milabench.commands as cmd

main = self.dirs.code / self.main_script
pack = cmd.PackCommand(self, *self.argv, lazy=True)

Expand Down
56 changes: 0 additions & 56 deletions benchmarks/flops/dev.yaml

This file was deleted.

3 changes: 1 addition & 2 deletions benchmarks/flops/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,11 @@ def main():

log, monitor = setupvoir()

# FIXME
#with monitor:
f(args.number, args.repeat, args.m, args.n, TERA, dtypes[args.dtype], log)

monitor.stop()


if __name__ == "__main__":
main()
print("done")
5 changes: 0 additions & 5 deletions benchmarks/flops/requirements.cpu.txt

This file was deleted.

13 changes: 0 additions & 13 deletions benchmarks/flops/simple.sh

This file was deleted.

36 changes: 0 additions & 36 deletions benchmarks/geo_gnn/modelsize.py

This file was deleted.

5 changes: 0 additions & 5 deletions benchmarks/purejaxrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,6 @@ def train(rng):
init_x = jnp.zeros(env.observation_space(env_params).shape)
network_params = network.init(_rng, init_x)


param_count = sum(x.size for x in jax.tree.leaves(network_params))
print("PARAM COUNT", param_count)


def linear_schedule(count):
frac = 1.0 - (count / config["NUM_UPDATES"])
return config["LR"] * frac
Expand Down
6 changes: 0 additions & 6 deletions benchmarks/purejaxrl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,9 @@ def train(rng):
network = ActorCritic(
env.action_space(env_params).shape[0], activation=config["ACTIVATION"]
)


rng, _rng = jax.random.split(rng)
init_x = jnp.zeros(env.observation_space(env_params).shape)
network_params = network.init(_rng, init_x)

param_count = sum(x.size for x in jax.tree.leaves(network_params))
print("PARAM COUNT", param_count)

if config["ANNEAL_LR"]:
tx = optax.chain(
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
Expand Down
4 changes: 1 addition & 3 deletions benchmarks/recursiongfn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,11 @@ def __init__(
self.num_cond_dim = self.temperature_conditional.encoding_size()

def _load_task_models(self):
xdg_cache = os.environ.get("XDG_CACHE_HOME")
xdg_cache = os.environ["XDG_CACHE_HOME"]
model = bengio2021flow.load_original_model(
cache=True,
location=Path(os.path.join(xdg_cache, "bengio2021flow_proxy.pkl.gz")),
)
from benchmate.models import model_size
print(model_size(model))
model.to(get_worker_device())
model = self._wrap_model(model)
return {"seh": model}
Expand Down
3 changes: 0 additions & 3 deletions benchmarks/torchatari/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,7 @@ def main():
envs = RecordEpisodeStatistics(envs)
assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported"


from benchmate.models import model_size
agent = Agent(envs).to(device)
print(model_size(agent))
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

# ALGO Logic: Storage setup
Expand Down
36 changes: 0 additions & 36 deletions benchmate/benchmate/models.py

This file was deleted.

15 changes: 11 additions & 4 deletions benchmate/benchmate/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from voir.instruments.io import io_monitor
from voir.instruments.network import network_monitor
from voir.instruments.monitor import monitor
from voir.helpers import current_overseer


from .metrics import sumggle_push, give_push, file_push

Expand Down Expand Up @@ -64,10 +64,17 @@ def monitor_node(ov, poll_interval=1, arch=None):


def _smuggle_monitor(poll_interval=10, worker_init=None, **monitors):
log = auto_push()

# USE auto push
data_file = SmuggleWriter(sys.stdout)
def mblog(data):
log(**data)
nonlocal data_file

if data_file is not None:
try:
print(json.dumps(data), file=data_file)
except ValueError:
pass
# print("Is bench ending?, ignoring ValueError")

def get():
t = time.time()
Expand Down
1 change: 0 additions & 1 deletion milabench/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@
__tag__ = "v1.0.0_RC1-18-g784b38e"
__commit__ = "784b38e77b90116047e3de893c22c2f7d3225179"
__date__ = "2024-10-18 15:58:46 +0000"

2 changes: 0 additions & 2 deletions milabench/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,6 @@ def make_env(self):
f"MILABENCH_DIR_{name.upper()}": path
for name, path in self.config["dirs"].items()
}

env["MILABENCH_MANAGED"] = "1"

env["OMP_NUM_THREADS"] = resolve_placeholder(self, "{cpu_per_gpu}")

Expand Down

0 comments on commit a66519e

Please sign in to comment.