-
Notifications
You must be signed in to change notification settings - Fork 337
/
Copy pathExecutionSession.hpp
133 lines (110 loc) · 4.86 KB
/
ExecutionSession.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
/*
* SPDX-License-Identifier: Apache-2.0
*/
//===--------- ExecutionSession.hpp - ExecutionSession Declaration --------===//
//
// Copyright 2019-2024 The IBM Research Authors.
//
// =============================================================================
//
// This file contains declarations of ExecutionSession class, which helps C++
// programs interact with compiled binary model libraries.
//
//===----------------------------------------------------------------------===//
#ifndef ONNX_MLIR_EXECUTION_SESSION_H
#define ONNX_MLIR_EXECUTION_SESSION_H
#include <cassert>
#include <memory>
#include <string>
#include "OnnxMlirRuntime.h"
// LLVM provides the wrapper class, llvm::sys::DynamicLibrary, for dynamic
// library. Use this library only on Windows. Therefore, the Runtime does not
// depend on llvm for Runtime component.
#if defined(_WIN32)
#include "llvm/Support/DynamicLibrary.h"
typedef llvm::sys::DynamicLibrary DynamicLibraryHandleType;
#else
typedef void *DynamicLibraryHandleType;
#endif
namespace onnx_mlir {
using entryPointFuncType = OMTensorList *(*)(OMTensorList *);
using queryEntryPointsFuncType = const char **(*)(int64_t *);
using signatureFuncType = const char *(*)(const char *);
using OMTensorUniquePtr = std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>;
/* ExecutionSession
* Class that supports executing compiled models.
*
* When the execution session does not work for known reasons, this class will
* throw std::runtime_error errors. Errno info will provide further info about
* the specific error that was raised.
*
* EFAULT when it could not load the library or a needed symbol was not found.
* EINVAL when it expected an entry point prior to executing a specific
* function.
* EPERM when the model executed on a machine without a compatible
* hardware/specialized accelerator.
*/
class ExecutionSession {
public:
// Create an execution session using the model given in sharedLibPath.
// This path must point to the actual file, local directory is not searched.
ExecutionSession(std::string sharedLibPath, std::string tag = "",
bool defaultEntryPoint = true);
~ExecutionSession();
// Get a NULL-terminated array of entry point names.
// For example {"run_addition, "run_subtraction", NULL}
// In order to get the number of entry points, pass an integer pointer to the
// function.
const std::string *queryEntryPoints(int64_t *numOfEntryPoints) const;
// Set entry point for this session.
// Call this before running the session or querying signatures if
// defaultEntryPoint is false or there are multiple entry points in the model.
void setEntryPoint(const std::string &entryPointName);
DynamicLibraryHandleType &getSharedLibraryHandle() {
return _sharedLibraryHandle;
};
// Use custom deleter since forward declared OMTensor hides destructor
std::vector<OMTensorUniquePtr> run(std::vector<OMTensorUniquePtr>);
// Run using public interface. Explicit calls are needed to free tensor &
// tensor lists.
OMTensorList *run(OMTensorList *input);
// Get input and output signature as a Json string. For example for nminst:
// `[ { "type" : "f32" , "dims" : [1 , 1 , 28 , 28] , "name" : "image" } ]`
const std::string inputSignature() const;
const std::string outputSignature() const;
protected:
// Constructor that build the object without initialization (for use by
// subclass only).
ExecutionSession() = default;
// Initialization of library. Called by public constructor, or by subclasses.
void Init(std::string sharedLibPath, std::string tag, bool defaultEntryPoint);
// Error reporting processing when throwing runtime errors. Set errno as
// appropriate.
std::string reportInitError() const;
std::string reportLibraryOpeningError(const std::string &libraryName) const;
std::string reportSymbolLoadingError(const std::string &symbolName) const;
std::string reportUndefinedEntryPointIn(
const std::string &functionName) const;
std::string reportErrnoError() const;
std::string reportCompilerError(const std::string &errorMessage) const;
// Track if Init was called or not.
bool isInitialized = false;
// Handler to the shared library file being loaded.
DynamicLibraryHandleType _sharedLibraryHandle;
// Tag used to compile the model. By default, it is the model filename without
// extension.
std::string tag;
// Entry point function.
std::string _entryPointName;
entryPointFuncType _entryPointFunc = nullptr;
// Query entry point function.
const std::string _queryEntryPointsName = "omQueryEntryPoints";
queryEntryPointsFuncType _queryEntryPointsFunc = nullptr;
// Entry point for input/output signatures
const std::string _inputSignatureName = "omInputSignature";
const std::string _outputSignatureName = "omOutputSignature";
signatureFuncType _inputSignatureFunc = nullptr;
signatureFuncType _outputSignatureFunc = nullptr;
};
} // namespace onnx_mlir
#endif