pmukhop commited on
Commit
6dfd56f
·
0 Parent(s):

Initial commit pdearena_ins

Browse files
Files changed (2) hide show
  1. coalesced.pth +3 -0
  2. extended_config.yaml +266 -0
coalesced.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97b0ddf1f0bbf06e461b50c155640065ed31f3f585968407b2ef9e7d476f0f8e
3
+ size 5151218135
extended_config.yaml ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_workers: 10
2
+ name: Walrus_ft_pdearena_ins_cond_realglobalnorm-PDEA--delta-Isotr[Space-Adapt-]-AdamW-0.0001
3
+ finetune: true
4
+ automatic_setup: true
5
+ trainer:
6
+ _target_: walrus.trainer.Trainer
7
+ max_epoch: 50
8
+ val_frequency: 5
9
+ rollout_val_frequency: 5
10
+ short_validation_length: 20
11
+ max_rollout_steps: 200
12
+ num_time_intervals: 5
13
+ enable_amp: false
14
+ loss_fn:
15
+ _target_: the_well.benchmark.metrics.MAE
16
+ formatter:
17
+ _target_: hydra.utils.get_class
18
+ path: walrus.data.well_to_multi_transformer.ChannelsFirstWithTimeFormatter
19
+ revin:
20
+ _target_: walrus.trainer.normalization_strat.GlobalRevNormalization
21
+ _partial_: true
22
+ prediction_type: delta
23
+ grad_acc_steps: 1
24
+ image_validation: true
25
+ video_validation: true
26
+ gradient_log_level: 0
27
+ clip_gradient: 10
28
+ log_interval: 200
29
+ loss_multiplier: 100.0
30
+ lr_scheduler_per_step: false
31
+ validation_suite:
32
+ - _target_: the_well.benchmark.metrics.NRMSE
33
+ - _target_: the_well.benchmark.metrics.VRMSE
34
+ - _target_: the_well.benchmark.metrics.PearsonR
35
+ validation_trajectory_metrics:
36
+ - _target_: the_well.benchmark.metrics.HistogramW1
37
+ - _target_: the_well.benchmark.metrics.WindowedDTW
38
+ batch_aggregation_fns:
39
+ - torch.mean
40
+ - torch.median
41
+ - torch.std
42
+ skip_spectral_metrics: true
43
+ optimizer:
44
+ _target_: torch.optim.AdamW
45
+ weight_decay: 0.0001
46
+ eps: 1.0e-10
47
+ lr: 0.0001
48
+ lr_scheduler:
49
+ _target_: walrus.optim.schedulers.InverseSqrtLinearWarmupSqrtCooldown
50
+ warmup_epochs: 10
51
+ cooldown_epochs: 10
52
+ warmup_lr_factor: 0.1
53
+ cooldown_lr_factor: 0.001
54
+ model:
55
+ encoder:
56
+ _partial_: true
57
+ _target_: walrus.models.encoders.vstride_encoder.SpaceBagAdaptiveDVstrideEncoder
58
+ learned_pad: true
59
+ base_kernel_size1d:
60
+ - - 4
61
+ - 4
62
+ base_kernel_size2d:
63
+ - - 8
64
+ - 4
65
+ - - 8
66
+ - 4
67
+ base_kernel_size3d:
68
+ - - 8
69
+ - 4
70
+ - - 8
71
+ - 4
72
+ - - 8
73
+ - 4
74
+ groups: 12
75
+ kernel_scales_seq:
76
+ - - 2
77
+ - 2
78
+ - - 4
79
+ - 2
80
+ - - 4
81
+ - 4
82
+ - - 8
83
+ - 4
84
+ variable_downsample: true
85
+ variable_deterministic_ds: true
86
+ activation:
87
+ _partial_: true
88
+ _target_: torch.nn.SiLU
89
+ decoder:
90
+ _partial_: true
91
+ _target_: walrus.models.decoders.vstride_decoder.AdaptiveDVstrideDecoder
92
+ learned_pad: true
93
+ base_kernel_size1d:
94
+ - - 4
95
+ - 4
96
+ base_kernel_size2d:
97
+ - - 8
98
+ - 4
99
+ - - 8
100
+ - 4
101
+ base_kernel_size3d:
102
+ - - 8
103
+ - 4
104
+ - - 8
105
+ - 4
106
+ - - 8
107
+ - 4
108
+ groups: 12
109
+ activation:
110
+ _partial_: true
111
+ _target_: torch.nn.SiLU
112
+ processor:
113
+ space_mixing:
114
+ _partial_: true
115
+ _target_: walrus.models.spatial_blocks.full_attention.FullAttention
116
+ num_heads: 16
117
+ mlp_dim: null
118
+ time_mixing:
119
+ _partial_: true
120
+ _target_: walrus.models.temporal_blocks.axial_time_attention.AxialTimeAttention
121
+ num_heads: 16
122
+ bias_type: rel
123
+ channel_mixing:
124
+ _partial_: true
125
+ _target_: torch.nn.Identity
126
+ _partial_: true
127
+ _target_: walrus.models.spatiotemporal_blocks.space_time_split.SpaceTimeSplitBlock
128
+ norm_layer:
129
+ _partial_: true
130
+ _target_: walrus.models.shared_utils.normalization.RMSGroupNorm
131
+ _target_: walrus.models.IsotropicModel
132
+ hidden_dim: 1408
133
+ projection_dim: 48
134
+ intermediate_dim: 352
135
+ processor_blocks: 40
136
+ drop_path: 0.0
137
+ groups: 16
138
+ max_d: 3
139
+ static_axes: true
140
+ weight_tied_axes: false
141
+ causal_in_time: true
142
+ include_d:
143
+ - 2
144
+ - 3
145
+ override_dimensionality: 0
146
+ jitter_patches: true
147
+ gradient_checkpointing_freq: 0
148
+ use_periodic_fixed_jitter: true
149
+ input_field_drop: 0
150
+ data:
151
+ field_index_map_override:
152
+ closed_boundary: 0
153
+ open_boundary: 1
154
+ bias_correction: 2
155
+ pressure: 3
156
+ velocity_x: 4
157
+ velocity_y: 5
158
+ velocity_z: 6
159
+ zeros_like_density: 7
160
+ speed_of_sound: 8
161
+ concentration: 9
162
+ D_xx: 10
163
+ D_xy: 11
164
+ D_xz: 12
165
+ D_yx: 13
166
+ D_yy: 14
167
+ D_yz: 15
168
+ D_zx: 16
169
+ D_zy: 17
170
+ D_zz: 18
171
+ E_xx: 19
172
+ E_xy: 20
173
+ E_xz: 21
174
+ E_yx: 22
175
+ E_yy: 23
176
+ E_yz: 24
177
+ E_zx: 25
178
+ E_zy: 26
179
+ E_zz: 27
180
+ density: 28
181
+ energy: 29
182
+ velocity_r: 30
183
+ velocity_theta: 31
184
+ velocity_phi: 32
185
+ momentum_x: 33
186
+ momentum_y: 34
187
+ momentum_z: 35
188
+ pressure_re: 36
189
+ pressure_im: 37
190
+ mask: 38
191
+ magnetic_field_x: 39
192
+ magnetic_field_y: 40
193
+ magnetic_field_z: 41
194
+ A: 42
195
+ B: 43
196
+ height: 44
197
+ internal_energy: 45
198
+ temperature: 46
199
+ electron_fraction: 47
200
+ entropy: 48
201
+ magnetic_field_log_r: 49
202
+ magnetic_field_theta: 50
203
+ magnetic_field_phi: 51
204
+ velocity_log_r: 52
205
+ buoyancy: 53
206
+ tracer: 54
207
+ log10_density: 55
208
+ log10_temperature: 56
209
+ c_zz: 57
210
+ C_xx: 58
211
+ C_xy: 59
212
+ C_xz: 60
213
+ C_yx: 61
214
+ C_yy: 62
215
+ C_yz: 63
216
+ C_zx: 64
217
+ C_zy: 65
218
+ C_zz: 66
219
+ well_base_path: /mnt/gpuxl/polymathic/the_well/datasets/
220
+ wandb_data_name: PDEA-INS
221
+ module_parameters:
222
+ _target_: walrus.data.MixedWellDataModule
223
+ batch_size: 1
224
+ n_steps_input: 6
225
+ n_steps_output: 1
226
+ min_dt_stride: 1
227
+ max_dt_stride: 1
228
+ max_samples: 2000
229
+ max_rollout_steps: 200
230
+ well_dataset_info:
231
+ pdea_ins:
232
+ include_filters: []
233
+ exclude_filters: []
234
+ path: /mnt/gpuxl/polymathic/WellFormattedExternalData/PDEArena/buoyant_ins_cond/
235
+ normalization_path: stats.yaml
236
+ start_rollout_valid_output_at_t: 11
237
+ auto_resume: true
238
+ folder_override: ''
239
+ checkpoint_override: ''
240
+ config_override: /mnt/home/polymathic/ceph/walrus_logging/platinum_checkpoints/extended_config.yaml
241
+ validation_mode: false
242
+ frozen_components:
243
+ - model
244
+ distribution:
245
+ distribution_type: fsdp
246
+ local_size: null
247
+ logger:
248
+ wandb: true
249
+ wandb_project_name: walrus_Finetuning_Runs
250
+ checkpoint:
251
+ _target_: walrus.trainer.checkpoints.CheckPointer
252
+ save_dir: /mnt/home/polymathic/ceph/walrus_logging/runs/Walrus_ft_pdearena_ins_cond_realglobalnorm-PDEA--delta-Isotr[Space-Adapt-]-AdamW-0.0001/finetune/0/checkpoints
253
+ load_checkpoint_path: null
254
+ coalesced_checkpoint_path: /mnt/home/polymathic/ceph/walrus_logging/platinum_checkpoints/final_base_model/walrus.pt
255
+ save_best: true
256
+ checkpoint_frequency: 20
257
+ align_fields: true
258
+ load_chkpt_after_finetuning_expansion: false
259
+ finetuning_mods:
260
+ learnable_rope: true
261
+ rope_per_axis: true
262
+ ape_shape:
263
+ - 33
264
+ - 33
265
+ - 1
266
+ experiment_dir: /mnt/home/polymathic/ceph/walrus_logging/runs