-
Notifications
You must be signed in to change notification settings - Fork 0
/
training_config.yaml
39 lines (36 loc) · 2.16 KB
/
training_config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# model parameters
model_name : AutoregressiveLMU # name of the model
load_state : "" # path to model state for loading pretrained models
dt : 0.005 # nengo time step
t_delay : 0.02 # how far to predict the future (initial guess)
neurons_per_dim : 50 # number of neurons representing each dimension
radius : 1.5 # radius of the nengo ensemble
lmu_theta : 0.02 # duration of the LMU delay
lmu_q : 1 # number of factorizations per dim in LMU
predict_delta : True # decide if the model should use the current state as a prior for next state prediction
seed : 4 # to get reproducible neuron properties across runs
# training data parameters
data_dir : data/Mixed/Train/ # folder containing individual training files
bound : 0.19 # if the cart ever leaves these bounds, the data is ignored
sample_freq : 50 # cartpole data is recorded at ~50Hz
shuffle : True # shuffle the training data
skiprows : 0 # initial rows to skip in the data .csv files
action_vars : # names of the action variables in training data
- Q
state_vars : # names of the state variables in training data
- angle_sin
- angle_cos
- angleD
- position
- positionD
# result parameters
results_dir : results/ # folder to save the plots and weights to
experiment_name : training # sub-folder of results folder
# training process parameters
epochs : 1 # number or epochs for training
learning_rate : 0.00005 # learning rate used in online learning
save_state_every : 200 # how often to save the weights of the model during epoch
plot_prediction_every : 200 # how often to plot a prediction curve during learning
max_samples : -1 # just reduce training set for quick debugging
t_switch : 5.0 # after how many seconds to switch to autoregressive mode
datapoints_per_file : 1000 # divide by sampling rate to get the duration in s