Skip to content

Commit

Permalink
Add checkpointing and hot start (#197)
Browse files Browse the repository at this point in the history
* Implement hot_start using argmin checkpointing

* Make it work and validate hot_start

* Improve py documentaiton

* Restore trego default config
  • Loading branch information
relf authored Sep 30, 2024
1 parent d5e0617 commit ebdde79
Show file tree
Hide file tree
Showing 13 changed files with 201 additions and 12 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ input.txt
output.txt
mopta08.exe
mopta08_elf64.bin
**/.checkpoints

# JOSS
joss/paper.jats
joss/paper.pdf

1 change: 1 addition & 0 deletions ego/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ nlopt = { version = "0.7.0", optional = true }

rand_xoshiro = { version = "0.6", features = ["serde1"] }
argmin = { version = "0.10.0", features = ["serde1", "ctrlc"] }
bincode = { version = "1.3.0" }
web-time = "1.1.0"
libm = "0.2.6"
finitediff = { version = "0.1", features = ["ndarray"] }
Expand Down
56 changes: 54 additions & 2 deletions ego/src/egor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,18 @@ use crate::types::*;
use crate::EgorConfig;
use crate::EgorState;
use crate::{to_xtypes, EgorSolver};
use crate::{CheckpointingFrequency, HotStartCheckpoint};

use argmin::core::observers::ObserverMode;

use egobox_moe::GpMixtureParams;
use log::info;
use ndarray::{concatenate, Array2, ArrayBase, Axis, Data, Ix2};
use ndarray_rand::rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;

use argmin::core::{observers::Observe, Error, Executor, State, KV};
use serde::de::DeserializeOwned;

/// Json filename for configuration
pub const CONFIG_FILE: &str = "egor_config.json";
Expand Down Expand Up @@ -191,12 +194,12 @@ impl<O: GroupFunc> EgorBuilder<O> {
/// Egor optimizer structure used to parameterize the underlying `argmin::Solver`
/// and trigger the optimization using `argmin::Executor`.
#[derive(Clone)]
pub struct Egor<O: GroupFunc, SB: SurrogateBuilder> {
pub struct Egor<O: GroupFunc, SB: SurrogateBuilder + DeserializeOwned> {
fobj: ObjFunc<O>,
solver: EgorSolver<SB>,
}

impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
impl<O: GroupFunc, SB: SurrogateBuilder + DeserializeOwned> Egor<O, SB> {
/// Runs the (constrained) optimization of the objective function.
pub fn run(&self) -> Result<OptimResult<f64>> {
let xtypes = self.solver.config.xtypes.clone();
Expand All @@ -209,12 +212,26 @@ impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
}

let exec = Executor::new(self.fobj.clone(), self.solver.clone());

let exec = if let Some(ext_iters) = self.solver.config.hot_start {
let checkpoint = HotStartCheckpoint::new(
".checkpoints",
"egor",
CheckpointingFrequency::Always,
ext_iters,
);
exec.checkpointing(checkpoint)
} else {
exec
};

let result = if let Some(outdir) = self.solver.config.outdir.as_ref() {
let hist = OptimizationObserver::new(outdir.clone());
exec.add_observer(hist, ObserverMode::Always).run()?
} else {
exec.run()?
};

info!("{}", result);
let (x_data, y_data) = result.state().clone().take_data().unwrap();

Expand Down Expand Up @@ -399,6 +416,41 @@ mod tests {
assert_abs_diff_eq!(expected, res.x_opt, epsilon = 1e-1);
}

#[test]
#[serial]
fn test_xsinx_checkpoint_egor() {
let _ = std::fs::remove_file(".checkpoints/egor.arg");
let n_iter = 1;
let res = EgorBuilder::optimize(xsinx)
.configure(|config| config.max_iters(n_iter).seed(42).hot_start(Some(0)))
.min_within(&array![[0.0, 25.0]])
.run()
.expect("Egor should minimize");
let expected = array![19.1];
assert_abs_diff_eq!(expected, res.x_opt, epsilon = 1e-1);

// without hostart we reach the same point
let res = EgorBuilder::optimize(xsinx)
.configure(|config| config.max_iters(n_iter).seed(42).hot_start(None))
.min_within(&array![[0.0, 25.0]])
.run()
.expect("Egor should minimize");
let expected = array![19.1];
assert_abs_diff_eq!(expected, res.x_opt, epsilon = 1e-1);

// with hot start we continue
let ext_iters = 3;
let res = EgorBuilder::optimize(xsinx)
.configure(|config| config.seed(42).hot_start(Some(ext_iters)))
.min_within(&array![[0.0, 25.0]])
.run()
.expect("Egor should minimize");
let expected = array![18.9];
assert_abs_diff_eq!(expected, res.x_opt, epsilon = 1e-1);
assert_eq!(n_iter as u64 + ext_iters, res.state.get_iter());
let _ = std::fs::remove_file(".checkpoints/egor.arg");
}

#[test]
#[serial]
fn test_xsinx_auto_clustering_egor_builder() {
Expand Down
4 changes: 3 additions & 1 deletion ego/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ pub use crate::errors::*;
pub use crate::gpmix::spec::{CorrelationSpec, RegressionSpec};
pub use crate::solver::*;
pub use crate::types::*;
pub use crate::utils::find_best_result_index;
pub use crate::utils::{
find_best_result_index, Checkpoint, CheckpointingFrequency, HotStartCheckpoint,
};

mod optimizers;
mod utils;
9 changes: 9 additions & 0 deletions ego/src/solver/egor_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ pub struct EgorConfig {
pub(crate) outdir: Option<String>,
/// If true use `outdir` to retrieve and start from previous results
pub(crate) warm_start: bool,
/// If some enable checkpointing allowing to restart for given ext_iters number of iteration from last checkpointed iteration
pub(crate) hot_start: Option<u64>,
/// List of x types allowing the handling of discrete input variables
pub(crate) xtypes: Vec<XType>,
/// A random generator seed used to get reproductible results.
Expand Down Expand Up @@ -109,6 +111,7 @@ impl Default for EgorConfig {
target: f64::NEG_INFINITY,
outdir: None,
warm_start: false,
hot_start: None,
xtypes: vec![],
seed: None,
trego: TregoConfig::default(),
Expand Down Expand Up @@ -265,6 +268,12 @@ impl EgorConfig {
self
}

/// Whether checkpointing is enabled allowing hot start from previous checkpointed iteration if any
pub fn hot_start(mut self, hot_start: Option<u64>) -> Self {
self.hot_start = hot_start;
self
}

/// Allow to specify a seed for random number generator to allow
/// reproducible runs.
pub fn seed(mut self, seed: u64) -> Self {
Expand Down
5 changes: 3 additions & 2 deletions ego/src/solver/egor_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ use ndarray::{
use ndarray_stats::QuantileExt;
use rand_xoshiro::Xoshiro256Plus;
use rayon::prelude::*;
use serde::de::DeserializeOwned;

impl<SB: SurrogateBuilder> EgorSolver<SB> {
impl<SB: SurrogateBuilder + DeserializeOwned> EgorSolver<SB> {
/// Constructor of the optimization of the function `f` with specified random generator
/// to get reproducibility.
///
Expand Down Expand Up @@ -80,7 +81,7 @@ impl<SB: SurrogateBuilder> EgorSolver<SB> {

impl<SB> EgorSolver<SB>
where
SB: SurrogateBuilder,
SB: SurrogateBuilder + DeserializeOwned,
{
pub fn have_to_recluster(&self, added: usize, prev_added: usize) -> bool {
self.config.n_clusters == 0 && (added != 0 && added % 10 == 0 && added - prev_added > 0)
Expand Down
5 changes: 3 additions & 2 deletions ego/src/solver/egor_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use egobox_moe::GpMixtureParams;
use ndarray::{Array2, ArrayBase, Data, Ix2};
use ndarray_rand::rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use serde::de::DeserializeOwned;

/// EGO optimizer service builder allowing to use Egor optimizer
/// as a service.
Expand Down Expand Up @@ -114,11 +115,11 @@ impl EgorServiceBuilder {

/// Egor optimizer service.
#[derive(Clone)]
pub struct EgorService<SB: SurrogateBuilder> {
pub struct EgorService<SB: SurrogateBuilder + DeserializeOwned> {
solver: EgorSolver<SB>,
}

impl<SB: SurrogateBuilder> EgorService<SB> {
impl<SB: SurrogateBuilder + DeserializeOwned> EgorService<SB> {
/// Given an evaluated doe (x, y) data, return the next promising x point
/// where optimum may be located with regard to the infill criterion.
/// This function inverses the control of the optimization and can be used
Expand Down
6 changes: 3 additions & 3 deletions ego/src/solver/egor_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ use argmin::core::{
};

use rand_xoshiro::Xoshiro256Plus;
use serde::{Deserialize, Serialize};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::time::Instant;

/// Numpy filename for initial DOE dump
Expand Down Expand Up @@ -161,7 +161,7 @@ pub fn to_xtypes(xlimits: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Vec<XType>
impl<O, SB> Solver<O, EgorState<f64>> for EgorSolver<SB>
where
O: CostFunction<Param = Array2<f64>, Output = Array2<f64>>,
SB: SurrogateBuilder,
SB: SurrogateBuilder + DeserializeOwned,
{
const NAME: &'static str = "Egor";

Expand Down Expand Up @@ -304,7 +304,7 @@ where

impl<SB> EgorSolver<SB>
where
SB: SurrogateBuilder,
SB: SurrogateBuilder + DeserializeOwned,
{
/// Iteration of EGO algorithm
fn ego_iteration<O: CostFunction<Param = Array2<f64>, Output = Array2<f64>>>(
Expand Down
10 changes: 10 additions & 0 deletions ego/src/solver/egor_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,16 @@ where
}
}

impl<F> EgorState<F>
where
F: Float + ArgminFloat,
{
/// Allow hot start feature by extending current max_iters
pub fn extend_max_iters(&mut self, ext_iters: u64) {
self.max_iters += ext_iters;
}
}

impl<F> State for EgorState<F>
where
F: Float + ArgminFloat,
Expand Down
3 changes: 2 additions & 1 deletion ego/src/solver/trego.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ use ndarray::Zip;
use ndarray::{s, Array, Array1, Array2, ArrayView1, Axis};

use rayon::prelude::*;
use serde::de::DeserializeOwned;

impl<SB: SurrogateBuilder> EgorSolver<SB> {
impl<SB: SurrogateBuilder + DeserializeOwned> EgorSolver<SB> {
/// Local step where infill criterion is optimized within trust region
pub fn trego_step<O: CostFunction<Param = Array2<f64>, Output = Array2<f64>>>(
&mut self,
Expand Down
94 changes: 94 additions & 0 deletions ego/src/utils/hot_start.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
pub use argmin::core::checkpointing::{Checkpoint, CheckpointingFrequency};
use argmin::core::Error;
use serde::{de::DeserializeOwned, Serialize};
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::PathBuf;

use crate::EgorState;

/// Handles saving a checkpoint to disk as a binary file.
#[derive(Clone, Eq, PartialEq, Debug, Hash)]
pub struct HotStartCheckpoint {
/// Extended iteration number
pub extension_iters: u64,
/// Indicates how often a checkpoint is created
pub frequency: CheckpointingFrequency,
/// Directory where the checkpoints are saved to
pub directory: PathBuf,
/// Name of the checkpoint files
pub filename: PathBuf,
}

impl Default for HotStartCheckpoint {
/// Create a default `HotStartCheckpoint` instance.
fn default() -> HotStartCheckpoint {
HotStartCheckpoint {
extension_iters: 0,
frequency: CheckpointingFrequency::default(),
directory: PathBuf::from(".checkpoints"),
filename: PathBuf::from("egor.arg"),
}
}
}

impl HotStartCheckpoint {
/// Create a new `HotStartCheckpoint` instance
pub fn new<N: AsRef<str>>(
directory: N,
name: N,
frequency: CheckpointingFrequency,
ext_iters: u64,
) -> Self {
HotStartCheckpoint {
extension_iters: ext_iters,
frequency,
directory: PathBuf::from(directory.as_ref()),
filename: PathBuf::from(format!("{}.arg", name.as_ref())),
}
}
}

impl<S> Checkpoint<S, EgorState<f64>> for HotStartCheckpoint
where
S: Serialize + DeserializeOwned,
{
/// Writes checkpoint to disk.
///
/// If the directory does not exist already, it will be created. It uses `bincode` to serialize
/// the data.
/// It will return an error if creating the directory or file or serialization failed.
fn save(&self, solver: &S, state: &EgorState<f64>) -> Result<(), Error> {
if !self.directory.exists() {
std::fs::create_dir_all(&self.directory)?
}
let fname = self.directory.join(&self.filename);
let f = BufWriter::new(File::create(fname)?);
bincode::serialize_into(f, &(solver, state))?;
Ok(())
}

/// Load a checkpoint from disk.
///
///
/// If there is no checkpoint on disk, it will return `Ok(None)`.
/// Returns an error if opening the file or deserialization failed.
fn load(&self) -> Result<Option<(S, EgorState<f64>)>, Error> {
let path = &self.directory.join(&self.filename);
if !path.exists() {
return Ok(None);
}
let file = File::open(path)?;
let reader = BufReader::new(file);
let (solver, mut state): (_, EgorState<_>) = bincode::deserialize_from(reader)?;
state.extend_max_iters(self.extension_iters);
Ok(Some((solver, state)))
}

/// Returns the how often a checkpoint is to be saved.
///
/// Used internally by [`save_cond`](`argmin::core::checkpointing::Checkpoint::save_cond`).
fn frequency(&self) -> CheckpointingFrequency {
self.frequency
}
}
2 changes: 2 additions & 0 deletions ego/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod find_result;
mod hot_start;
mod misc;
mod sort_axis;

pub use find_result::*;
pub use hot_start::*;
pub use misc::*;
Loading

0 comments on commit ebdde79

Please sign in to comment.