-
Notifications
You must be signed in to change notification settings - Fork 2
/
propose.py
50 lines (46 loc) · 1.34 KB
/
propose.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import hydra
import logging
import os
import openai
from zero_hero.core import EnvNode, TaskDatabase, CenteralizedTask
@hydra.main(config_path="cfg", config_name="config", version_base="1.1")
def main(cfg):
openai.api_key = os.getenv("OPENAI_API_KEY")
model = cfg.proposal.model
env_name = cfg.env.env_name.lower()
env_idx = f"E{cfg.seed:02d}"
tdb = TaskDatabase(
env_name=env_name,
env_idx=env_idx,
target_num_skills=cfg.proposal.target_num_skills,
failed_tolerance=cfg.proposal.failed_tolerance,
proposal_batch=cfg.proposal.proposal_batch,
)
ct = CenteralizedTask(
env_name=env_name,
model=cfg.design.model,
temperature=cfg.design.temperature,
)
env_node = EnvNode(
task_database=tdb,
centralized_task=ct,
idx=env_idx,
env_name=env_name,
resume=cfg.resume,
model=model,
n_samples=1,
temperature=cfg.temperature,
skills=[],
impossibles=[],
)
while not tdb.met_target() and not tdb.should_wait():
tasks = env_node.propose()
tdb.add_tasks(tasks)
tdb.render()
tdb.save()
logging.info(
f"Updated task database {tdb.store_path} with {len(tasks)} new tasks."
)
logging.info(f"Finished!")
if __name__ == "__main__":
main()