Skip to content

adding run-make test to autodiff #142444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,10 @@ struct LLVMRustSanitizerOptions {
#ifdef ENZYME
extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB,
/* augmentPassBuilder */ bool);

extern "C" {
extern llvm::cl::opt<std::string> EnzymeFunctionToAnalyze;
}
#endif

extern "C" LLVMRustResult LLVMRustOptimize(
Expand Down Expand Up @@ -1069,6 +1073,15 @@ extern "C" LLVMRustResult LLVMRustOptimize(
return LLVMRustResult::Failure;
}

// Check if PrintTAFn was used and add type analysis pass if needed
if (!EnzymeFunctionToAnalyze.empty()) {
if (auto Err = PB.parsePassPipeline(MPM, "print-type-analysis")) {
std::string ErrMsg = toString(std::move(Err));
LLVMRustSetLastError(ErrMsg.c_str());
return LLVMRustResult::Failure;
}
}

if (PrintAfterEnzyme) {
// Handle the Rust flag `-Zautodiff=PrintModAfter`.
std::string Banner = "Module after EnzymeNewPM";
Expand Down
13 changes: 13 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Autodiff Type-Trees Type Analysis Tests

This directory contains run-make tests for the autodiff type-trees type analysis functionality. These tests verify that the autodiff compiler correctly analyzes and tracks type information for different Rust types during automatic differentiation.

## What These Tests Do

Each test compiles a simple Rust function with the `#[autodiff_reverse]` attribute and verifies that the compiler:

1. **Correctly identifies type information** in the generated LLVM IR
2. **Tracks type annotations** for variables and operations
3. **Preserves type context** through the autodiff transformation process

The tests capture the stdout from the autodiff compiler (which contains type analysis information) and verify it matches expected patterns using FileCheck.
26 changes: 26 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/array/array.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 4, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 8, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 4, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 8, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
20 changes: 20 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/array/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#![feature(autodiff)]

use std::autodiff::autodiff_reverse;

#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
fn callee(x: &[f32; 3]) -> f32 {
x[0] * x[0] + x[1] * x[1] + x[2] * x[2]
}

fn main() {
let x = [1.0f32, 2.0, 3.0];
let mut df_dx = [0.0f32; 3];
let out = callee(&x);
let out_ = d_square(&x, &mut df_dx, 1.0);
assert_eq!(out, out_);
assert_eq!(2.0, df_dx[0]);
assert_eq!(4.0, df_dx[1]);
assert_eq!(6.0, df_dx[2]);
}
28 changes: 28 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/array/rmake.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//@ needs-enzyme
//@ ignore-cross-compile

use std::fs;

use run_make_support::{llvm_filecheck, rfs, rustc};

fn main() {
// Compile the Rust file with the required flags, capturing both stdout and stderr
let output = rustc()
Copy link
Member

@jieyouxu jieyouxu Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: by default, run-make test's rustc() will try to cross-compile for targets tested under CI, which includes some Tier 2 targets that do not support std. You may wish to limit this further if you only wish to check this as a host test, e.g. //@ ignore-cross-compile.

I've not looked at the exact tests, but you may also wish to restrict which platforms these run-make tests run against. For instance, apple and windows Tier 1 targets? Or i686 msvc, etc.

.input("array.rs")
.arg("-Zautodiff=Enable,PrintTAFn=callee")
.arg("-Zautodiff=NoPostopt")
.opt_level("3")
.arg("-Clto=fat")
.arg("-g")
.run();

let stdout = output.stdout_utf8();
let stderr = output.stderr_utf8();

// Write the outputs to files
rfs::write("array.stdout", stdout);
rfs::write("array.stderr", stderr);

// Run FileCheck on the stdout using the check file
llvm_filecheck().patterns("array.check").stdin_buf(rfs::read("array.stdout")).run();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ 0.000000e+00, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x [2 x float]], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x float], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw float, ptr %{{[0-9]+}}, i64 %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ 0.000000e+00, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x [2 x float]], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x float], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw float, ptr %{{[0-9]+}}, i64 %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#![feature(autodiff)]

use std::autodiff::autodiff_reverse;

#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
fn callee(x: &[[[f32; 2]; 2]; 2]) -> f32 {
let mut sum = 0.0;
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
sum += x[i][j][k] * x[i][j][k];
}
}
}
sum
}

fn main() {
let x = [[[1.0f32, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
let mut df_dx = [[[0.0f32; 2]; 2]; 2];
let out = callee(&x);
let out_ = d_square(&x, &mut df_dx, 1.0);
assert_eq!(out, out_);
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
assert_eq!(df_dx[i][j][k], 2.0 * x[i][j][k]);
}
}
}
}
28 changes: 28 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/array3d/rmake.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//@ needs-enzyme
//@ ignore-cross-compile

use std::fs;

use run_make_support::{llvm_filecheck, rfs, rustc};

fn main() {
// Compile the Rust file with the required flags, capturing both stdout and stderr
let output = rustc()
.input("array3d.rs")
.arg("-Zautodiff=Enable,PrintTAFn=callee")
.arg("-Zautodiff=NoPostopt")
.opt_level("3")
.arg("-Clto=fat")
.arg("-g")
.run();

let stdout = output.stdout_utf8();
let stderr = output.stderr_utf8();

// Write the outputs to files
rfs::write("array3d.stdout", stdout);
rfs::write("array3d.stderr", stderr);

// Run FileCheck on the stdout using the check file
llvm_filecheck().patterns("array3d.check").stdin_buf(rfs::read("array3d.stdout")).run();
}
12 changes: 12 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/box/box.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !{{[0-9]+}}, !align !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Float@float}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !{{[0-9]+}}, !align !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
18 changes: 18 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/box/box.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#![feature(autodiff)]

use std::autodiff::autodiff_reverse;

#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
fn callee(x: &Box<f32>) -> f32 {
**x * **x
}

fn main() {
let x = Box::new(7.0f32);
let mut df_dx = Box::new(0.0f32);
let out = callee(&x);
let out_ = d_square(&x, &mut df_dx, 1.0);
assert_eq!(out, out_);
assert_eq!(14.0, *df_dx);
}
28 changes: 28 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/box/rmake.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//@ needs-enzyme
//@ ignore-cross-compile

use std::fs;

use run_make_support::{llvm_filecheck, rfs, rustc};

fn main() {
// Compile the Rust file with the required flags, capturing both stdout and stderr
let output = rustc()
.input("box.rs")
.arg("-Zautodiff=Enable,PrintTAFn=callee")
.arg("-Zautodiff=NoPostopt")
.opt_level("3")
.arg("-Clto=fat")
.arg("-g")
.run();

let stdout = output.stdout_utf8();
let stderr = output.stderr_utf8();

// Write the outputs to files
rfs::write("box.stdout", stdout);
rfs::write("box.stderr", stderr);

// Run FileCheck on the stdout using the check file
llvm_filecheck().patterns("box.check").stdin_buf(rfs::read("box.stdout")).run();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#![feature(autodiff)]

use std::autodiff::autodiff_reverse;

#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
fn callee(x: *const f32) -> f32 {
unsafe { *x * *x }
}

fn main() {
let x: f32 = 7.0;
let out = callee(&x as *const f32);
let mut df_dx: f32 = 0.0;
let out_ = d_square(&x as *const f32, &mut df_dx, 1.0);
assert_eq!(out, out_);
assert_eq!(14.0, df_dx);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//@ needs-enzyme
//@ ignore-cross-compile

use std::fs;

use run_make_support::{llvm_filecheck, rfs, rustc};

fn main() {
// Compile the Rust file with the required flags, capturing both stdout and stderr
let output = rustc()
.input("const_pointer.rs")
.arg("-Zautodiff=Enable,PrintTAFn=callee")
.arg("-Zautodiff=NoPostopt")
.opt_level("3")
.arg("-Clto=fat")
.arg("-g")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: is the exact level of debuginfo important for these tests? If it is, consider explicitly spelling it out as debuginfo_level("2") with a comment, ignore this review comment otherwise.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enzyme has a parser for debug info metadata. Enzyme should "translate" the debug info into "Enzyme Typetrees", which it will then use later for computing derivatives. In reality Enzyme's debug metadata parser was untested for too long and doesn't seem to do anything (based on the output of these tests). Still, I would hard-code the debug info level since it (in theory) could affect the test output.

Also, these tests are meant to capture the status-quo. In the second half Karan will generate TypeTrees directly from MIR, so it's good to have a before/after comparison with the test cases.

.run();

let stdout = output.stdout_utf8();
let stderr = output.stderr_utf8();

// Write the outputs to files
rfs::write("const_pointer.stdout", stdout);
rfs::write("const_pointer.stderr", stderr);

// Run FileCheck on the stdout using the check file
llvm_filecheck()
.patterns("const_pointer.check")
.stdin_buf(rfs::read("const_pointer.stdout"))
.run();
}
12 changes: 12 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/f32/f32.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}

// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}

// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
19 changes: 19 additions & 0 deletions tests/run-make/autodiff/type-trees/type-analysis/f32/f32.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#![feature(autodiff)]

use std::autodiff::autodiff_reverse;

#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
fn callee(x: &f32) -> f32 {
*x * *x
}

fn main() {
let x: f32 = 7.0;
let mut df_dx: f32 = 0.0;
let out = callee(&x);
let out_ = d_square(&x, &mut df_dx, 1.0);

assert_eq!(out, out_);
assert_eq!(14.0, df_dx);
}
Loading
Loading