Skip to content

Commit

Permalink
add more tests for CI
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Mar 31, 2024
1 parent 5bf6664 commit 08352cd
Show file tree
Hide file tree
Showing 23 changed files with 590 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[workspace]
resolver = "2"
members = [
members = [ "integration",
"samples",
]

Expand Down
8 changes: 8 additions & 0 deletions integration/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[package]
name = "integration"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
10 changes: 10 additions & 0 deletions integration/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#[macro_export]
macro_rules! test {
($m:ident; $($func:item)*) => {
mod $m {
$($func)*
#[test]
fn run() { main() }
}
};
}
23 changes: 23 additions & 0 deletions integration/tests/examples/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#![feature(autodiff)]

#[autodiff(d_array, Reverse, Active, Duplicated)]
fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 {
arr[0][0][0] * arr[1][1][1]
}

fn main() {
let arr = [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]];
let mut d_arr = [[[0.0; 2]; 2]; 2];

d_array(&arr, &mut d_arr, 1.0);

dbg!(&d_arr);
}

#[cfg(test)]
mod tests {
#[test]
fn main() {
super::main()
}
}
30 changes: 30 additions & 0 deletions integration/tests/examples/bad_add.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#![feature(autodiff)]

pub fn add(a: i32, b: i32) -> i32 {
a + b
}

// This is a really bad adding function, its purpose is to fail in this
// example.
#[allow(dead_code)]
fn bad_add(a: i32, b: i32) -> i32 {
a - b
}

#[cfg(test)]
mod tests {
// Note this useful idiom: importing names from outer (for mod tests) scope.
use super::*;

#[test]
fn test_add() {
assert_eq!(add(1, 2), 3);
}

#[test]
fn test_bad_add() {
// This assert would fire and test will fail.
// Please note, that private functions can be tested too!
assert_eq!(bad_add(1, 2), 3);
}
}
25 changes: 25 additions & 0 deletions integration/tests/examples/box.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#![feature(autodiff)]

//#[autodiff(cos_box, Reverse, Active, Duplicated)]
#[autodiff(cos_box, Reverse, Duplicated, Active)]
fn sin(x: &Box<f32>) -> f32 {
f32::sin(**x)
}

fn main() {
let x = Box::<f32>::new(3.14);
let mut df_dx = Box::<f32>::new(0.0);
cos_box(&x, &mut df_dx, 1.0);

dbg!(&df_dx);

assert!(*df_dx == f32::cos(*x));
}

#[cfg(test)]
mod tests {
#[test]
fn main() {
super::main()
}
}
35 changes: 35 additions & 0 deletions integration/tests/examples/broken_matvec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#![feature(autodiff)]


type Matrix = Vec<Vec<f32>>;
type Vector = Vec<f32>;

#[autodiff(d_matvec, Forward, Dual, Const, Dual)]
fn matvec(mat: &Matrix, vec: &Vector, out: &mut Vector) {
for i in 0..mat.len() - 1 {
for j in 0..mat[0].len() - 1 {
out[i] += mat[i][j] * vec[j];
}
}
}

fn main() {
let mat = vec![vec![1.0, 1.0], vec![1.0, 1.0]];
let mut d_mat = vec![vec![0.0, 0.0], vec![0.0, 0.0]];
let inp = vec![1.0, 1.0];
let mut out = vec![0.0, 0.0];
let mut out_tang = vec![0.0, 1.0];

//matvec(&mat, &inp, &mut out);
d_matvec(&mat, &mut d_mat, &inp, &mut out, &mut out_tang);

dbg!(&out);
}

#[cfg(test)]
mod tests {
#[test]
fn main() {
super::main()
}
}
38 changes: 38 additions & 0 deletions integration/tests/examples/enum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#![feature(autodiff)]

enum Foo {
A(f32),
B(i32),
}

#[autodiff(d_bar, Reverse, Duplicated, Active)]
fn bar(x: &f32) -> f32 {
let val: Foo =
if *x > 0.0 {
Foo::A(*x)
} else {
Foo::B(12)
};

std::hint::black_box(&val);
match val {
Foo::A(f) => f * f,
Foo::B(_) => 4.0,
}
}

fn main() {
let x = 1.0;
let x2 = -1.0;
let mut dx = 0.0;
let mut dx2 = 0.0;
let out = bar(&x);
let dout = d_bar(&x, &mut dx, 1.0);
let dout2 = d_bar(&x2, &mut dx2, 1.0);
println!("x: {out}");
println!("dx: {dout}");
println!("dx2: {dout2}");
assert_eq!(dx, 2.0);
assert_eq!(dx2, 0.0);
}

54 changes: 54 additions & 0 deletions integration/tests/examples/enum2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#![feature(autodiff)]

#[derive(Debug, PartialEq)]
enum Foo {
A(f32),
B(i32),
}

#[autodiff(d_bar, Reverse, Duplicated, Active)]
fn bar(x: &Foo) -> f32 {








match x {
Foo::A(f) => f * f,
Foo::B(_) => 4.0,
}
}

fn main() {
let x = Foo::A(1.0);
let x2 = Foo::A(-1.0);
let x3 = Foo::B(1);
let x4 = Foo::B(1);
let x5 = Foo::A(1.0);
let mut dx = Foo::A(0.0);
let mut dx2 = Foo::A(0.0);
let mut dx3 = Foo::B(0);
let mut dx4 = Foo::A(0.0);
let mut dx5 = Foo::B(0);
let out = bar(&x);
let dout = d_bar(&x, &mut dx, 1.0);
let dout2 = d_bar(&x2, &mut dx2, 1.0);
let dout3 = d_bar(&x3, &mut dx3, 1.0);
let dout4 = d_bar(&x4, &mut dx4, 1.0);
let dout5 = d_bar(&x5, &mut dx5, 1.0);
println!("x: {out}");
println!("dx: {dout}");
println!("dx2: {dout2}");
println!("dx3: {dout3}");
println!("dx4: {dout4}");
println!("dx5: {dout5}");
assert_eq!(dx, Foo::A(2.0));
assert_eq!(dx2, Foo::A(-2.0));
assert_eq!(dx3, Foo::B(0));
assert_eq!(dx4, Foo::A(0.0));
assert_eq!(dx5, Foo::B(0));
}

26 changes: 26 additions & 0 deletions integration/tests/examples/enum3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#![feature(autodiff)]

struct Bar {
x: f32,
y: bool,
}

#[autodiff(df, Reverse, Duplicated, Active)]
fn f(x: &Bar) -> f32 {
if x.y {
x.x * x.x
} else {
4.0
}
}

fn main() {
let a = Bar { x: 3.0, y: true };
let mut da_good = Bar { x: 0.0, y: true };
let mut da_bad = Bar { x: 0.0, y: false };
let dx = df(&a, &mut da_good, 1.0);
let dx2 = df(&a, &mut da_bad, 1.0);
println!("good: {:?}", da_good.x);
println!("bad: {:?}", da_bad.x);
println!("bool values good/bad: {:?} {:?}", da_good.y, da_bad.y);
}
26 changes: 26 additions & 0 deletions integration/tests/examples/enum4.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#![feature(autodiff)]

struct Bar {
x: f32,
y: bool,
}

#[autodiff(df, Reverse, Duplicated, Duplicated, Active)]
fn f(x: &Bar, val: bool) -> f32 {
if val {
x.x
} else {
4.0
}
}

fn main() {
//let a = Bar { x: 1.0, y: true };
//let mut da_good = Bar { x: 0.0, y: true };
//let mut da_bad = Bar { x: 0.0, y: false };
//let dx = df(&a, &mut da_good, 1.0);
//let dx2 = df(&a, &mut da_bad, 1.0);
//println!("good: {:?}", da_good.x);
//println!("bad: {:?}", da_bad.x);
//println!("bool values good/bad: {:?} {:?}", da_good.y, da_bad.y);
}
11 changes: 11 additions & 0 deletions integration/tests/examples/foo.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#![feature(autodiff)]

#[autodiff(df, Forward, Dual, Dual)]
fn f(x: &[f32]) -> f32 { x[0] * x[0] + x[1] * x[0] }

fn main() {
let x = [2.0, 2.0];
let dx = [1.0, 0.0];
let (y, dy) = df(&x, &dx);
assert_eq!(dy, 2.0 * x[0] + x[1]);
}
29 changes: 29 additions & 0 deletions integration/tests/examples/hessian_sin.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#![feature(rustc_attrs)]
#![feature(autodiff)]

fn sin(x: &Vec<f32>, y: &mut f32) {
*y = x.into_iter().map(|x| f32::sin(*x)).sum()
}

#[autodiff(sin, Reverse, Const, Duplicated, Duplicated)]
fn jac(x: &Vec<f32>, d_x: &mut Vec<f32>, y: &mut f32, y_t: &f32);

#[autodiff(jac, Forward, Const, Dual, Const, Const, Const)]
fn hessian(x: &Vec<f32>, y_x: &Vec<f32>, d_x: &mut Vec<f32>, y: &mut f32, y_t: &f32);

fn main() {
let inp = vec![3.1415 / 2., 1.0, 0.5];
let mut d_inp = vec![0.0, 0.0, 0.0];
let mut y = 0.0;
let tang = vec![1.0, 0.0, 0.0];
hessian(&inp, &tang, &mut d_inp, &mut y, &1.0);
dbg!(&d_inp);
}

#[cfg(test)]
mod tests {
#[test]
fn main() {
super::main()
}
}
17 changes: 17 additions & 0 deletions integration/tests/examples/jed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#![feature(autodiff)]

#[autodiff(d_square, Reverse, Duplicated, Active)]
fn square(x: &f64) -> f64 {
x.powi(2)
}

fn main() {
let x = 3.0;
let output = square(&x);
println!("{output}");

let mut df_dx = 0.0;
d_square(&x, &mut df_dx, 1.0);
println!("df_dx: {:?}", df_dx);
}

26 changes: 26 additions & 0 deletions integration/tests/examples/ndarray.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#![feature(autodiff)]

use ndarray::Array1;

//#[autodiff(d_collect, Reverse, Active)]
#[autodiff(d_collect, Reverse, Duplicated, Active)]
fn collect(x: &Array1<f32>) -> f32 {
x[0]
}

fn main() {
let a = Array1::zeros(19);
let mut d_a = Array1::zeros(19);

d_collect(&a, &mut d_a, 1.0);

dbg!(&d_a);
}

#[cfg(test)]
mod tests {
#[test]
fn main() {
super::main()
}
}
Loading

0 comments on commit 08352cd

Please sign in to comment.