Skip to content

Commit

Permalink
log acc and f1 to wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
syrkis committed Jun 16, 2024
1 parent c36716e commit 9abccee
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
6 changes: 3 additions & 3 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
base: 16
n: 1024 # 16384
n: 2048 # 16384
emb: 128
lr: 0.001
depth: 2
heads: 4
epochs: 1000
epochs: 2000
block: vaswani
dropout: 0.5
l2: 0.3 # weight decay
l2: 1.0 # weight decay
gamma: 2
18 changes: 9 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


# functions
def log_run(conf, metrics):
def log_run(cfg, metrics):
# long loss and epoch
log_fn = lambda x: {
"train_loss": x[0],
Expand All @@ -21,26 +21,26 @@ def log_run(conf, metrics):
"train_f1": x[4],
"valid_f1": x[5],
}
with wandb.init(project="miiii", config=conf):
with wandb.init(project="miiii", config=cfg):
for epoch, metric in enumerate(metrics[:-1]):
wandb.log(log_fn(metric), step=epoch, commit=False)
wandb.log(log_fn(metrics[-1]), step=conf.epochs)
wandb.log(log_fn(metrics[-1]), step=cfg.epochs)
# TODO: log images and model, and maybe more


def main():
# config and init
conf, (rng, key) = miiii.get_conf(), random.split(random.PRNGKey(0))
data = miiii.prime_fn(conf.n, partial(miiii.base_n, conf.base))
params = miiii.init_fn(key, conf)
cfg, (rng, key) = miiii.get_conf(), random.split(random.PRNGKey(0))
data = miiii.prime_fn(cfg.n, partial(miiii.base_n, cfg.base))
params = miiii.init_fn(key, cfg)

# train
apply_fn = miiii.make_apply_fn(miiii.vaswani_fn)
train_fn, opt_state = miiii.init_train(apply_fn, params, conf, *data)
(params, opt_state), metrics = train_fn(conf.epochs, rng, (params, opt_state))
train_fn, opt_state = miiii.init_train(apply_fn, params, cfg, *data)
(params, opt_state), metrics = train_fn(cfg.epochs, rng, (params, opt_state))

# evaluate
log_run(conf, metrics) # log run
log_run(cfg, metrics) # log run


if __name__ == "__main__":
Expand Down

0 comments on commit 9abccee

Please sign in to comment.