-
Notifications
You must be signed in to change notification settings - Fork 13.5k
adding run-make test to autodiff #142444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
adding run-make test to autodiff #142444
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Autodiff Type-Trees Type Analysis Tests | ||
|
||
This directory contains run-make tests for the autodiff type-trees type analysis functionality. These tests verify that the autodiff compiler correctly analyzes and tracks type information for different Rust types during automatic differentiation. | ||
|
||
## What These Tests Do | ||
|
||
Each test compiles a simple Rust function with the `#[autodiff_reverse]` attribute and verifies that the compiler: | ||
|
||
1. **Correctly identifies type information** in the generated LLVM IR | ||
2. **Tracks type annotations** for variables and operations | ||
3. **Preserves type context** through the autodiff transformation process | ||
|
||
The tests capture the stdout from the autodiff compiler (which contains type analysis information) and verify it matches expected patterns using FileCheck. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{} | ||
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 4, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 8, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float}:{} | ||
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 4, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 8, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#![feature(autodiff)] | ||
|
||
use std::autodiff::autodiff_reverse; | ||
|
||
#[autodiff_reverse(d_square, Duplicated, Active)] | ||
#[no_mangle] | ||
fn callee(x: &[f32; 3]) -> f32 { | ||
x[0] * x[0] + x[1] * x[1] + x[2] * x[2] | ||
} | ||
|
||
fn main() { | ||
let x = [1.0f32, 2.0, 3.0]; | ||
let mut df_dx = [0.0f32; 3]; | ||
let out = callee(&x); | ||
let out_ = d_square(&x, &mut df_dx, 1.0); | ||
assert_eq!(out, out_); | ||
assert_eq!(2.0, df_dx[0]); | ||
assert_eq!(4.0, df_dx[1]); | ||
assert_eq!(6.0, df_dx[2]); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
//@ needs-enzyme | ||
//@ ignore-cross-compile | ||
|
||
use std::fs; | ||
|
||
use run_make_support::{llvm_filecheck, rfs, rustc}; | ||
|
||
fn main() { | ||
// Compile the Rust file with the required flags, capturing both stdout and stderr | ||
let output = rustc() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: by default, I've not looked at the exact tests, but you may also wish to restrict which platforms these run-make tests run against. For instance, apple and windows Tier 1 targets? Or i686 msvc, etc. |
||
.input("array.rs") | ||
.arg("-Zautodiff=Enable,PrintTAFn=callee") | ||
.arg("-Zautodiff=NoPostopt") | ||
.opt_level("3") | ||
.arg("-Clto=fat") | ||
.arg("-g") | ||
.run(); | ||
|
||
let stdout = output.stdout_utf8(); | ||
let stderr = output.stderr_utf8(); | ||
|
||
// Write the outputs to files | ||
rfs::write("array.stdout", stdout); | ||
rfs::write("array.stderr", stderr); | ||
|
||
// Run FileCheck on the stdout using the check file | ||
llvm_filecheck().patterns("array.check").stdin_buf(rfs::read("array.stdout")).run(); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{} | ||
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK-DAG: %{{[0-9]+}} = phi float [ 0.000000e+00, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer} | ||
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer} | ||
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x [2 x float]], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer} | ||
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer} | ||
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x float], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer} | ||
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer} | ||
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw float, ptr %{{[0-9]+}}, i64 %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float}:{} | ||
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK-DAG: %{{[0-9]+}} = phi float [ 0.000000e+00, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer} | ||
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer} | ||
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x [2 x float]], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer} | ||
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer} | ||
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x float], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer} | ||
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer} | ||
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw float, ptr %{{[0-9]+}}, i64 %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#![feature(autodiff)] | ||
|
||
use std::autodiff::autodiff_reverse; | ||
|
||
#[autodiff_reverse(d_square, Duplicated, Active)] | ||
#[no_mangle] | ||
fn callee(x: &[[[f32; 2]; 2]; 2]) -> f32 { | ||
let mut sum = 0.0; | ||
for i in 0..2 { | ||
for j in 0..2 { | ||
for k in 0..2 { | ||
sum += x[i][j][k] * x[i][j][k]; | ||
} | ||
} | ||
} | ||
sum | ||
} | ||
|
||
fn main() { | ||
let x = [[[1.0f32, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]; | ||
let mut df_dx = [[[0.0f32; 2]; 2]; 2]; | ||
let out = callee(&x); | ||
let out_ = d_square(&x, &mut df_dx, 1.0); | ||
assert_eq!(out, out_); | ||
for i in 0..2 { | ||
for j in 0..2 { | ||
for k in 0..2 { | ||
assert_eq!(df_dx[i][j][k], 2.0 * x[i][j][k]); | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
//@ needs-enzyme | ||
//@ ignore-cross-compile | ||
|
||
use std::fs; | ||
|
||
use run_make_support::{llvm_filecheck, rfs, rustc}; | ||
|
||
fn main() { | ||
// Compile the Rust file with the required flags, capturing both stdout and stderr | ||
let output = rustc() | ||
.input("array3d.rs") | ||
.arg("-Zautodiff=Enable,PrintTAFn=callee") | ||
.arg("-Zautodiff=NoPostopt") | ||
.opt_level("3") | ||
.arg("-Clto=fat") | ||
.arg("-g") | ||
.run(); | ||
|
||
let stdout = output.stdout_utf8(); | ||
let stderr = output.stderr_utf8(); | ||
|
||
// Write the outputs to files | ||
rfs::write("array3d.stdout", stdout); | ||
rfs::write("array3d.stderr", stderr); | ||
|
||
// Run FileCheck on the stdout using the check file | ||
llvm_filecheck().patterns("array3d.check").stdin_buf(rfs::read("array3d.stdout")).run(); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{} | ||
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !{{[0-9]+}}, !align !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Float@float}:{} | ||
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !{{[0-9]+}}, !align !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#![feature(autodiff)] | ||
|
||
use std::autodiff::autodiff_reverse; | ||
|
||
#[autodiff_reverse(d_square, Duplicated, Active)] | ||
#[no_mangle] | ||
fn callee(x: &Box<f32>) -> f32 { | ||
**x * **x | ||
} | ||
|
||
fn main() { | ||
let x = Box::new(7.0f32); | ||
let mut df_dx = Box::new(0.0f32); | ||
let out = callee(&x); | ||
let out_ = d_square(&x, &mut df_dx, 1.0); | ||
assert_eq!(out, out_); | ||
assert_eq!(14.0, *df_dx); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
//@ needs-enzyme | ||
//@ ignore-cross-compile | ||
|
||
use std::fs; | ||
|
||
use run_make_support::{llvm_filecheck, rfs, rustc}; | ||
|
||
fn main() { | ||
// Compile the Rust file with the required flags, capturing both stdout and stderr | ||
let output = rustc() | ||
.input("box.rs") | ||
.arg("-Zautodiff=Enable,PrintTAFn=callee") | ||
.arg("-Zautodiff=NoPostopt") | ||
.opt_level("3") | ||
.arg("-Clto=fat") | ||
.arg("-g") | ||
.run(); | ||
|
||
let stdout = output.stdout_utf8(); | ||
let stderr = output.stderr_utf8(); | ||
|
||
// Write the outputs to files | ||
rfs::write("box.stdout", stdout); | ||
rfs::write("box.stderr", stderr); | ||
|
||
// Run FileCheck on the stdout using the check file | ||
llvm_filecheck().patterns("box.check").stdin_buf(rfs::read("box.stdout")).run(); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{} | ||
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float}:{} | ||
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#![feature(autodiff)] | ||
|
||
use std::autodiff::autodiff_reverse; | ||
|
||
#[autodiff_reverse(d_square, Duplicated, Active)] | ||
#[no_mangle] | ||
fn callee(x: *const f32) -> f32 { | ||
unsafe { *x * *x } | ||
} | ||
|
||
fn main() { | ||
let x: f32 = 7.0; | ||
let out = callee(&x as *const f32); | ||
let mut df_dx: f32 = 0.0; | ||
let out_ = d_square(&x as *const f32, &mut df_dx, 1.0); | ||
assert_eq!(out, out_); | ||
assert_eq!(14.0, df_dx); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
//@ needs-enzyme | ||
//@ ignore-cross-compile | ||
|
||
use std::fs; | ||
|
||
use run_make_support::{llvm_filecheck, rfs, rustc}; | ||
|
||
fn main() { | ||
// Compile the Rust file with the required flags, capturing both stdout and stderr | ||
let output = rustc() | ||
.input("const_pointer.rs") | ||
.arg("-Zautodiff=Enable,PrintTAFn=callee") | ||
.arg("-Zautodiff=NoPostopt") | ||
.opt_level("3") | ||
.arg("-Clto=fat") | ||
.arg("-g") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: is the exact level of debuginfo important for these tests? If it is, consider explicitly spelling it out as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Enzyme has a parser for debug info metadata. Enzyme should "translate" the debug info into "Enzyme Typetrees", which it will then use later for computing derivatives. In reality Enzyme's debug metadata parser was untested for too long and doesn't seem to do anything (based on the output of these tests). Still, I would hard-code the debug info level since it (in theory) could affect the test output. Also, these tests are meant to capture the status-quo. In the second half Karan will generate TypeTrees directly from MIR, so it's good to have a before/after comparison with the test cases. |
||
.run(); | ||
|
||
let stdout = output.stdout_utf8(); | ||
let stderr = output.stderr_utf8(); | ||
|
||
// Write the outputs to files | ||
rfs::write("const_pointer.stdout", stdout); | ||
rfs::write("const_pointer.stderr", stderr); | ||
|
||
// Run FileCheck on the stdout using the check file | ||
llvm_filecheck() | ||
.patterns("const_pointer.check") | ||
.stdin_buf(rfs::read("const_pointer.stdout")) | ||
.run(); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{} | ||
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
|
||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} | ||
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float}:{} | ||
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float} | ||
|
||
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float} | ||
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#![feature(autodiff)] | ||
|
||
use std::autodiff::autodiff_reverse; | ||
|
||
#[autodiff_reverse(d_square, Duplicated, Active)] | ||
#[no_mangle] | ||
fn callee(x: &f32) -> f32 { | ||
*x * *x | ||
} | ||
|
||
fn main() { | ||
let x: f32 = 7.0; | ||
let mut df_dx: f32 = 0.0; | ||
let out = callee(&x); | ||
let out_ = d_square(&x, &mut df_dx, 1.0); | ||
|
||
assert_eq!(out, out_); | ||
assert_eq!(14.0, df_dx); | ||
} |
Uh oh!
There was an error while loading. Please reload this page.