xesn.Driver.run_training#
- Driver.run_training()#
Perform
ESNorLazyESNtraining, learn the readout matrix weights. Resulting readout weights are stored in theoutput_directoryin the zarr storeesn-weights.zarrorlazyesn-weights.zarr, depending on which model is used. SeeESN.to_xds()for what is stored in this dataset.- Required Parameter Sections:
xdata – all options are used to create an
XDataobject. Ifsubsamplingis provided, it must have slicing options labelled “training”.esn or lazyesn – all options are used to create a
ESNorLazyESNobject, depending on the name of the section, case does not mattertraining – all options are passed to
ESN.train()orLazyESN.train()
Example Config YAML File
Highlighted regions are used by this routine, other options are ignored.
xdata: dimensions : ['x', 'time'] # zstore_path : lorenz96-12d.zarr field_name : 'trajectory' subsampling: time: training : [ 2_000, 42_001, null] macro_training : [ 42_000, 52_001, null] testing : [ 57_000, null, null] normalization: bias : 2.38565488 # computed from trainer.mean scale : 3.65852722 # computed from trainer.std esn: n_input : 12 n_output : 12 n_reservoir : 1_000 leak_rate : 0.5 tikhonov_parameter : 1.e-6 # input_kwargs: factor : 0.5 distribution : uniform normalization : svd is_sparse : False random_seed : 0 # adjacency_kwargs: factor : 0.9 distribution : uniform normalization : svd is_sparse : True connectedness : 5 random_seed : 1 # bias_kwargs: factor : 0. distribution : uniform random_seed : 2 training: n_spinup : 0 batch_size : 20_000 esn_weights: store : output-guess/esn-weights.zarr testing: n_spinup : 100 n_samples : 20 n_steps : 500 random_seed : 10 cost_terms: nrmse : 1. psd_nrmse : 1. macro_training: parameters: input_factor : [0., 2.] adjacency_factor: [0., 2.] bias_factor : [0., 2.] leak_rate : [0., 1.] tikhonov_parameter : [1.e-12, 1.] transformations: tikhonov_parameter : log10 forecast: n_spinup : 100 n_samples : 5 n_steps : 500 random_seed : 10 ego: n_iter : 5 n_doe : 10 n_parallel : 4 random_state : 5 cost_upper_bound : 1.e9 cost_terms: nrmse : 1. psd_nrmse : 1.