-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
23 changed files
with
590 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[workspace] | ||
resolver = "2" | ||
members = [ | ||
members = [ "integration", | ||
"samples", | ||
] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
[package] | ||
name = "integration" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#[macro_export] | ||
macro_rules! test { | ||
($m:ident; $($func:item)*) => { | ||
mod $m { | ||
$($func)* | ||
#[test] | ||
fn run() { main() } | ||
} | ||
}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#![feature(autodiff)] | ||
|
||
#[autodiff(d_array, Reverse, Active, Duplicated)] | ||
fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 { | ||
arr[0][0][0] * arr[1][1][1] | ||
} | ||
|
||
fn main() { | ||
let arr = [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]; | ||
let mut d_arr = [[[0.0; 2]; 2]; 2]; | ||
|
||
d_array(&arr, &mut d_arr, 1.0); | ||
|
||
dbg!(&d_arr); | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
#[test] | ||
fn main() { | ||
super::main() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#![feature(autodiff)] | ||
|
||
pub fn add(a: i32, b: i32) -> i32 { | ||
a + b | ||
} | ||
|
||
// This is a really bad adding function, its purpose is to fail in this | ||
// example. | ||
#[allow(dead_code)] | ||
fn bad_add(a: i32, b: i32) -> i32 { | ||
a - b | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
// Note this useful idiom: importing names from outer (for mod tests) scope. | ||
use super::*; | ||
|
||
#[test] | ||
fn test_add() { | ||
assert_eq!(add(1, 2), 3); | ||
} | ||
|
||
#[test] | ||
fn test_bad_add() { | ||
// This assert would fire and test will fail. | ||
// Please note, that private functions can be tested too! | ||
assert_eq!(bad_add(1, 2), 3); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#![feature(autodiff)] | ||
|
||
//#[autodiff(cos_box, Reverse, Active, Duplicated)] | ||
#[autodiff(cos_box, Reverse, Duplicated, Active)] | ||
fn sin(x: &Box<f32>) -> f32 { | ||
f32::sin(**x) | ||
} | ||
|
||
fn main() { | ||
let x = Box::<f32>::new(3.14); | ||
let mut df_dx = Box::<f32>::new(0.0); | ||
cos_box(&x, &mut df_dx, 1.0); | ||
|
||
dbg!(&df_dx); | ||
|
||
assert!(*df_dx == f32::cos(*x)); | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
#[test] | ||
fn main() { | ||
super::main() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#![feature(autodiff)] | ||
|
||
|
||
type Matrix = Vec<Vec<f32>>; | ||
type Vector = Vec<f32>; | ||
|
||
#[autodiff(d_matvec, Forward, Dual, Const, Dual)] | ||
fn matvec(mat: &Matrix, vec: &Vector, out: &mut Vector) { | ||
for i in 0..mat.len() - 1 { | ||
for j in 0..mat[0].len() - 1 { | ||
out[i] += mat[i][j] * vec[j]; | ||
} | ||
} | ||
} | ||
|
||
fn main() { | ||
let mat = vec![vec![1.0, 1.0], vec![1.0, 1.0]]; | ||
let mut d_mat = vec![vec![0.0, 0.0], vec![0.0, 0.0]]; | ||
let inp = vec![1.0, 1.0]; | ||
let mut out = vec![0.0, 0.0]; | ||
let mut out_tang = vec![0.0, 1.0]; | ||
|
||
//matvec(&mat, &inp, &mut out); | ||
d_matvec(&mat, &mut d_mat, &inp, &mut out, &mut out_tang); | ||
|
||
dbg!(&out); | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
#[test] | ||
fn main() { | ||
super::main() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#![feature(autodiff)] | ||
|
||
enum Foo { | ||
A(f32), | ||
B(i32), | ||
} | ||
|
||
#[autodiff(d_bar, Reverse, Duplicated, Active)] | ||
fn bar(x: &f32) -> f32 { | ||
let val: Foo = | ||
if *x > 0.0 { | ||
Foo::A(*x) | ||
} else { | ||
Foo::B(12) | ||
}; | ||
|
||
std::hint::black_box(&val); | ||
match val { | ||
Foo::A(f) => f * f, | ||
Foo::B(_) => 4.0, | ||
} | ||
} | ||
|
||
fn main() { | ||
let x = 1.0; | ||
let x2 = -1.0; | ||
let mut dx = 0.0; | ||
let mut dx2 = 0.0; | ||
let out = bar(&x); | ||
let dout = d_bar(&x, &mut dx, 1.0); | ||
let dout2 = d_bar(&x2, &mut dx2, 1.0); | ||
println!("x: {out}"); | ||
println!("dx: {dout}"); | ||
println!("dx2: {dout2}"); | ||
assert_eq!(dx, 2.0); | ||
assert_eq!(dx2, 0.0); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#![feature(autodiff)] | ||
|
||
#[derive(Debug, PartialEq)] | ||
enum Foo { | ||
A(f32), | ||
B(i32), | ||
} | ||
|
||
#[autodiff(d_bar, Reverse, Duplicated, Active)] | ||
fn bar(x: &Foo) -> f32 { | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
match x { | ||
Foo::A(f) => f * f, | ||
Foo::B(_) => 4.0, | ||
} | ||
} | ||
|
||
fn main() { | ||
let x = Foo::A(1.0); | ||
let x2 = Foo::A(-1.0); | ||
let x3 = Foo::B(1); | ||
let x4 = Foo::B(1); | ||
let x5 = Foo::A(1.0); | ||
let mut dx = Foo::A(0.0); | ||
let mut dx2 = Foo::A(0.0); | ||
let mut dx3 = Foo::B(0); | ||
let mut dx4 = Foo::A(0.0); | ||
let mut dx5 = Foo::B(0); | ||
let out = bar(&x); | ||
let dout = d_bar(&x, &mut dx, 1.0); | ||
let dout2 = d_bar(&x2, &mut dx2, 1.0); | ||
let dout3 = d_bar(&x3, &mut dx3, 1.0); | ||
let dout4 = d_bar(&x4, &mut dx4, 1.0); | ||
let dout5 = d_bar(&x5, &mut dx5, 1.0); | ||
println!("x: {out}"); | ||
println!("dx: {dout}"); | ||
println!("dx2: {dout2}"); | ||
println!("dx3: {dout3}"); | ||
println!("dx4: {dout4}"); | ||
println!("dx5: {dout5}"); | ||
assert_eq!(dx, Foo::A(2.0)); | ||
assert_eq!(dx2, Foo::A(-2.0)); | ||
assert_eq!(dx3, Foo::B(0)); | ||
assert_eq!(dx4, Foo::A(0.0)); | ||
assert_eq!(dx5, Foo::B(0)); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#![feature(autodiff)] | ||
|
||
struct Bar { | ||
x: f32, | ||
y: bool, | ||
} | ||
|
||
#[autodiff(df, Reverse, Duplicated, Active)] | ||
fn f(x: &Bar) -> f32 { | ||
if x.y { | ||
x.x * x.x | ||
} else { | ||
4.0 | ||
} | ||
} | ||
|
||
fn main() { | ||
let a = Bar { x: 3.0, y: true }; | ||
let mut da_good = Bar { x: 0.0, y: true }; | ||
let mut da_bad = Bar { x: 0.0, y: false }; | ||
let dx = df(&a, &mut da_good, 1.0); | ||
let dx2 = df(&a, &mut da_bad, 1.0); | ||
println!("good: {:?}", da_good.x); | ||
println!("bad: {:?}", da_bad.x); | ||
println!("bool values good/bad: {:?} {:?}", da_good.y, da_bad.y); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#![feature(autodiff)] | ||
|
||
struct Bar { | ||
x: f32, | ||
y: bool, | ||
} | ||
|
||
#[autodiff(df, Reverse, Duplicated, Duplicated, Active)] | ||
fn f(x: &Bar, val: bool) -> f32 { | ||
if val { | ||
x.x | ||
} else { | ||
4.0 | ||
} | ||
} | ||
|
||
fn main() { | ||
//let a = Bar { x: 1.0, y: true }; | ||
//let mut da_good = Bar { x: 0.0, y: true }; | ||
//let mut da_bad = Bar { x: 0.0, y: false }; | ||
//let dx = df(&a, &mut da_good, 1.0); | ||
//let dx2 = df(&a, &mut da_bad, 1.0); | ||
//println!("good: {:?}", da_good.x); | ||
//println!("bad: {:?}", da_bad.x); | ||
//println!("bool values good/bad: {:?} {:?}", da_good.y, da_bad.y); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#![feature(autodiff)] | ||
|
||
#[autodiff(df, Forward, Dual, Dual)] | ||
fn f(x: &[f32]) -> f32 { x[0] * x[0] + x[1] * x[0] } | ||
|
||
fn main() { | ||
let x = [2.0, 2.0]; | ||
let dx = [1.0, 0.0]; | ||
let (y, dy) = df(&x, &dx); | ||
assert_eq!(dy, 2.0 * x[0] + x[1]); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#![feature(rustc_attrs)] | ||
#![feature(autodiff)] | ||
|
||
fn sin(x: &Vec<f32>, y: &mut f32) { | ||
*y = x.into_iter().map(|x| f32::sin(*x)).sum() | ||
} | ||
|
||
#[autodiff(sin, Reverse, Const, Duplicated, Duplicated)] | ||
fn jac(x: &Vec<f32>, d_x: &mut Vec<f32>, y: &mut f32, y_t: &f32); | ||
|
||
#[autodiff(jac, Forward, Const, Dual, Const, Const, Const)] | ||
fn hessian(x: &Vec<f32>, y_x: &Vec<f32>, d_x: &mut Vec<f32>, y: &mut f32, y_t: &f32); | ||
|
||
fn main() { | ||
let inp = vec![3.1415 / 2., 1.0, 0.5]; | ||
let mut d_inp = vec![0.0, 0.0, 0.0]; | ||
let mut y = 0.0; | ||
let tang = vec![1.0, 0.0, 0.0]; | ||
hessian(&inp, &tang, &mut d_inp, &mut y, &1.0); | ||
dbg!(&d_inp); | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
#[test] | ||
fn main() { | ||
super::main() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#![feature(autodiff)] | ||
|
||
#[autodiff(d_square, Reverse, Duplicated, Active)] | ||
fn square(x: &f64) -> f64 { | ||
x.powi(2) | ||
} | ||
|
||
fn main() { | ||
let x = 3.0; | ||
let output = square(&x); | ||
println!("{output}"); | ||
|
||
let mut df_dx = 0.0; | ||
d_square(&x, &mut df_dx, 1.0); | ||
println!("df_dx: {:?}", df_dx); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#![feature(autodiff)] | ||
|
||
use ndarray::Array1; | ||
|
||
//#[autodiff(d_collect, Reverse, Active)] | ||
#[autodiff(d_collect, Reverse, Duplicated, Active)] | ||
fn collect(x: &Array1<f32>) -> f32 { | ||
x[0] | ||
} | ||
|
||
fn main() { | ||
let a = Array1::zeros(19); | ||
let mut d_a = Array1::zeros(19); | ||
|
||
d_collect(&a, &mut d_a, 1.0); | ||
|
||
dbg!(&d_a); | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
#[test] | ||
fn main() { | ||
super::main() | ||
} | ||
} |
Oops, something went wrong.