Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LRSchedulers #2603

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion candle-nn/examples/basic_optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ extern crate intel_mkl_src;
extern crate accelerate_src;

use candle::{DType, Device, Result, Tensor};
use candle_nn::{linear, AdamW, Linear, Module, Optimizer, ParamsAdamW, VarBuilder, VarMap};
use candle_nn::{
linear, AdamW, FnLRScheduler, LRScheduler, Linear, Module, Optimizer, ParamsAdamW, VarBuilder,
VarMap,
};

fn gen_data() -> Result<(Tensor, Tensor)> {
// Generate some sample linear data.
Expand All @@ -29,7 +32,12 @@ fn main() -> Result<()> {
..Default::default()
};
let mut opt = AdamW::new(varmap.all_vars(), params)?;
let mut scheduler = FnLRScheduler::<usize>::new(Box::new(|step| {
Ok(0.2 * 0.9f64.powi((step as f64 / 1000f64).floor() as i32))
}));

for step in 0..10000 {
opt.set_learning_rate(scheduler.step(step)?);
let ys = model.forward(&sample_xs)?;
let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
opt.backward_step(&loss)?;
Expand Down
2 changes: 2 additions & 0 deletions candle-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod linear;
pub mod loss;
pub mod ops;
pub mod optim;
pub mod scheduler;
pub mod rnn;
pub mod rotary_emb;
pub mod sequential;
Expand All @@ -33,6 +34,7 @@ pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use linear::{linear, linear_b, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
pub use scheduler::{FnLRScheduler, LRScheduler};
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
pub use sequential::{seq, Sequential};
pub use var_builder::VarBuilder;
Expand Down
145 changes: 145 additions & 0 deletions candle-nn/src/scheduler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use candle::Result;

/// The interface LR Schedulers should implement.
pub trait LRScheduler<T> {
/// Step the scheduler and return the new learning rate.
fn step(&mut self, params: T) -> Result<f64>;

/// Get the current learning rate.
fn get_lr(&self) -> f64;
}

/// A learning rate scheduler that uses a function to determine the learning rate.
/// The function should take a parameter of type `T` and return a `f64`.
pub struct FnLRScheduler<T> {
pub func: Box<dyn Fn(T) -> Result<f64>>,
pub lr: f64,
}

impl<T> FnLRScheduler<T> {
pub fn new(func: Box<dyn Fn(T) -> Result<f64>>) -> Self {
Self { func, lr: 0.0 }
}
}

impl<T> LRScheduler<T> for FnLRScheduler<T> {
fn step(&mut self, params: T) -> Result<f64> {
self.lr = (self.func)(params)?;
Ok(self.lr)
}

fn get_lr(&self) -> f64 {
self.lr
}
}

/// Decays the learning rate of each parameter group by gamma every step_size epochs.
// https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html#torch.optim.lr_scheduler.StepLR
pub struct StepLR {
step_size: usize,
last_epoch: usize,
gamma: f64,
lr: f64,
}

impl StepLR {
pub fn new(step_size: usize, gamma: f64, lr: f64) -> Self {
Self {
step_size,
last_epoch: 0,
gamma,
lr,
}
}
}

impl LRScheduler<()> for StepLR {
fn step(&mut self, _params: ()) -> Result<f64> {
self.last_epoch += 1;
if self.last_epoch % self.step_size == 0 {
self.lr *= self.gamma;
}
Ok(self.lr)
}

fn get_lr(&self) -> f64 {
self.lr
}
}
/// Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones.
// https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR.html#torch.optim.lr_scheduler.MultiStepLR
pub struct MultiStepLR {
millstones: Vec<usize>,
gamma: f64,
last_epoch: usize,
lr: f64,
}

impl MultiStepLR {
pub fn new(millstones: Vec<usize>, gamma: f64, lr: f64) -> Result<Self> {
// Ensure millstones are sorted.
if !millstones.is_sorted() {
candle::bail!("millstones should be sorted")
}

Ok(Self {
millstones,
gamma,
last_epoch: 0,
lr,
})
}
}

impl LRScheduler<()> for MultiStepLR {
fn step(&mut self, _params: ()) -> Result<f64> {
self.last_epoch += 1;
if let Some(step) = self.millstones.first() {
if self.last_epoch == *step {
self.millstones.remove(0);
self.lr *= self.gamma;
}
}
Ok(self.lr)
}

fn get_lr(&self) -> f64 {
self.lr
}
}

/// Set the learning rate of each parameter group using a cosine annealing schedule.
//https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html#torch.optim.lr_scheduler.CosineAnnealingLR
pub struct CosineAnnealingLR {
t_max: usize,
last_epoch: usize,
eta_min: f64,
lr: f64,
}

impl CosineAnnealingLR {
pub fn new(t_max: usize, eta_min: f64, lr: f64) -> Self {
Self {
t_max,
last_epoch: 0,
eta_min,
lr,
}
}
}

impl LRScheduler<()> for CosineAnnealingLR {
fn step(&mut self, _params: ()) -> Result<f64> {
self.lr = self.eta_min
+ 0.5
* (self.lr - self.eta_min)
* (1. + ((self.last_epoch as f64 / self.t_max as f64) * std::f64::consts::PI)).cos();
self.last_epoch += 1;
self.last_epoch = self.last_epoch.min(self.t_max);
Ok(self.lr)
}

fn get_lr(&self) -> f64 {
self.lr
}
}