-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.yaml
38 lines (38 loc) · 912 Bytes
/
train.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Training params
use_cuda: true
batch_size: 64
num_workers: 16
lr_init: 1.0e-3
lr_decay_rate: 0.1
lr_decay_steps: 1.0e+6
training_steps: 1.0e+6
warmup_steps: 5
# Evaluation
loss_report_step: 2000
log_grad: True
log_grad_step: 10000
save_model_step: 2000
eval_step: 2000
rollout_steps: 50
run_validate: true
num_eval_rollout: 10
save_video: true
# Dataset
data_path: "datasets/rigidFall_epLen_100_inter_10_totalSteps_1M.npz"
test_data_path: "datasets/rigidFall_epLen_100_inter_10_totalSteps_100K.npz"
data_config:
noise_std: 3.0e-5
connectivity_radius: 0.01
input_seq_length: 3
# Logging
logging_folder: "log"
log_level: "info"
# continue_log_from: "2024-08-29-17:05"
# Resume training
# model_file: "log/sim-pc/models/weights_itr_482000.ckpt"
# train_state_file: "log/sim-pc/models/train_state_itr_482000.ckpt"
# Simulator params
leave_out_mm: True
latent_dim: 128
message_passing_steps: 10
mlp_layers: 2