jadechoghari HF Staff commited on
Commit
014f19a
·
verified ·
1 Parent(s): 3a00b61

add initial files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. README.md +63 -0
  3. config.json +122 -0
  4. config.yaml +235 -0
  5. model.safetensors +3 -0
  6. replay.mp4 +3 -0
  7. train_config.json +280 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ replay.mp4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: lerobot
3
+ tags:
4
+ - model_hub_mixin
5
+ - pytorch_model_hub_mixin
6
+ - robotics
7
+ - dot
8
+ license: apache-2.0
9
+ datasets:
10
+ - lerobot/pusht_keypoints
11
+ pipeline_tag: robotics
12
+ ---
13
+
14
+ # Model Card for "Decoder Only Transformer (DOT) Policy" for PushT keypoints dataset
15
+
16
+ Read more about the model and implementation details in the [DOT Policy repository](https://github.com/IliaLarchenko/dot_policy).
17
+
18
+ This model is trained using the [LeRobot library](https://huggingface.co/lerobot) and achieves state-of-the-art results on behavior cloning on the PushT keypoints dataset. It achieves 94% success rate (and 0.985 average max reward) vs. ~78% for the previous state-of-the-art model or 69% that I managed to reproduce using VQ-BET implementation in LeRobot.
19
+
20
+ This is the best checkpoint for the model. These results are achievable assuming we have reliable validation and can select the best checkpoint based on the validation results (not always the case in robotics). If you are interested in more stable and reproducible results achievable without checkpoint selection, please refer to https://huggingface.co/IliaLarchenko/dot_pusht_keypoints
21
+
22
+ You can use this model by installing LeRobot from [this branch](https://github.com/IliaLarchenko/lerobot/tree/dot_new_config)
23
+
24
+ To train the model:
25
+
26
+ ```bash
27
+ python lerobot/scripts/train.py \
28
+ --policy.type=dot \
29
+ --dataset.repo_id=lerobot/pusht_keypoints \
30
+ --env.type=pusht \
31
+ --env.task=PushT-v0 \
32
+ --output_dir=outputs/train/pusht_keyponts \
33
+ --batch_size=24 \
34
+ --log_freq=1000 \
35
+ --eval_freq=10000 \
36
+ --save_freq=50000 \
37
+ --offline.steps=1000000 \
38
+ --seed=100000 \
39
+ --wandb.enable=true \
40
+ --num_workers=24 \
41
+ --use_amp=true \
42
+ --device=cuda \
43
+ --policy.return_every_n=2 \
44
+ --policy.train_horizon=30 \
45
+ --policy.inference_horizon=30
46
+ ```
47
+
48
+ To evaluate the model:
49
+
50
+ ```bash
51
+ python lerobot/scripts/eval.py \
52
+ --policy.path=IliaLarchenko/dot_pusht_keypoints_best \
53
+ --env.type=pusht \
54
+ --env.task=PushT-v0 \
55
+ --eval.n_episodes=1000 \
56
+ --eval.batch_size=100 \
57
+ --env.obs_type=environment_state_agent_pos \
58
+ --seed=1000000
59
+ ```
60
+
61
+ Model size:
62
+ - Total parameters: 2.1m
63
+ - Trainable parameters: 2.1m
config.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "dot",
3
+ "n_obs_steps": 3,
4
+ "normalization_mapping": {
5
+ "VISUAL": "MEAN_STD",
6
+ "STATE": "MIN_MAX",
7
+ "ENV": "MIN_MAX",
8
+ "ACTION": "MIN_MAX"
9
+ },
10
+ "input_features": {
11
+ "observation.state": {
12
+ "type": "STATE",
13
+ "shape": [
14
+ 2
15
+ ]
16
+ },
17
+ "observation.environment_state": {
18
+ "type": "ENV",
19
+ "shape": [
20
+ 16
21
+ ]
22
+ }
23
+ },
24
+ "output_features": {
25
+ "action": {
26
+ "type": "ACTION",
27
+ "shape": [
28
+ 2
29
+ ]
30
+ }
31
+ },
32
+ "train_horizon": 30,
33
+ "inference_horizon": 30,
34
+ "lookback_obs_steps": 10,
35
+ "lookback_aug": 5,
36
+ "override_dataset_stats": false,
37
+ "new_dataset_stats": {
38
+ "action": {
39
+ "max": [
40
+ 512.0,
41
+ 512.0
42
+ ],
43
+ "min": [
44
+ 0.0,
45
+ 0.0
46
+ ]
47
+ },
48
+ "observation.environment_state": {
49
+ "max": [
50
+ 512.0,
51
+ 512.0,
52
+ 512.0,
53
+ 512.0,
54
+ 512.0,
55
+ 512.0,
56
+ 512.0,
57
+ 512.0,
58
+ 512.0,
59
+ 512.0,
60
+ 512.0,
61
+ 512.0,
62
+ 512.0,
63
+ 512.0,
64
+ 512.0,
65
+ 512.0
66
+ ],
67
+ "min": [
68
+ 0.0,
69
+ 0.0,
70
+ 0.0,
71
+ 0.0,
72
+ 0.0,
73
+ 0.0,
74
+ 0.0,
75
+ 0.0,
76
+ 0.0,
77
+ 0.0,
78
+ 0.0,
79
+ 0.0,
80
+ 0.0,
81
+ 0.0,
82
+ 0.0,
83
+ 0.0
84
+ ]
85
+ },
86
+ "observation.state": {
87
+ "max": [
88
+ 512.0,
89
+ 512.0
90
+ ],
91
+ "min": [
92
+ 0.0,
93
+ 0.0
94
+ ]
95
+ }
96
+ },
97
+ "vision_backbone": "resnet18",
98
+ "pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
99
+ "pre_norm": true,
100
+ "lora_rank": 20,
101
+ "merge_lora": false,
102
+ "dim_model": 128,
103
+ "n_heads": 8,
104
+ "dim_feedforward": 512,
105
+ "n_decoder_layers": 8,
106
+ "rescale_shape": [
107
+ 96,
108
+ 96
109
+ ],
110
+ "crop_scale": 1.0,
111
+ "state_noise": 0.01,
112
+ "noise_decay": 0.999995,
113
+ "dropout": 0.1,
114
+ "alpha": 0.75,
115
+ "train_alpha": 0.9,
116
+ "predict_every_n": 1,
117
+ "return_every_n": 2,
118
+ "optimizer_lr": 0.0001,
119
+ "optimizer_min_lr": 0.0001,
120
+ "optimizer_lr_cycle_steps": 300000,
121
+ "optimizer_weight_decay": 1e-05
122
+ }
config.yaml ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ resume: false
2
+ device: cuda
3
+ use_amp: true
4
+ seed: 100000
5
+ dataset_repo_id: lerobot/pusht_keypoints
6
+ video_backend: pyav
7
+ training:
8
+ offline_steps: 1000000
9
+ num_workers: 24
10
+ batch_size: 24
11
+ eval_freq: 10000
12
+ log_freq: 1000
13
+ save_checkpoint: true
14
+ save_freq: 50000
15
+ online_steps: 0
16
+ online_rollout_n_episodes: 1
17
+ online_rollout_batch_size: 1
18
+ online_steps_between_rollouts: 1
19
+ online_sampling_ratio: 0.5
20
+ online_env_seed: null
21
+ online_buffer_capacity: null
22
+ online_buffer_seed_size: 0
23
+ do_online_rollout_async: false
24
+ image_transforms:
25
+ enable: false
26
+ max_num_transforms: 3
27
+ random_order: false
28
+ brightness:
29
+ weight: 1
30
+ min_max:
31
+ - 0.8
32
+ - 1.2
33
+ contrast:
34
+ weight: 1
35
+ min_max:
36
+ - 0.8
37
+ - 1.2
38
+ saturation:
39
+ weight: 1
40
+ min_max:
41
+ - 0.5
42
+ - 1.5
43
+ hue:
44
+ weight: 1
45
+ min_max:
46
+ - -0.05
47
+ - 0.05
48
+ sharpness:
49
+ weight: 1
50
+ min_max:
51
+ - 0.8
52
+ - 1.2
53
+ save_model: true
54
+ grad_clip_norm: 50
55
+ lr: 0.0001
56
+ min_lr: 0.0001
57
+ lr_cycle_steps: 300000
58
+ weight_decay: 1.0e-05
59
+ delta_timestamps:
60
+ observation.environment_state:
61
+ - -1.5
62
+ - -1.4
63
+ - -1.3
64
+ - -1.2
65
+ - -1.1
66
+ - -1.0
67
+ - -0.9
68
+ - -0.8
69
+ - -0.7
70
+ - -0.6
71
+ - -0.5
72
+ - -0.1
73
+ - 0.0
74
+ observation.state:
75
+ - -1.5
76
+ - -1.4
77
+ - -1.3
78
+ - -1.2
79
+ - -1.1
80
+ - -1.0
81
+ - -0.9
82
+ - -0.8
83
+ - -0.7
84
+ - -0.6
85
+ - -0.5
86
+ - -0.1
87
+ - 0.0
88
+ action:
89
+ - -1.5
90
+ - -1.4
91
+ - -1.3
92
+ - -1.2
93
+ - -1.1
94
+ - -1.0
95
+ - -0.9
96
+ - -0.8
97
+ - -0.7
98
+ - -0.6
99
+ - -0.5
100
+ - -0.1
101
+ - 0.0
102
+ - 0.1
103
+ - 0.2
104
+ - 0.3
105
+ - 0.4
106
+ - 0.5
107
+ - 0.6
108
+ - 0.7
109
+ - 0.8
110
+ - 0.9
111
+ - 1.0
112
+ - 1.1
113
+ - 1.2
114
+ - 1.3
115
+ - 1.4
116
+ - 1.5
117
+ - 1.6
118
+ - 1.7
119
+ - 1.8
120
+ - 1.9
121
+ - 2.0
122
+ - 2.1
123
+ - 2.2
124
+ - 2.3
125
+ - 2.4
126
+ - 2.5
127
+ - 2.6
128
+ - 2.7
129
+ - 2.8
130
+ - 2.9
131
+ eval:
132
+ n_episodes: 100
133
+ batch_size: 100
134
+ use_async_envs: false
135
+ wandb:
136
+ enable: true
137
+ disable_artifact: false
138
+ project: pusht
139
+ notes: ''
140
+ fps: 10
141
+ env:
142
+ name: pusht
143
+ task: PushT-v0
144
+ image_size: 96
145
+ state_dim: 2
146
+ action_dim: 2
147
+ fps: ${fps}
148
+ episode_length: 300
149
+ gym:
150
+ obs_type: environment_state_agent_pos
151
+ render_mode: rgb_array
152
+ visualization_width: 384
153
+ visualization_height: 384
154
+ override_dataset_stats:
155
+ observation.environment_state:
156
+ min:
157
+ - 0.0
158
+ - 0.0
159
+ - 0.0
160
+ - 0.0
161
+ - 0.0
162
+ - 0.0
163
+ - 0.0
164
+ - 0.0
165
+ - 0.0
166
+ - 0.0
167
+ - 0.0
168
+ - 0.0
169
+ - 0.0
170
+ - 0.0
171
+ - 0.0
172
+ - 0.0
173
+ max:
174
+ - 512.0
175
+ - 512.0
176
+ - 512.0
177
+ - 512.0
178
+ - 512.0
179
+ - 512.0
180
+ - 512.0
181
+ - 512.0
182
+ - 512.0
183
+ - 512.0
184
+ - 512.0
185
+ - 512.0
186
+ - 512.0
187
+ - 512.0
188
+ - 512.0
189
+ - 512.0
190
+ observation.state:
191
+ min:
192
+ - 0.0
193
+ - 0.0
194
+ max:
195
+ - 512.0
196
+ - 512.0
197
+ action:
198
+ min:
199
+ - 0.0
200
+ - 0.0
201
+ max:
202
+ - 512.0
203
+ - 512.0
204
+ policy:
205
+ name: dot
206
+ n_obs_steps: 3
207
+ train_horizon: 30
208
+ inference_horizon: 30
209
+ lookback_obs_steps: 10
210
+ lookback_aug: 5
211
+ input_shapes:
212
+ observation.environment_state:
213
+ - 16
214
+ observation.state:
215
+ - ${env.state_dim}
216
+ output_shapes:
217
+ action:
218
+ - ${env.action_dim}
219
+ input_normalization_modes:
220
+ observation.environment_state: min_max
221
+ observation.state: min_max
222
+ output_normalization_modes:
223
+ action: min_max
224
+ state_noise: 0.01
225
+ noise_decay: 0.999995
226
+ pre_norm: true
227
+ dim_model: 128
228
+ n_heads: 8
229
+ dim_feedforward: 512
230
+ n_decoder_layers: 8
231
+ dropout: 0.1
232
+ alpha: 0.75
233
+ train_alpha: 0.9
234
+ predict_every_n: 1
235
+ return_every_n: 2
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d5e02a6c29abeaf8b44b1c78dc953b7bb8ce8983ed9491e7fe19eb24cfd6c94
3
+ size 8534312
replay.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e32ed981aca2d1f85511b1f310421dc766c26aad6427ff45c32439e8975187e
3
+ size 135178
train_config.json ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset": {
3
+ "repo_id": "lerobot/pusht_keypoints",
4
+ "episodes": null,
5
+ "image_transforms": {
6
+ "enable": false,
7
+ "max_num_transforms": 3,
8
+ "random_order": false,
9
+ "tfs": {
10
+ "brightness": {
11
+ "weight": 1.0,
12
+ "type": "ColorJitter",
13
+ "kwargs": {
14
+ "brightness": [
15
+ 0.8,
16
+ 1.2
17
+ ]
18
+ }
19
+ },
20
+ "contrast": {
21
+ "weight": 1.0,
22
+ "type": "ColorJitter",
23
+ "kwargs": {
24
+ "contrast": [
25
+ 0.8,
26
+ 1.2
27
+ ]
28
+ }
29
+ },
30
+ "saturation": {
31
+ "weight": 1.0,
32
+ "type": "ColorJitter",
33
+ "kwargs": {
34
+ "saturation": [
35
+ 0.5,
36
+ 1.5
37
+ ]
38
+ }
39
+ },
40
+ "hue": {
41
+ "weight": 1.0,
42
+ "type": "ColorJitter",
43
+ "kwargs": {
44
+ "hue": [
45
+ -0.05,
46
+ 0.05
47
+ ]
48
+ }
49
+ },
50
+ "sharpness": {
51
+ "weight": 1.0,
52
+ "type": "SharpnessJitter",
53
+ "kwargs": {
54
+ "sharpness": [
55
+ 0.5,
56
+ 1.5
57
+ ]
58
+ }
59
+ }
60
+ }
61
+ },
62
+ "local_files_only": false,
63
+ "use_imagenet_stats": true,
64
+ "video_backend": "pyav"
65
+ },
66
+ "env": {
67
+ "type": "pusht",
68
+ "task": "PushT-v0",
69
+ "fps": 10,
70
+ "features": {
71
+ "action": {
72
+ "type": "ACTION",
73
+ "shape": [
74
+ 2
75
+ ]
76
+ },
77
+ "agent_pos": {
78
+ "type": "STATE",
79
+ "shape": [
80
+ 2
81
+ ]
82
+ },
83
+ "environment_state": {
84
+ "type": "ENV",
85
+ "shape": [
86
+ 16
87
+ ]
88
+ }
89
+ },
90
+ "features_map": {
91
+ "action": "action",
92
+ "agent_pos": "observation.state",
93
+ "environment_state": "observation.environment_state",
94
+ "pixels": "observation.image"
95
+ },
96
+ "episode_length": 300,
97
+ "obs_type": "environment_state_agent_pos",
98
+ "render_mode": "rgb_array",
99
+ "visualization_width": 384,
100
+ "visualization_height": 384
101
+ },
102
+ "policy": {
103
+ "type": "dot",
104
+ "n_obs_steps": 3,
105
+ "normalization_mapping": {
106
+ "VISUAL": "MEAN_STD",
107
+ "STATE": "MIN_MAX",
108
+ "ENV": "MIN_MAX",
109
+ "ACTION": "MIN_MAX"
110
+ },
111
+ "input_features": {
112
+ "observation.state": {
113
+ "type": "STATE",
114
+ "shape": [
115
+ 2
116
+ ]
117
+ },
118
+ "observation.environment_state": {
119
+ "type": "ENV",
120
+ "shape": [
121
+ 16
122
+ ]
123
+ }
124
+ },
125
+ "output_features": {
126
+ "action": {
127
+ "type": "ACTION",
128
+ "shape": [
129
+ 2
130
+ ]
131
+ }
132
+ },
133
+ "train_horizon": 30,
134
+ "inference_horizon": 30,
135
+ "lookback_obs_steps": 10,
136
+ "lookback_aug": 5,
137
+ "override_dataset_stats": false,
138
+ "new_dataset_stats": {
139
+ "action": {
140
+ "max": [
141
+ 512.0,
142
+ 512.0
143
+ ],
144
+ "min": [
145
+ 0.0,
146
+ 0.0
147
+ ]
148
+ },
149
+ "observation.environment_state": {
150
+ "max": [
151
+ 512.0,
152
+ 512.0,
153
+ 512.0,
154
+ 512.0,
155
+ 512.0,
156
+ 512.0,
157
+ 512.0,
158
+ 512.0,
159
+ 512.0,
160
+ 512.0,
161
+ 512.0,
162
+ 512.0,
163
+ 512.0,
164
+ 512.0,
165
+ 512.0,
166
+ 512.0
167
+ ],
168
+ "min": [
169
+ 0.0,
170
+ 0.0,
171
+ 0.0,
172
+ 0.0,
173
+ 0.0,
174
+ 0.0,
175
+ 0.0,
176
+ 0.0,
177
+ 0.0,
178
+ 0.0,
179
+ 0.0,
180
+ 0.0,
181
+ 0.0,
182
+ 0.0,
183
+ 0.0,
184
+ 0.0
185
+ ]
186
+ },
187
+ "observation.state": {
188
+ "max": [
189
+ 512.0,
190
+ 512.0
191
+ ],
192
+ "min": [
193
+ 0.0,
194
+ 0.0
195
+ ]
196
+ }
197
+ },
198
+ "vision_backbone": "resnet18",
199
+ "pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
200
+ "pre_norm": true,
201
+ "lora_rank": 20,
202
+ "merge_lora": false,
203
+ "dim_model": 128,
204
+ "n_heads": 8,
205
+ "dim_feedforward": 512,
206
+ "n_decoder_layers": 8,
207
+ "rescale_shape": [
208
+ 96,
209
+ 96
210
+ ],
211
+ "crop_scale": 1.0,
212
+ "state_noise": 0.01,
213
+ "noise_decay": 0.999995,
214
+ "dropout": 0.1,
215
+ "alpha": 0.75,
216
+ "train_alpha": 0.9,
217
+ "predict_every_n": 1,
218
+ "return_every_n": 2,
219
+ "optimizer_lr": 0.0001,
220
+ "optimizer_min_lr": 0.0001,
221
+ "optimizer_lr_cycle_steps": 300000,
222
+ "optimizer_weight_decay": 1e-05
223
+ },
224
+ "output_dir": "outputs/train/pusht_keypoints",
225
+ "job_name": "pusht_dot",
226
+ "resume": false,
227
+ "device": "cuda",
228
+ "use_amp": true,
229
+ "seed": 100000,
230
+ "num_workers": 24,
231
+ "batch_size": 24,
232
+ "eval_freq": 10000,
233
+ "log_freq": 1000,
234
+ "save_checkpoint": true,
235
+ "save_freq": 50000,
236
+ "offline": {
237
+ "steps": 1000000
238
+ },
239
+ "online": {
240
+ "steps": 0,
241
+ "rollout_n_episodes": 1,
242
+ "rollout_batch_size": 1,
243
+ "steps_between_rollouts": null,
244
+ "sampling_ratio": 0.5,
245
+ "env_seed": null,
246
+ "buffer_capacity": null,
247
+ "buffer_seed_size": 0,
248
+ "do_rollout_async": false
249
+ },
250
+ "use_policy_training_preset": true,
251
+ "optimizer": {
252
+ "type": "adamw",
253
+ "lr": 0.0001,
254
+ "weight_decay": 1e-05,
255
+ "grad_clip_norm": 10.0,
256
+ "betas": [
257
+ 0.9,
258
+ 0.999
259
+ ],
260
+ "eps": 1e-08
261
+ },
262
+ "scheduler": {
263
+ "type": "cosine_annealing",
264
+ "num_warmup_steps": 0,
265
+ "min_lr": 0.0001,
266
+ "T_max": 300000
267
+ },
268
+ "eval": {
269
+ "n_episodes": 50,
270
+ "batch_size": 50,
271
+ "use_async_envs": false
272
+ },
273
+ "wandb": {
274
+ "enable": true,
275
+ "disable_artifact": false,
276
+ "project": "pusht",
277
+ "entity": null,
278
+ "notes": null
279
+ }
280
+ }