-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgrad4.cxx
109 lines (86 loc) · 2.51 KB
/
grad4.cxx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include <json.hpp>
#include <fstream>
#include <iostream>
#include <map>
#include <apex/autodiff_codegen.hxx>
#include <dirent.h>
inline std::string get_extension(const std::string& filename) {
return filename.substr(filename.find_last_of(".") + 1);
}
inline bool match_extension(const char* filename, const char* ext) {
return ext == get_extension(filename);
}
// Parse the JSON file and keep it open in j.
using nlohmann::json;
using apex::sq;
struct vec3_t {
double x, y, z;
};
// Record the function names encountered in here!
@meta std::vector<std::string> func_names;
@macro void gen_functions(const char* filename) {
// Open this file at compile time and parse as JSON.
@meta std::ifstream json_file(filename);
@meta json j;
@meta json_file>> j;
@meta for(auto& item : j.items()) {
// For each item in the json...
@meta std::string name = item.key();
@meta std::string f = item.value();
@meta std::cout<< "Injecting '"<< name<< "' : '"<< f<< "' from "<<
filename<< "\n";
// Generate a function from the expression.
extern "C" double @("f_" + name)(vec3_t v) {
double x = v.x, y = v.y, z = v.z;
return @expression(f);
}
// Generate a function to return the gradient.
extern "C" vec3_t @("grad_" + name)(vec3_t v) {
return apex::autodiff_grad(f.c_str(), v);
}
@meta func_names.push_back(name);
}
}
// Use Circle like a build system:
// Open the current directory.
@meta DIR* dir = opendir(".");
// Loop over all files in the current directory.
@meta while(dirent* ent = readdir(dir)) {
// Match .json files.
@meta if(match_extension(ent->d_name, "json")) {
// Generate functions for all entries in this json file.
@macro gen_functions(ent->d_name);
}
}
@meta closedir(dir);
std::pair<double, vec3_t> eval(const char* name, vec3_t v) {
@meta for(const std::string& f : func_names) {
if(!strcmp(name, @string(f))) {
return {
@("f_" + f)(v),
@("grad_" + f)(v)
};
}
}
printf("Unknown function %s\n", name);
exit(1);
}
void print_usage() {
printf(" Usage: grad3 name x y z\n");
exit(1);
}
int main(int argc, char** argv) {
if(5 != argc)
print_usage();
const char* f = argv[1];
double x = atof(argv[2]);
double y = atof(argv[3]);
double z = atof(argv[4]);
vec3_t v { x, y, z };
auto result = eval(f, v);
double val = result.first;
vec3_t grad = result.second;
printf(" f: %f\n", val);
printf(" grad: { %f, %f, %f }\n", grad.x, grad.y, grad.z);
return 0;
}