-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
156 lines (146 loc) · 5.21 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from ml_collections import ConfigDict
from ml_collections.config_dict import FieldReference, placeholder
from octo.data.utils.text_processing import MuseEmbedding
from octo.model.components.action_heads import MSEActionHead
from octo.model.components.tokenizers import ImageTokenizer
from octo.model.components.transformer import common_transformer_sizes
from octo.model.components.vit_encoders import SmallStem16
from octo.utils.spec import ModuleSpec
def get_config(
transformer_size="vit_s",
):
print("Creating config with: ", locals())
window_size = FieldReference(default=1)
return ConfigDict(
dict(
seed=42,
num_steps=2e6,
save_dir=placeholder(str),
model=get_model_config(transformer_size),
window_size=window_size,
dataset_kwargs=get_dataset_config(window_size),
optimizer=dict(
learning_rate=dict(
name="rsqrt",
init_value=0.0,
peak_value=3e-4,
warmup_steps=2000,
timescale=10000,
),
weight_decay=0.1,
clip_gradient=1.0,
frozen_keys=tuple(),
),
prefetch_num_batches=0,
start_step=placeholder(int),
log_interval=100,
eval_interval=5000,
viz_interval=20000,
save_interval=10000,
val_kwargs=dict(
val_shuffle_buffer_size=1000,
num_val_batches=16,
),
viz_kwargs=dict(
eval_batch_size=128,
trajs_for_metrics=100,
trajs_for_viz=8,
samples_per_state=8,
),
resume_path=placeholder(str),
text_processor=ModuleSpec.create(MuseEmbedding),
pretrained_loaders=tuple(),
wandb=dict(
project="octo",
group=placeholder(str),
entity=placeholder(str),
),
wandb_resume_id=placeholder(str),
eval_datasets=(),
)
)
def get_model_config(transformer_size):
"""
Transformer_size is one of ["dummy", "vanilla", "vit_t" "vit_s", "vit_b", "vit_l", "vit_h"]
This model stacks all the images from different cameras together, and passes it through
a small convolutional stem before entering the transformer.
The action head pools all the observation token embeddings, and passes it through a small MLP
before predicting the action using a MSE loss.
"""
token_embedding_size, transformer_kwargs = common_transformer_sizes(
transformer_size
)
return dict(
observation_tokenizers=dict(
image=ModuleSpec.create(
ImageTokenizer,
obs_stack_keys=["image_.*"],
task_stack_keys=["image_.*"],
task_film_keys=["language_instruction"],
encoder=ModuleSpec.create(SmallStem16, use_film=True),
),
),
task_tokenizers=dict(),
heads=dict(
action=ModuleSpec.create(
MSEActionHead,
action_horizon=1,
action_dim=7,
readout_key="obs",
),
),
readouts=dict(),
token_embedding_size=token_embedding_size,
transformer_kwargs=transformer_kwargs,
max_horizon=10,
use_correct_attention=True,
)
def get_dataset_config(window_size=1):
task_augmentation = dict(
task_augment_strategy="delete_task_conditioning",
task_augment_kwargs=dict(
keep_image_prob=0.5,
),
)
return dict(
# oxe_kwargs will generate dataset_kwargs_list and sampling weights
oxe_kwargs=dict(
data_mix=placeholder(str),
data_dir=placeholder(str),
load_camera_views=("primary", "wrist"),
load_depth=False,
),
traj_transform_kwargs=dict(
window_size=window_size,
action_horizon=1,
goal_relabeling_strategy="uniform",
subsample_length=100,
**task_augmentation,
),
frame_transform_kwargs=dict(
resize_size=dict(primary=(256, 256)),
image_dropout_prob=0.0,
image_augment_kwargs=dict(
primary=dict(
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
random_brightness=[0.2],
random_contrast=[0.8, 1.2],
random_saturation=[0.8, 1.2],
random_hue=[0.1],
augment_order=[
"random_resized_crop",
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
)
),
num_parallel_calls=200,
),
traj_transform_threads=48, # shared between all datasets
traj_read_threads=48, # shared between all datasets
shuffle_buffer_size=100000, # shared between all datasets
batch_size=512,
balance_weights=True,
)