forked from yalue/onnxruntime_go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathonnxruntime_wrapper.c
76 lines (63 loc) · 2.22 KB
/
onnxruntime_wrapper.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include "onnxruntime_wrapper.h"
static const OrtApi *ort_api = NULL;
int SetAPIFromBase(OrtApiBase *api_base) {
if (!api_base) return 1;
ort_api = api_base->GetApi(ORT_API_VERSION);
if (!ort_api) return 2;
return 0;
}
void ReleaseOrtStatus(OrtStatus *status) {
ort_api->ReleaseStatus(status);
}
OrtStatus *CreateOrtEnv(char *name, OrtEnv **env) {
return ort_api->CreateEnv(ORT_LOGGING_LEVEL_ERROR, name, env);
}
void ReleaseOrtEnv(OrtEnv *env) {
ort_api->ReleaseEnv(env);
}
OrtStatus *CreateOrtMemoryInfo(OrtMemoryInfo **mem_info) {
return ort_api->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault,
mem_info);
}
void ReleaseOrtMemoryInfo(OrtMemoryInfo *info) {
ort_api->ReleaseMemoryInfo(info);
}
const char *GetErrorMessage(OrtStatus *status) {
if (!status) return "No error (NULL status)";
return ort_api->GetErrorMessage(status);
}
OrtStatus *CreateSession(void *model_data, size_t model_data_length,
OrtEnv *env, OrtSession **out) {
OrtStatus *status = NULL;
OrtSessionOptions *options = NULL;
status = ort_api->CreateSessionOptions(&options);
if (status) return status;
status = ort_api->CreateSessionFromArray(env, model_data, model_data_length,
options, out);
// It's OK to release the session options now, right? The docs don't say.
ort_api->ReleaseSessionOptions(options);
return status;
}
OrtStatus *RunOrtSession(OrtSession *session,
OrtValue **inputs, char **input_names, int input_count,
OrtValue **outputs, char **output_names, int output_count) {
OrtStatus *status = NULL;
status = ort_api->Run(session, NULL, (const char* const*) input_names,
(const OrtValue* const*) inputs, input_count,
(const char* const*) output_names, output_count, outputs);
return status;
}
void ReleaseOrtSession(OrtSession *session) {
ort_api->ReleaseSession(session);
}
void ReleaseOrtValue(OrtValue *value) {
ort_api->ReleaseValue(value);
}
OrtStatus *CreateOrtTensorWithShape(void *data, size_t data_size,
int64_t *shape, int64_t shape_size, OrtMemoryInfo *mem_info,
ONNXTensorElementDataType dtype, OrtValue **out) {
OrtStatus *status = NULL;
status = ort_api->CreateTensorWithDataAsOrtValue(mem_info, data, data_size,
shape, shape_size, dtype, out);
return status;
}